11import asyncio
22import threading
33from datetime import timedelta
4- from typing import Optional , TypeVar
4+ from typing import Callable , Optional , TypeVar
55from unittest .mock import Mock
66
7+ import torch
78from torch .futures import Future
89
910T = TypeVar ("T" )
@@ -17,7 +18,6 @@ def __init__(self) -> None:
1718
1819 def set_timer (self , timer_handle : asyncio .TimerHandle ) -> None :
1920 assert self ._lock .locked ()
20-
2121 self ._timer_handle = timer_handle
2222 self ._lock .release ()
2323
@@ -99,6 +99,18 @@ def callback(fut: Future[T]) -> None:
9999 fut .add_done_callback (callback )
100100 return timed_fut
101101
102+ def stream_timeout (self , callback : Callable [[], None ], timeout : timedelta ) -> None :
103+ loop = self ._maybe_start_event_loop ()
104+
105+ event = torch .cuda .Event ()
106+ event .record ()
107+
108+ def handler () -> None :
109+ if not event .query ():
110+ callback ()
111+
112+ loop .call_soon_threadsafe (self ._register_handler , loop , handler , timeout )
113+
102114 @classmethod
103115 def _register (
104116 cls ,
@@ -116,6 +128,18 @@ def _register(
116128 )
117129 handle .set_timer (timer_handle )
118130
131+ @classmethod
132+ def _register_handler (
133+ cls ,
134+ loop ,
135+ handler : Callable [[], None ],
136+ timeout : timedelta ,
137+ ) -> None :
138+ loop .call_later (
139+ timeout .total_seconds (),
140+ handler ,
141+ )
142+
119143
120144_TIMEOUT_MANAGER = _TimeoutManager ()
121145
@@ -163,3 +187,18 @@ def callback(fut: Future[T]) -> T:
163187 raise TimeoutError (f"future did not complete within { timeout } " )
164188
165189 return fut .wait ()
190+
191+
192+ def stream_timeout (callback : Callable [[], None ], timeout : timedelta ) -> None :
193+ """
194+ Registers a callback that will be called after the specified timeout if
195+ the current stream doesn't complete in time.
196+
197+ This uses a cuda Event to track the completion of the current stream. If
198+ the stream is not complete after the timeout, the callback is called.
199+
200+ Args:
201+ callback: The callback to call if the stream doesn't complete in time.
202+ timeout: The timeout to wait for the stream to complete.
203+ """
204+ _TIMEOUT_MANAGER .stream_timeout (callback , timeout )
0 commit comments