6
6
from deephyper .evaluator .evaluate import Evaluator
7
7
8
8
logger = logging .getLogger (__name__ )
9
- WaitResult = namedtuple (' WaitResult' , [' active' , ' done' , ' failed' , ' cancelled' ])
9
+ WaitResult = namedtuple (" WaitResult" , [" active" , " done" , " failed" , " cancelled" ])
10
10
11
11
12
- class MPIFuture () :
12
+ class MPIFuture :
13
13
"""MPIFuture is a class meant to track a pending evaluation.
14
14
It record whether it was posted to a worker, the associated
15
15
MPI request, the tag, and the command that was sent."""
@@ -28,7 +28,7 @@ def posted(self):
28
28
def post (self , comm , worker , tag ):
29
29
"""Posts the request to a particular worker,
30
30
with a particular tag."""
31
- if ( self .posted ) :
31
+ if self .posted :
32
32
raise ValueError ("Request already posted" )
33
33
comm .send (self ._cmd , dest = worker , tag = tag )
34
34
self ._worker = worker
@@ -57,7 +57,7 @@ def _set_result(self, value):
57
57
def test (self ):
58
58
"""Tests if the request has completed."""
59
59
completed , result = MPI .Request .test (self ._request )
60
- if ( completed ) :
60
+ if completed :
61
61
self ._set_result (result )
62
62
return completed
63
63
@@ -66,7 +66,7 @@ def waitany(futures):
66
66
"""Waits for any of the provided futures to complete
67
67
and sets the result of the one that completed."""
68
68
status = MPI .Status ()
69
- requests = [ f ._request for f in futures ]
69
+ requests = [f ._request for f in futures ]
70
70
idx , result = MPI .Request .waitany (requests , status = status )
71
71
f = futures [idx ]
72
72
f ._set_result (result )
@@ -76,10 +76,11 @@ def waitany(futures):
76
76
def waitall (futures ):
77
77
"""Waits for all the provided futures to complete and
78
78
sets their result."""
79
- results = MPI .Request .waitall ([ f ._request for f in futures ])
79
+ results = MPI .Request .waitall ([f ._request for f in futures ])
80
80
for r , f in zip (results , futures ):
81
81
f ._set_result (r )
82
82
83
+
83
84
class MPIWorkerPool (Evaluator ):
84
85
"""Evaluator using a pool of MPI workers.
85
86
@@ -91,21 +92,33 @@ class MPIWorkerPool(Evaluator):
91
92
If ``None``, then cache_key defaults to a lossless (identity)
92
93
encoding of the input dict.
93
94
"""
94
- def __init__ (self , run_function , cache_key = None , comm = None , ** kwargs ):
95
+
96
+ def __init__ (
97
+ self ,
98
+ run_function ,
99
+ cache_key = None ,
100
+ comm = None ,
101
+ num_nodes_master = 1 ,
102
+ num_nodes_per_eval = 1 ,
103
+ num_ranks_per_node = 1 ,
104
+ num_evals_per_node = 1 ,
105
+ num_threads_per_rank = 64 ,
106
+ ** kwargs
107
+ ):
95
108
"""Constructor."""
96
109
super ().__init__ (run_function , cache_key )
97
- if ( comm is None ) :
110
+ if comm is None :
98
111
self .comm = MPI .COMM_WORLD
99
112
else :
100
113
self .comm = comm
101
- self .num_workers = self .comm .Get_size ()- 1
114
+ self .num_workers = self .comm .Get_size () - 1
102
115
self .avail_workers = []
103
- for tag in range (0 , self . WORKERS_PER_NODE ):
116
+ for tag in range (0 , num_ranks_per_node ):
104
117
for rank in range (0 , self .num_workers ):
105
- self .avail_workers .append ((rank + 1 , tag + 1 ))
118
+ self .avail_workers .append ((rank + 1 , tag + 1 ))
106
119
funcName = self ._run_function .__name__
107
120
moduleName = self ._run_function .__module__
108
- self .appName = '.' .join ((moduleName , funcName ))
121
+ self .appName = "." .join ((moduleName , funcName ))
109
122
110
123
def _try_posting (self , unposted ):
111
124
"""This function takes a list of MPIFuture instances that aren't
@@ -115,7 +128,7 @@ def _try_posting(self, unposted):
115
128
now_posted = []
116
129
now_unposted = []
117
130
for f in unposted :
118
- if ( len (self .avail_workers ) > 0 ) :
131
+ if len (self .avail_workers ) > 0 :
119
132
worker , tag = self .avail_workers .pop ()
120
133
f .post (self .comm , worker , tag )
121
134
now_posted .append (f )
@@ -128,29 +141,29 @@ def _eval_exec(self, x):
128
141
with the provided point x as argument. Returns an instance
129
142
of MPIFuture. If possible, this future will have been posted."""
130
143
assert isinstance (x , dict )
131
- cmd = {' cmd' : ' exec' , ' args' : [x ] }
144
+ cmd = {" cmd" : " exec" , " args" : [x ]}
132
145
future = MPIFuture (cmd )
133
- if ( len (self .avail_workers ) > 0 ) :
146
+ if len (self .avail_workers ) > 0 :
134
147
worker , tag = self .avail_workers .pop ()
135
148
future .post (self .comm , worker , tag )
136
149
return future
137
150
138
- def wait (self , futures , timeout = None , return_when = ' ANY_COMPLETED' ):
151
+ def wait (self , futures , timeout = None , return_when = " ANY_COMPLETED" ):
139
152
"""Waits for a set of futures to complete. If return_when == ANY_COMPLETED,
140
153
this function will return as soon as at least one of the futures has completed.
141
154
Otherwise it will wait for all the futures to have completed."""
142
155
# TODO: for now the timeout is not taken into account and
143
156
# the failed and cancelled lists will always be empty.
144
- done , failed , cancelled , active = [],[],[],[]
157
+ done , failed , cancelled , active = [], [], [], []
145
158
posted = [f for f in futures if f .posted ]
146
159
unposted = [f for f in futures if not f .posted ]
147
160
148
- if ( len (posted ) == 0 ) :
161
+ if len (posted ) == 0 :
149
162
newly_posted , unposted = self ._try_posting (unposted )
150
163
posted .extend (newly_posted )
151
164
152
- if ( return_when == ' ALL_COMPLETED' ) :
153
- while ( len (posted ) > 0 or len (unposted ) > 0 ) :
165
+ if return_when == " ALL_COMPLETED" :
166
+ while len (posted ) > 0 or len (unposted ) > 0 :
154
167
MPIFuture .waitall (posted )
155
168
for f in posted :
156
169
self .avail_workers .append ((f .worker , f .tag ))
@@ -167,18 +180,18 @@ def wait(self, futures, timeout=None, return_when='ANY_COMPLETED'):
167
180
one_completed = True
168
181
done .append (f )
169
182
# one request completed, try posting a new request
170
- if ( len (unposted ) > 0 ) :
183
+ if len (unposted ) > 0 :
171
184
p = unposted .pop (0 )
172
185
p .post (self .comm , worker = f .worker , tag = f .tag )
173
186
active .append (p )
174
187
else :
175
188
self .avail_workers .append ((f .worker , f .tag ))
176
189
else :
177
190
active .append (f )
178
- if not one_completed : # we need to call waitany
191
+ if not one_completed : # we need to call waitany
179
192
f = MPIFuture .waitany (posted )
180
193
done .append (f )
181
- if ( len (unposted ) > 0 ):
194
+ if len (unposted ) > 0 :
182
195
p = unposted .pop (0 )
183
196
p .post (self .comm , worker = f .worker , tag = f .tag )
184
197
active .append (p )
@@ -187,18 +200,13 @@ def wait(self, futures, timeout=None, return_when='ANY_COMPLETED'):
187
200
for f in unposted :
188
201
active .append (f )
189
202
190
- return WaitResult (
191
- active = active ,
192
- done = done ,
193
- failed = failed ,
194
- cancelled = cancelled
195
- )
203
+ return WaitResult (active = active , done = done , failed = failed , cancelled = cancelled )
196
204
197
205
def shutdown_workers (self ):
198
206
"""Shuts down all the MPIWorker instances."""
199
207
req = []
200
208
for k in range (1 , self .comm .Get_size ()):
201
- r = self .comm .isend ({' cmd' : ' exit' }, dest = k , tag = 0 )
209
+ r = self .comm .isend ({" cmd" : " exit" }, dest = k , tag = 0 )
202
210
req .append (r )
203
211
MPI .Request .waitall (req )
204
212
0 commit comments