|
| 1 | +import contextlib |
| 2 | +import faulthandler |
1 | 3 | import logging |
| 4 | +import os |
2 | 5 | import queue |
| 6 | +import sys |
3 | 7 | import threading |
4 | 8 | import time |
5 | 9 | from datetime import datetime |
|
12 | 16 | from ..common.callback_registrar import CallbackRegistrar |
13 | 17 | from .task_queue import TaskQueueReader, TaskQueueWriter |
14 | 18 |
|
| 19 | +SHUTDOWN_TIME = 60.0 # Wait up to 60 seconds for workers to finish their tasks |
| 20 | + |
15 | 21 |
|
16 | 22 | class TaskDistributor: |
17 | 23 | """ |
@@ -123,17 +129,40 @@ def stop(self) -> None: |
123 | 129 |
|
124 | 130 | # Signalize stop to worker threads |
125 | 131 | self.running = False |
| 132 | + deadline_ts = time.monotonic() + SHUTDOWN_TIME |
126 | 133 |
|
127 | | - # Wait until all workers stopped |
| 134 | + # Wait until all workers stopped (bounded) |
128 | 135 | for worker in self._worker_threads: |
129 | | - worker.join() |
| 136 | + remaining = max(0.0, deadline_ts - time.monotonic()) |
| 137 | + try: |
| 138 | + worker.join(timeout=remaining) |
| 139 | + except Exception as e: |
| 140 | + self.log.exception( |
| 141 | + "join() failed for %s: %s", getattr(worker, "name", "<unnamed>"), str(e) |
| 142 | + ) |
130 | 143 |
|
131 | 144 | self._task_queue_reader.disconnect() |
132 | 145 | self._task_queue_writer.disconnect() |
133 | 146 |
|
| 147 | + alive = [w for w in self._worker_threads if w.is_alive()] |
| 148 | + if alive: |
| 149 | + self._dump_thread_stacks() |
| 150 | + self.log.critical("Forcing shutdown with %d workers still alive", len(alive)) |
| 151 | + with contextlib.suppress(Exception): |
| 152 | + logging.shutdown() # flush logs |
| 153 | + os._exit(1) # nuke entire process |
| 154 | + |
134 | 155 | # Cleanup |
135 | 156 | self._worker_threads = [] |
136 | 157 |
|
| 158 | + def _dump_thread_stacks(self) -> None: |
| 159 | + self.log.critical("=== Graceful shutdown failed, thread stack dump follows ===") |
| 160 | + for worker in self._worker_threads: |
| 161 | + if not worker.is_alive(): |
| 162 | + continue |
| 163 | + self.log.error("Thread %s (ident=0x%x) still alive", worker.name, worker.ident) |
| 164 | + faulthandler.dump_traceback(file=sys.stderr, all_threads=True) |
| 165 | + |
137 | 166 | def _distribute_task(self, msg_id, task: DataPointTask): |
138 | 167 | """ |
139 | 168 | Puts given task into local queue of the corresponding thread. |
|
0 commit comments