Skip to content

Commit 3aa2d7d

Browse files
committed
TaskDistributor: Ensure process shutdown
1 parent 4846440 commit 3aa2d7d

File tree

1 file changed

+31
-2
lines changed

1 file changed

+31
-2
lines changed

dp3/task_processing/task_distributor.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
import contextlib
2+
import faulthandler
13
import logging
4+
import os
25
import queue
6+
import sys
37
import threading
48
import time
59
from datetime import datetime
@@ -12,6 +16,8 @@
1216
from ..common.callback_registrar import CallbackRegistrar
1317
from .task_queue import TaskQueueReader, TaskQueueWriter
1418

19+
SHUTDOWN_TIME = 60.0 # Wait up to 60 seconds for workers to finish their tasks
20+
1521

1622
class TaskDistributor:
1723
"""
@@ -123,17 +129,40 @@ def stop(self) -> None:
123129

124130
# Signalize stop to worker threads
125131
self.running = False
132+
deadline_ts = time.monotonic() + SHUTDOWN_TIME
126133

127-
# Wait until all workers stopped
134+
# Wait until all workers stopped (bounded)
128135
for worker in self._worker_threads:
129-
worker.join()
136+
remaining = max(0.0, deadline_ts - time.monotonic())
137+
try:
138+
worker.join(timeout=remaining)
139+
except Exception as e:
140+
self.log.exception(
141+
"join() failed for %s: %s", getattr(worker, "name", "<unnamed>"), str(e)
142+
)
130143

131144
self._task_queue_reader.disconnect()
132145
self._task_queue_writer.disconnect()
133146

147+
alive = [w for w in self._worker_threads if w.is_alive()]
148+
if alive:
149+
self._dump_thread_stacks()
150+
self.log.critical("Forcing shutdown with %d workers still alive", len(alive))
151+
with contextlib.suppress(Exception):
152+
logging.shutdown() # flush logs
153+
os._exit(1) # nuke entire process
154+
134155
# Cleanup
135156
self._worker_threads = []
136157

158+
def _dump_thread_stacks(self) -> None:
159+
self.log.critical("=== Graceful shutdown failed, thread stack dump follows ===")
160+
for worker in self._worker_threads:
161+
if not worker.is_alive():
162+
continue
163+
self.log.error("Thread %s (ident=0x%x) still alive", worker.name, worker.ident)
164+
faulthandler.dump_traceback(file=sys.stderr, all_threads=True)
165+
137166
def _distribute_task(self, msg_id, task: DataPointTask):
138167
"""
139168
Puts given task into local queue of the corresponding thread.

0 commit comments

Comments
 (0)