@@ -203,7 +203,7 @@ def resubmitter(executable_with_state: ExecutableWithState) -> None:
203203 ThreadPoolExecutor (max_workers = max_workers ) as thread_executor ,
204204 ):
205205
206- def submit_task (executable_with_state : ExecutableWithState ) -> None :
206+ def submit_task (executable_with_state : ExecutableWithState ) -> Future :
207207 """Submit task to the thread executor and mark its state as started."""
208208 future = thread_executor .submit (
209209 self ._execute_item_in_child_context ,
@@ -216,14 +216,21 @@ def on_done(future: Future) -> None:
216216 self ._on_task_complete (executable_with_state , future , scheduler )
217217
218218 future .add_done_callback (on_done )
219+ return future
219220
220221 # Submit initial tasks
221- for exe_state in self .executables_with_state :
222- submit_task (exe_state )
222+ futures = [
223+ submit_task (exe_state ) for exe_state in self .executables_with_state
224+ ]
223225
224226 # Wait for completion
225227 self ._completion_event .wait ()
226228
229+ # Cancel remaining futures so
230+ # that we don't wait for them to join.
231+ for future in futures :
232+ future .cancel ()
233+
227234 # Suspend execution if everything done and at least one of the tasks raised a suspend exception.
228235 if self ._suspend_exception :
229236 raise self ._suspend_exception
@@ -275,6 +282,11 @@ def _on_task_complete(
275282 scheduler : TimerScheduler ,
276283 ) -> None :
277284 """Handle task completion, suspension, or failure."""
285+
286+ if future .cancelled ():
287+ exe_state .suspend ()
288+ return
289+
278290 try :
279291 result = future .result ()
280292 exe_state .complete (result )
0 commit comments