-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CudaGraph] Handle exceptions thrown while capturing cuda graph (#17113)
* [CudaGraph] Handle exceptions thrown while capturing cuda graph Prior to this commit, an exception thrown during the capture of a cuda graph would result in `std::terminate` being called. This commit updates the implementation of `"vm.builtin.cuda_graph.run_or_capture"` such that a thrown exception can be recovered from, and does not cause any changes to the state of TVM's cuda graph cache. - Call to `cudaStreamDestroy` was previously skipped, now moved to a RAII-style destructor in a `ScopedCUDAStream` class. - Call to `cudaStreamEndCapture` was previously skipped, end of cuda graph capture now performed as part of RAII-style destructor for `CUDACaptureStream` class. - Restoration of `CUDAThreadEntry::ThreadLocal()->stream` was previously skipped, now restored as part of RAII-style destructor for `CUDACaptureStream` class. - Previously, an error raised from `cudaGraphInstantiate` would leave the `capture_cache_` in an ill-formed state. Now, the `capture_cache_` is only updated after a valid `CUDAGraphCapturedState` has been fully constructed. * lint fix * Unit test fix
- Loading branch information
1 parent
3c6ca5d
commit a84adaf
Showing
2 changed files
with
140 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters