1
1
import abc
2
+ import contextlib
2
3
import multiprocessing as mp
3
4
import os
4
5
import queue
6
+ import signal
5
7
import sys
6
8
import threading
7
9
from multiprocessing .queues import JoinableQueue
8
- from typing import Any , Callable , Union
10
+ from typing import Any , Callable , Set , Union
9
11
10
12
from .logging import multiprocessing_breakpoint
11
13
12
14
mp .set_start_method ("fork" )
13
15
14
16
15
17
class PoolBase (abc .ABC ):
18
+ def __init__ (self ):
19
+ with pools_lock :
20
+ pools .add (self )
21
+
16
22
@abc .abstractmethod
17
23
def submit (self , args ):
18
24
pass
@@ -24,15 +30,20 @@ def process_until_done(self):
24
30
def start (self ):
25
31
pass
26
32
27
- def close (self ):
28
- pass
33
+ def close (self , * , immediate = False ): # noqa: ARG002
34
+ with pools_lock :
35
+ pools .remove (self )
29
36
30
37
def __enter__ (self ):
31
38
self .start ()
32
39
return self
33
40
34
- def __exit__ (self , * args ):
35
- self .close ()
41
+ def __exit__ (self , exc_type , _exc_value , _tb ):
42
+ self .close (immediate = exc_type is not None )
43
+
44
+
45
+ pools_lock = threading .Lock ()
46
+ pools : Set [PoolBase ] = set ()
36
47
37
48
38
49
class Queue (JoinableQueue ):
@@ -53,9 +64,15 @@ class _Sentinel:
53
64
54
65
55
66
def _worker_process (handler , input_ , output ):
56
- # Creates a new process group, making sure no signals are propagated from the main process to the worker processes.
67
+ # Creates a new process group, making sure no signals are
68
+ # propagated from the main process to the worker processes.
57
69
os .setpgrp ()
58
70
71
+ # Restore default signal handlers, otherwise workers would inherit
72
+ # them from main process
73
+ signal .signal (signal .SIGTERM , signal .SIG_DFL )
74
+ signal .signal (signal .SIGINT , signal .SIG_DFL )
75
+
59
76
sys .breakpointhook = multiprocessing_breakpoint
60
77
while (args := input_ .get ()) is not _SENTINEL :
61
78
result = handler (args )
@@ -71,11 +88,14 @@ def __init__(
71
88
* ,
72
89
result_callback : Callable [["MultiPool" , Any ], Any ],
73
90
):
91
+ super ().__init__ ()
74
92
if process_num <= 0 :
75
93
raise ValueError ("At process_num must be greater than 0" )
76
94
95
+ self ._running = False
77
96
self ._result_callback = result_callback
78
97
self ._input = Queue (ctx = mp .get_context ())
98
+ self ._input .cancel_join_thread ()
79
99
self ._output = mp .SimpleQueue ()
80
100
self ._procs = [
81
101
mp .Process (
@@ -87,14 +107,32 @@ def __init__(
87
107
self ._tid = threading .get_native_id ()
88
108
89
109
def start (self ):
110
+ self ._running = True
90
111
for p in self ._procs :
91
112
p .start ()
92
113
93
- def close (self ):
94
- self ._clear_input_queue ()
95
- self ._request_workers_to_quit ()
96
- self ._clear_output_queue ()
114
+ def close (self , * , immediate = False ):
115
+ if not self ._running :
116
+ return
117
+ self ._running = False
118
+
119
+ if immediate :
120
+ self ._terminate_workers ()
121
+ else :
122
+ self ._clear_input_queue ()
123
+ self ._request_workers_to_quit ()
124
+ self ._clear_output_queue ()
125
+
97
126
self ._wait_for_workers_to_quit ()
127
+ super ().close (immediate = immediate )
128
+
129
+ def _terminate_workers (self ):
130
+ for proc in self ._procs :
131
+ proc .terminate ()
132
+
133
+ self ._input .close ()
134
+ if sys .version_info >= (3 , 9 ):
135
+ self ._output .close ()
98
136
99
137
def _clear_input_queue (self ):
100
138
try :
@@ -129,14 +167,16 @@ def submit(self, args):
129
167
self ._input .put (args )
130
168
131
169
def process_until_done (self ):
132
- while not self ._input .is_empty ():
133
- result = self ._output .get ()
134
- self ._result_callback (self , result )
135
- self ._input .task_done ()
170
+ with contextlib .suppress (EOFError ):
171
+ while not self ._input .is_empty ():
172
+ result = self ._output .get ()
173
+ self ._result_callback (self , result )
174
+ self ._input .task_done ()
136
175
137
176
138
177
class SinglePool (PoolBase ):
139
178
def __init__ (self , handler , * , result_callback ):
179
+ super ().__init__ ()
140
180
self ._handler = handler
141
181
self ._result_callback = result_callback
142
182
@@ -157,3 +197,20 @@ def make_pool(process_num, handler, result_callback) -> Union[SinglePool, MultiP
157
197
handler = handler ,
158
198
result_callback = result_callback ,
159
199
)
200
+
201
+
202
+ orig_signal_handlers = {}
203
+
204
+
205
+ def _on_terminate (signum , frame ):
206
+ with contextlib .suppress (StopIteration ):
207
+ while True :
208
+ pool = next (iter (pools ))
209
+ pool .close (immediate = True )
210
+
211
+ if callable (orig_signal_handlers [signum ]):
212
+ orig_signal_handlers [signum ](signum , frame )
213
+
214
+
215
+ orig_signal_handlers [signal .SIGTERM ] = signal .signal (signal .SIGTERM , _on_terminate )
216
+ orig_signal_handlers [signal .SIGINT ] = signal .signal (signal .SIGINT , _on_terminate )
0 commit comments