Skip to content

Commit fcc4ed7

Browse files
committed
refact(pool): explicitly terminate workers on interrupt
Instead of juggling with signal handlers and hoping that `ShutDownRequired` will be fired in the appropriate place in `multiprocessing.BasePool`, on exceptional termination, we signal workers via `SIGTERM`. As a side-effect this makes it possible to run `process_file` in non-main thread.
1 parent c8bbf96 commit fcc4ed7

File tree

3 files changed

+71
-66
lines changed

3 files changed

+71
-66
lines changed

unblob/pool.py

+71-14
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
import abc
2+
import contextlib
23
import multiprocessing as mp
34
import os
45
import queue
6+
import signal
57
import sys
68
import threading
79
from multiprocessing.queues import JoinableQueue
8-
from typing import Any, Callable, Union
10+
from typing import Any, Callable, Set, Union
911

1012
from .logging import multiprocessing_breakpoint
1113

1214
mp.set_start_method("fork")
1315

1416

1517
class PoolBase(abc.ABC):
18+
def __init__(self):
19+
with pools_lock:
20+
pools.add(self)
21+
1622
@abc.abstractmethod
1723
def submit(self, args):
1824
pass
@@ -24,15 +30,20 @@ def process_until_done(self):
2430
def start(self):
2531
pass
2632

27-
def close(self):
28-
pass
33+
def close(self, *, immediate=False): # noqa: ARG002
34+
with pools_lock:
35+
pools.remove(self)
2936

3037
def __enter__(self):
3138
self.start()
3239
return self
3340

34-
def __exit__(self, *args):
35-
self.close()
41+
def __exit__(self, exc_type, _exc_value, _tb):
42+
self.close(immediate=exc_type is not None)
43+
44+
45+
pools_lock = threading.Lock()
46+
pools: Set[PoolBase] = set()
3647

3748

3849
class Queue(JoinableQueue):
@@ -53,9 +64,15 @@ class _Sentinel:
5364

5465

5566
def _worker_process(handler, input_, output):
56-
# Creates a new process group, making sure no signals are propagated from the main process to the worker processes.
67+
# Creates a new process group, making sure no signals are
68+
# propagated from the main process to the worker processes.
5769
os.setpgrp()
5870

71+
# Restore default signal handlers, otherwise workers would inherit
72+
# them from main process
73+
signal.signal(signal.SIGTERM, signal.SIG_DFL)
74+
signal.signal(signal.SIGINT, signal.SIG_DFL)
75+
5976
sys.breakpointhook = multiprocessing_breakpoint
6077
while (args := input_.get()) is not _SENTINEL:
6178
result = handler(args)
@@ -71,11 +88,14 @@ def __init__(
7188
*,
7289
result_callback: Callable[["MultiPool", Any], Any],
7390
):
91+
super().__init__()
7492
if process_num <= 0:
7593
raise ValueError("At process_num must be greater than 0")
7694

95+
self._running = False
7796
self._result_callback = result_callback
7897
self._input = Queue(ctx=mp.get_context())
98+
self._input.cancel_join_thread()
7999
self._output = mp.SimpleQueue()
80100
self._procs = [
81101
mp.Process(
@@ -87,14 +107,32 @@ def __init__(
87107
self._tid = threading.get_native_id()
88108

89109
def start(self):
110+
self._running = True
90111
for p in self._procs:
91112
p.start()
92113

93-
def close(self):
94-
self._clear_input_queue()
95-
self._request_workers_to_quit()
96-
self._clear_output_queue()
114+
def close(self, *, immediate=False):
115+
if not self._running:
116+
return
117+
self._running = False
118+
119+
if immediate:
120+
self._terminate_workers()
121+
else:
122+
self._clear_input_queue()
123+
self._request_workers_to_quit()
124+
self._clear_output_queue()
125+
97126
self._wait_for_workers_to_quit()
127+
super().close(immediate=immediate)
128+
129+
def _terminate_workers(self):
130+
for proc in self._procs:
131+
proc.terminate()
132+
133+
self._input.close()
134+
if sys.version_info >= (3, 9):
135+
self._output.close()
98136

99137
def _clear_input_queue(self):
100138
try:
@@ -129,14 +167,16 @@ def submit(self, args):
129167
self._input.put(args)
130168

131169
def process_until_done(self):
132-
while not self._input.is_empty():
133-
result = self._output.get()
134-
self._result_callback(self, result)
135-
self._input.task_done()
170+
with contextlib.suppress(EOFError):
171+
while not self._input.is_empty():
172+
result = self._output.get()
173+
self._result_callback(self, result)
174+
self._input.task_done()
136175

137176

138177
class SinglePool(PoolBase):
139178
def __init__(self, handler, *, result_callback):
179+
super().__init__()
140180
self._handler = handler
141181
self._result_callback = result_callback
142182

@@ -157,3 +197,20 @@ def make_pool(process_num, handler, result_callback) -> Union[SinglePool, MultiP
157197
handler=handler,
158198
result_callback=result_callback,
159199
)
200+
201+
202+
orig_signal_handlers = {}
203+
204+
205+
def _on_terminate(signum, frame):
206+
with contextlib.suppress(StopIteration):
207+
while True:
208+
pool = next(iter(pools))
209+
pool.close(immediate=True)
210+
211+
if callable(orig_signal_handlers[signum]):
212+
orig_signal_handlers[signum](signum, frame)
213+
214+
215+
orig_signal_handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, _on_terminate)
216+
orig_signal_handlers[signal.SIGINT] = signal.signal(signal.SIGINT, _on_terminate)

unblob/processing.py

-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
StatReport,
4444
UnknownError,
4545
)
46-
from .signals import terminate_gracefully
4746
from .ui import NullProgressReporter, ProgressReporter
4847

4948
logger = get_logger()

unblob/signals.py

-51
This file was deleted.

0 commit comments

Comments
 (0)