Skip to content

Commit

Permalink
[CudaGraph] Handle exceptions thrown while capturing cuda graph (#17113)
Browse files Browse the repository at this point in the history
* [CudaGraph] Handle exceptions thrown while capturing cuda graph

Prior to this commit, an exception thrown during the capture of a cuda
graph would result in `std::terminate` being called.  This commit
updates the implementation of `"vm.builtin.cuda_graph.run_or_capture"`
such that a thrown exception can be recovered from, and does not cause
any changes to the state of TVM's cuda graph cache.

- Call to `cudaStreamDestroy` was previously skipped, now moved to a
  RAII-style destructor in a `ScopedCUDAStream` class.

- Call to `cudaStreamEndCapture` was previously skipped, end of cuda
  graph capture now performed as part of RAII-style destructor for
  `CUDACaptureStream` class.

- Restoration of `CUDAThreadEntry::ThreadLocal()->stream` was
  previously skipped, now restored as part of RAII-style destructor
  for `CUDACaptureStream` class.

- Previously, an error raised from `cudaGraphInstantiate` would leave
  the `capture_cache_` in an ill-formed state.  Now, the
  `capture_cache_` is only updated after a valid
  `CUDAGraphCapturedState` has been fully constructed.

* lint fix

* Unit test fix
  • Loading branch information
Lunderberg committed Jun 27, 2024
1 parent 3c6ca5d commit a84adaf
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 18 deletions.
81 changes: 65 additions & 16 deletions src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ namespace tvm {
namespace runtime {
namespace relax_vm {

namespace {

struct CUDAGraphCaptureKey {
// The unique index of the capture function within the module
int64_t index;
Expand Down Expand Up @@ -67,6 +69,18 @@ struct CUDAGraphCaptureKeyEqual {

/*! \brief The captured state of a CUDA graph */
struct CUDAGraphCapturedState {
CUDAGraphCapturedState() {}

CUDAGraphCapturedState(const CUDAGraphCapturedState&) = delete;
CUDAGraphCapturedState(CUDAGraphCapturedState&& other) { *this = std::move(other); }

CUDAGraphCapturedState& operator=(const CUDAGraphCapturedState&) = delete;
CUDAGraphCapturedState& operator=(CUDAGraphCapturedState&& other) {
std::swap(states, other.states);
std::swap(exec, other.exec);
return *this;
}

~CUDAGraphCapturedState() {
if (exec) {
CUDA_CALL(cudaGraphExecDestroy(exec));
Expand All @@ -82,6 +96,43 @@ struct CUDAGraphCapturedState {
cudaGraphExec_t exec = nullptr;
};

class ScopedCUDAStream {
public:
ScopedCUDAStream() { CUDA_CALL(cudaStreamCreate(&stream_)); }
~ScopedCUDAStream() { cudaStreamDestroy(stream_); }
ScopedCUDAStream(const ScopedCUDAStream&) = delete;
ScopedCUDAStream(ScopedCUDAStream&&) = delete;
ScopedCUDAStream& operator=(const ScopedCUDAStream&) = delete;
ScopedCUDAStream& operator=(ScopedCUDAStream&&) = delete;

operator cudaStream_t() const { return stream_; }

private:
cudaStream_t stream_;
};

class CUDACaptureStream {
public:
explicit CUDACaptureStream(cudaGraph_t* graph)
: prev_default_stream_(CUDAThreadEntry::ThreadLocal()->stream), output_graph_(graph) {
CUDAThreadEntry::ThreadLocal()->stream = capture_stream_;

CUDA_CALL(cudaStreamBeginCapture(capture_stream_, cudaStreamCaptureModeGlobal));
}
~CUDACaptureStream() {
cudaStreamEndCapture(capture_stream_, output_graph_);
CUDAThreadEntry::ThreadLocal()->stream = prev_default_stream_;
}

private:
cudaStream_t prev_default_stream_;
ScopedCUDAStream capture_stream_;

cudaGraph_t* output_graph_;
};

} // namespace

/*! \brief The VM extension of CUDA graph. */
class CUDAGraphExtensionNode : public VMExtensionNode {
public:
Expand All @@ -107,10 +158,6 @@ class CUDAGraphExtensionNode : public VMExtensionNode {
return states;
}

cudaStream_t capture_stream;
CUDA_CALL(cudaStreamCreate(&capture_stream));
CUDAGraphCapturedState entry;

// Set up arguments for the graph execution
Array<ObjectRef> tuple_args = Downcast<Array<ObjectRef>>(args);
int nargs = static_cast<int>(tuple_args.size());
Expand All @@ -130,21 +177,23 @@ class CUDAGraphExtensionNode : public VMExtensionNode {

// Run the graph in capture mode
cudaGraph_t graph;
std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream);
CUDA_CALL(cudaStreamBeginCapture(CUDAThreadEntry::ThreadLocal()->stream,
cudaStreamCaptureModeGlobal));

vm->InvokeClosurePacked(capture_func, TVMArgs(values.data(), tcodes.data(), nargs),
&capture_func_rv);
entry.states = capture_func_rv;
CUDA_CALL(cudaStreamEndCapture(CUDAThreadEntry::ThreadLocal()->stream, &graph));
std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream);
{
CUDACaptureStream capture_stream(&graph);
vm->InvokeClosurePacked(capture_func, TVMArgs(values.data(), tcodes.data(), nargs),
&capture_func_rv);
}

capture_cache_[entry_key] = entry;
CUDA_CALL(cudaGraphInstantiate(&capture_cache_[entry_key].exec, graph, NULL, NULL, 0));
CUDA_CALL(cudaStreamDestroy(capture_stream));
CUDAGraphCapturedState entry;
entry.states = capture_func_rv;
CUDA_CALL(cudaGraphInstantiate(&entry.exec, graph, NULL, NULL, 0));
CUDA_CALL(cudaGraphDestroy(graph));
return entry.states;

ObjectRef states = entry.states;

capture_cache_[entry_key] = std::move(entry);

return states;
}

/*!
Expand Down
77 changes: 75 additions & 2 deletions tests/python/relax/test_vm_cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
# under the License.

import tvm
from tvm.script import tir as T, relax as R, ir as I
from tvm import relax
import tvm.testing

from tvm import relax
from tvm.script import tir as T, relax as R, ir as I

import numpy as np
import pytest


# fmt: off
Expand Down Expand Up @@ -104,5 +107,75 @@ def test_vm_run():
tvm.testing.assert_allclose(y.asnumpy(), y_np, rtol=1e-5, atol=1e-5)


@tvm.testing.requires_cudagraph
def test_capture_error_is_recoverable():
"""Function calls while capturing cudagraph may throw exceptions
Calls to PackedFuncs may occur within a captured cudaGraph. If a
call to that PackedFunc raises an exception while capturing the
cudaGraph, throwing exception should cleanly unwind the stack, and
the exception may be caught in the calling scope.
This is a regression test. In previous implementations, an
exception thrown while capturing a cudaGraph would skip the call
to `cudaStreamEndCapture`, causing additional exceptions to be
thrown while freeing memory in TVM destructors. Since C++ does
not support stack unwinding from multiple simultaneous exceptions,
this would result in immediate `std::terminate`, making it
difficult to debug the original error.
"""

target = tvm.target.Target("cuda")
dev = tvm.cuda()

@tvm.register_func("test_vm_cuda_graph.invalid_impl_for_cudagraph", override=True)
def invalid_impl_for_cudagraph(arg_tensor):
# Memory allocation/deallocation may not be performed while
# capturing a cudaGraph. This passes the warm-up run
# performed by "vm.builtin.cuda_graph.run_or_capture", but
# throws an exception when the cudaGraph is being captured.
_dummy_workspace = tvm.nd.empty([16], "float16", dev)
return arg_tensor

@I.ir_module
class Module:
@R.function
def main(A: R.Tensor([16], "float16")):
B = R.add(A, A)
C = R.call_pure_packed(
"test_vm_cuda_graph.invalid_impl_for_cudagraph",
B,
sinfo_args=R.Tensor([16], "float16"),
)
D = R.add(C, C)
return D

with target, tvm.ir.transform.PassContext(config={"relax.backend.use_cuda_graph": True}):
Module = tvm.ir.transform.Sequential(
[
tvm.relax.transform.LegalizeOps(),
tvm.tir.transform.DefaultGPUSchedule(),
tvm.relax.transform.RemovePurityChecking(),
tvm.relax.transform.CallTIRRewrite(),
tvm.relax.transform.StaticPlanBlockMemory(),
tvm.relax.transform.RewriteCUDAGraph(),
]
)(Module)

assert "cuda_graph_alloc" in Module, (
"Validity of unit test requires the call to `invalid_impl_for_cudagraph` "
"to have been captured by RewriteCUDAGraph."
)

built = tvm.relax.build(Module, target=target)
vm = tvm.relax.VirtualMachine(built, dev)

arg = tvm.nd.array(np.arange(16).astype("float16"), dev)

with pytest.raises(tvm.TVMError):
vm["main"](arg)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit a84adaf

Please sign in to comment.