Skip to content

Commit

Permalink
refact(pool): explicitly terminate workers on interrupt
Browse files Browse the repository at this point in the history
Instead of juggling with signal handlers and hoping that
`ShutDownRequired` will be fired in the appropriate place in
`multiprocessing.BasePool`, on exceptional termination, we signal
workers via `SIGTERM`.

As a side-effect this makes it possible to run `process_file` in
non-main thread.
  • Loading branch information
vlaci committed Dec 3, 2024
1 parent 6930324 commit 32bc017
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 66 deletions.
85 changes: 71 additions & 14 deletions unblob/pool.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
import abc
import contextlib
import multiprocessing as mp
import os
import queue
import signal
import sys
import threading
from multiprocessing.queues import JoinableQueue
from typing import Any, Callable, Union
from typing import Any, Callable, Set, Union

from .logging import multiprocessing_breakpoint

mp.set_start_method("fork")


class PoolBase(abc.ABC):
def __init__(self):
with pools_lock:
pools.add(self)

@abc.abstractmethod
def submit(self, args):
pass
Expand All @@ -24,15 +30,20 @@ def process_until_done(self):
def start(self):
pass

def close(self):
pass
def close(self, *, immediate=False): # noqa: ARG002
with pools_lock:
pools.remove(self)

def __enter__(self):
self.start()
return self

def __exit__(self, *args):
self.close()
def __exit__(self, exc_type, _exc_value, _tb):
self.close(immediate=exc_type is not None)


pools_lock = threading.Lock()
pools: Set[PoolBase] = set()


class Queue(JoinableQueue):
Expand All @@ -53,9 +64,15 @@ class _Sentinel:


def _worker_process(handler, input_, output):
# Creates a new process group, making sure no signals are propagated from the main process to the worker processes.
# Creates a new process group, making sure no signals are
# propagated from the main process to the worker processes.
os.setpgrp()

# Restore default signal handlers, otherwise workers would inherit
# them from main process
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGINT, signal.SIG_DFL)

sys.breakpointhook = multiprocessing_breakpoint
while (args := input_.get()) is not _SENTINEL:
result = handler(args)
Expand All @@ -71,11 +88,14 @@ def __init__(
*,
result_callback: Callable[["MultiPool", Any], Any],
):
super().__init__()
if process_num <= 0:
raise ValueError("At process_num must be greater than 0")

self._running = False
self._result_callback = result_callback
self._input = Queue(ctx=mp.get_context())
self._input.cancel_join_thread()
self._output = mp.SimpleQueue()
self._procs = [
mp.Process(
Expand All @@ -87,14 +107,32 @@ def __init__(
self._tid = threading.get_native_id()

def start(self):
self._running = True
for p in self._procs:
p.start()

def close(self):
self._clear_input_queue()
self._request_workers_to_quit()
self._clear_output_queue()
def close(self, *, immediate=False):
if not self._running:
return
self._running = False

if immediate:
self._terminate_workers()
else:
self._clear_input_queue()
self._request_workers_to_quit()
self._clear_output_queue()

self._wait_for_workers_to_quit()
super().close(immediate=immediate)

def _terminate_workers(self):
for proc in self._procs:
proc.terminate()

self._input.close()
if sys.version_info >= (3, 9):
self._output.close()

def _clear_input_queue(self):
try:
Expand Down Expand Up @@ -129,14 +167,16 @@ def submit(self, args):
self._input.put(args)

def process_until_done(self):
while not self._input.is_empty():
result = self._output.get()
self._result_callback(self, result)
self._input.task_done()
with contextlib.suppress(EOFError):
while not self._input.is_empty():
result = self._output.get()
self._result_callback(self, result)
self._input.task_done()


class SinglePool(PoolBase):
def __init__(self, handler, *, result_callback):
super().__init__()
self._handler = handler
self._result_callback = result_callback

Expand All @@ -157,3 +197,20 @@ def make_pool(process_num, handler, result_callback) -> Union[SinglePool, MultiP
handler=handler,
result_callback=result_callback,
)


orig_signal_handlers = {}


def _on_terminate(signum, frame):
with contextlib.suppress(StopIteration):
while True:
pool = next(iter(pools))
pool.close(immediate=True)

if callable(orig_signal_handlers[signum]):
orig_signal_handlers[signum](signum, frame)


orig_signal_handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, _on_terminate)
orig_signal_handlers[signal.SIGINT] = signal.signal(signal.SIGINT, _on_terminate)
1 change: 0 additions & 1 deletion unblob/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
StatReport,
UnknownError,
)
from .signals import terminate_gracefully
from .ui import NullProgressReporter, ProgressReporter

logger = get_logger()
Expand Down
51 changes: 0 additions & 51 deletions unblob/signals.py

This file was deleted.

0 comments on commit 32bc017

Please sign in to comment.