diff --git a/unblob/pool.py b/unblob/pool.py index 810011a209..eeee58ca7a 100644 --- a/unblob/pool.py +++ b/unblob/pool.py @@ -1,11 +1,13 @@ import abc +import contextlib import multiprocessing as mp import os import queue +import signal import sys import threading from multiprocessing.queues import JoinableQueue -from typing import Any, Callable, Union +from typing import Any, Callable, Set, Union from .logging import multiprocessing_breakpoint @@ -13,6 +15,10 @@ class PoolBase(abc.ABC): + def __init__(self): + with pools_lock: + pools.add(self) + @abc.abstractmethod def submit(self, args): pass @@ -24,15 +30,20 @@ def process_until_done(self): def start(self): pass - def close(self): - pass + def close(self, *, immediate=False): # noqa: ARG002 + with pools_lock: + pools.remove(self) def __enter__(self): self.start() return self - def __exit__(self, *args): - self.close() + def __exit__(self, exc_type, _exc_value, _tb): + self.close(immediate=exc_type is not None) + + +pools_lock = threading.Lock() +pools: Set[PoolBase] = set() class Queue(JoinableQueue): @@ -53,9 +64,15 @@ class _Sentinel: def _worker_process(handler, input_, output): - # Creates a new process group, making sure no signals are propagated from the main process to the worker processes. + # Creates a new process group, making sure no signals are + # propagated from the main process to the worker processes. os.setpgrp() + # Restore default signal handlers, otherwise workers would inherit + # them from main process + signal.signal(signal.SIGTERM, signal.SIG_DFL) + signal.signal(signal.SIGINT, signal.SIG_DFL) + sys.breakpointhook = multiprocessing_breakpoint while (args := input_.get()) is not _SENTINEL: result = handler(args) @@ -71,11 +88,14 @@ def __init__( *, result_callback: Callable[["MultiPool", Any], Any], ): + super().__init__() if process_num <= 0: raise ValueError("At process_num must be greater than 0") + self._running = False self._result_callback = result_callback self._input = Queue(ctx=mp.get_context()) + self._input.cancel_join_thread() self._output = mp.SimpleQueue() self._procs = [ mp.Process( @@ -87,14 +107,32 @@ def __init__( self._tid = threading.get_native_id() def start(self): + self._running = True for p in self._procs: p.start() - def close(self): - self._clear_input_queue() - self._request_workers_to_quit() - self._clear_output_queue() + def close(self, *, immediate=False): + if not self._running: + return + self._running = False + + if immediate: + self._terminate_workers() + else: + self._clear_input_queue() + self._request_workers_to_quit() + self._clear_output_queue() + self._wait_for_workers_to_quit() + super().close(immediate=immediate) + + def _terminate_workers(self): + for proc in self._procs: + proc.terminate() + + self._input.close() + if sys.version_info >= (3, 9): + self._output.close() def _clear_input_queue(self): try: @@ -129,14 +167,16 @@ def submit(self, args): self._input.put(args) def process_until_done(self): - while not self._input.is_empty(): - result = self._output.get() - self._result_callback(self, result) - self._input.task_done() + with contextlib.suppress(EOFError): + while not self._input.is_empty(): + result = self._output.get() + self._result_callback(self, result) + self._input.task_done() class SinglePool(PoolBase): def __init__(self, handler, *, result_callback): + super().__init__() self._handler = handler self._result_callback = result_callback @@ -157,3 +197,20 @@ def make_pool(process_num, handler, result_callback) -> Union[SinglePool, MultiP handler=handler, result_callback=result_callback, ) + + +orig_signal_handlers = {} + + +def _on_terminate(signum, frame): + with contextlib.suppress(StopIteration): + while True: + pool = next(iter(pools)) + pool.close(immediate=True) + + if callable(orig_signal_handlers[signum]): + orig_signal_handlers[signum](signum, frame) + + +orig_signal_handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, _on_terminate) +orig_signal_handlers[signal.SIGINT] = signal.signal(signal.SIGINT, _on_terminate) diff --git a/unblob/processing.py b/unblob/processing.py index 393807842d..b4b0961860 100644 --- a/unblob/processing.py +++ b/unblob/processing.py @@ -44,7 +44,6 @@ StatReport, UnknownError, ) -from .signals import terminate_gracefully from .ui import NullProgressReporter, ProgressReporter logger = get_logger() diff --git a/unblob/signals.py b/unblob/signals.py deleted file mode 100644 index 76b70a4dbe..0000000000 --- a/unblob/signals.py +++ /dev/null @@ -1,51 +0,0 @@ -import functools -import signal - -from structlog import get_logger - -logger = get_logger() - - -class ShutDownRequired(BaseException): - def __init__(self, signal: str): - super().__init__() - self.signal = signal - - -def terminate_gracefully(func): - @functools.wraps(func) - def decorator(*args, **kwargs): - signals_fired = [] - - def _handle_signal(signum: int, frame): - nonlocal signals_fired - signals_fired.append((signum, frame)) - raise ShutDownRequired(signal=signal.Signals(signum).name) - - original_signal_handlers = { - signal.SIGINT: signal.signal(signal.SIGINT, _handle_signal), - signal.SIGTERM: signal.signal(signal.SIGTERM, _handle_signal), - } - - logger.debug( - "Setting up signal handlers", - original_signal_handlers=original_signal_handlers, - _verbosity=2, - ) - - try: - return func(*args, **kwargs) - except ShutDownRequired as exc: - logger.warning("Shutting down", signal=exc.signal) - finally: - # Set back the original signal handlers - for sig, handler in original_signal_handlers.items(): - signal.signal(sig, handler) - - # Call the original signal handler with the fired and catched signal(s) - for sig, frame in signals_fired: - handler = original_signal_handlers.get(sig) - if callable(handler): - handler(sig, frame) - - return decorator