Skip to content

Commit 4221c69

Browse files
Merge pull request #156 from rosswhitfield/dask_worker_plugin
Add ability to register Dask WorkerPlugin
2 parents 8fb6efa + 3d7b3e1 commit 4221c69

File tree

5 files changed

+76
-8
lines changed

5 files changed

+76
-8
lines changed

doc/conf.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -233,4 +233,5 @@
233233
[u'UT-Battelle, LLC'], 1)
234234
]
235235

236-
intersphinx_mapping = {'python': ('https://docs.python.org/3', None)}
236+
intersphinx_mapping = {'python': ('https://docs.python.org/3', None),
237+
'distributed': ('http://distributed.dask.org/en/stable', None)}

doc/user_guides/dask.rst

+45
Original file line numberDiff line numberDiff line change
@@ -109,4 +109,49 @@ You batch script should then look like:
109109
ips.py --config=ips.conf --platform=platform.conf
110110
111111
112+
Running with worker plugin
113+
--------------------------
112114

115+
There is the ability to set a
116+
:class:`~distributed.diagnostics.plugin.WorkerPlugin` on the dask
117+
worker using the `dask_worker_plugin` option in
118+
:meth:`~ipsframework.services.ServicesProxy.submit_tasks`.
119+
120+
Using a WorkerPlugin in combination with shifter allows you to do
121+
things like coping files out of the `Temporary XFS
122+
<https://docs.nersc.gov/development/shifter/how-to-use/#temporary-xfs-files-for-optimizing-io>`_
123+
file system. An example of that is
124+
125+
.. code-block:: python
126+
127+
from distributed.diagnostics.plugin import WorkerPlugin
128+
129+
class DaskWorkerPlugin(WorkerPlugin):
130+
def __init__(self, tmp_dir, target_dir):
131+
self.tmp_dir = tmp_dir
132+
self.target_dir = target_dir
133+
134+
def teardown(self, worker):
135+
os.system(f"cp {self.tmp_dir}/* {self.target_dir}")
136+
137+
class Worker(Component):
138+
def step(self, timestamp=0.0):
139+
cwd = self.services.get_working_dir()
140+
141+
self.services.create_task_pool('pool')
142+
self.services.add_task('pool', 'task_1', 1, '/tmp/', 'executable')
143+
144+
worker_plugin = DaskWorkerPlugin('/tmp', cwd)
145+
146+
ret_val = self.services.submit_tasks('pool',
147+
use_dask=True, use_shifter=True,
148+
dask_worker_plugin=worker_plugin)
149+
150+
exit_status = self.services.get_finished_tasks('pool')
151+
152+
153+
where the batch script has the temporary XFS filesystem mounted as
154+
155+
.. code-block:: bash
156+
157+
#SBATCH --volume="/global/cscratch1/sd/$USER/tmpfiles:/tmp:perNodeCache=size=1G"

ipsframework/services.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -1848,7 +1848,7 @@ def add_task(self, task_pool_name, task_name, nproc, working_dir,
18481848
*args, keywords=keywords)
18491849

18501850
def submit_tasks(self, task_pool_name, block=True, use_dask=False, dask_nodes=1,
1851-
dask_ppn=None, launch_interval=0.0, use_shifter=False):
1851+
dask_ppn=None, launch_interval=0.0, use_shifter=False, dask_worker_plugin=None):
18521852
"""
18531853
Launch all unfinished tasks in task pool *task_pool_name*. If *block* is ``True``,
18541854
return when all tasks have been launched. If *block* is ``False``, return when all
@@ -1860,7 +1860,7 @@ def submit_tasks(self, task_pool_name, block=True, use_dask=False, dask_nodes=1,
18601860
start_time = time.time()
18611861
self._send_monitor_event('IPS_TASK_POOL_BEGIN', 'task_pool = %s ' % task_pool_name)
18621862
task_pool: TaskPool = self.task_pools[task_pool_name]
1863-
retval = task_pool.submit_tasks(block, use_dask, dask_nodes, dask_ppn, launch_interval, use_shifter)
1863+
retval = task_pool.submit_tasks(block, use_dask, dask_nodes, dask_ppn, launch_interval, use_shifter, dask_worker_plugin)
18641864
elapsed_time = time.time() - start_time
18651865
self._send_monitor_event('IPS_TASK_POOL_END', 'task_pool = %s elapsed time = %.2f S' %
18661866
(task_pool_name, elapsed_time),
@@ -2066,7 +2066,7 @@ def add_task(self, task_name, nproc, working_dir, binary, *args, **keywords):
20662066
self.queued_tasks[task_name] = Task(task_name, nproc, working_dir, binary_fullpath, *args,
20672067
**keywords["keywords"])
20682068

2069-
def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter=False):
2069+
def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter=False, dask_worker_plugin=None):
20702070
"""Launch tasks in *queued_tasks* using dask.
20712071
20722072
:param block: Unused, this will always return after tasks are submitted
@@ -2077,6 +2077,8 @@ def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter
20772077
:type dask_ppn: int
20782078
:param use_shifter: Option to launch dask scheduler and workers in shifter container
20792079
:type use_shifter: bool
2080+
:param dask_worker_plugin: If provided this will be registered as a worker plugin with the dask client
2081+
:type dask_worker_plugin: distributed.diagnostics.plugin.WorkerPlugin
20802082
"""
20812083
services: ServicesProxy = self.services
20822084
self.dask_file_name = os.path.join(os.getcwd(),
@@ -2115,6 +2117,9 @@ def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter
21152117

21162118
self.dask_client = self.dask.distributed.Client(scheduler_file=self.dask_file_name)
21172119

2120+
if dask_worker_plugin is not None:
2121+
self.dask_client.register_worker_plugin(dask_worker_plugin)
2122+
21182123
try:
21192124
self.worker_event_logfile = services.sim_name + '_' + services.get_config_param("PORTAL_RUNID") + '_' + self.name + '_{}.json'
21202125
except KeyError:
@@ -2135,7 +2140,7 @@ def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter
21352140
self.queued_tasks = {}
21362141
return len(self.futures)
21372142

2138-
def submit_tasks(self, block=True, use_dask=False, dask_nodes=1, dask_ppn=None, launch_interval=0.0, use_shifter=False):
2143+
def submit_tasks(self, block=True, use_dask=False, dask_nodes=1, dask_ppn=None, launch_interval=0.0, use_shifter=False, dask_worker_plugin=None):
21392144
"""Launch tasks in *queued_tasks*. Finished tasks are handled before
21402145
launching new ones. If *block* is ``True``, the number of
21412146
tasks submitted is returned after all tasks have been launched
@@ -2157,7 +2162,8 @@ def submit_tasks(self, block=True, use_dask=False, dask_nodes=1, dask_ppn=None,
21572162
:type launch_internal: float
21582163
:param use_shifter: Option to launch dask scheduler and workers in shifter container
21592164
:type use_shifter: bool
2160-
2165+
:param dask_worker_plugin: If provided this will be registered as a worker plugin with the dask client
2166+
:type dask_worker_plugin: distributed.diagnostics.plugin.WorkerPlugin
21612167
"""
21622168

21632169
if use_dask:
@@ -2167,7 +2173,7 @@ def submit_tasks(self, block=True, use_dask=False, dask_nodes=1, dask_ppn=None,
21672173
self.services.error("Requested to run dask within shifter but shifter not available")
21682174
raise Exception("shifter not found")
21692175
else:
2170-
return self.submit_dask_tasks(block, dask_nodes, dask_ppn, use_shifter)
2176+
return self.submit_dask_tasks(block, dask_nodes, dask_ppn, use_shifter, dask_worker_plugin)
21712177
elif not TaskPool.dask:
21722178
self.services.warning("Requested use_dask but cannot because import dask failed")
21732179
elif not self.serial_pool:

tests/helloworld/hello_worker_task_pool_dask.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# -------------------------------------------------------------------------------
44
from time import sleep
55
import copy
6+
from distributed.diagnostics.plugin import WorkerPlugin
67
from ipsframework import Component
78

89

@@ -12,6 +13,14 @@ def myFun(*args):
1213
return 0
1314

1415

16+
class DaskWorkerPlugin(WorkerPlugin):
17+
def setup(self, worker):
18+
print("Running setup of worker")
19+
20+
def teardown(self, worker):
21+
print("Running teardown of worker")
22+
23+
1524
class HelloWorker(Component):
1625
def __init__(self, services, config):
1726
super().__init__(services, config)
@@ -32,7 +41,10 @@ def step(self, timestamp=0.0, **keywords):
3241
self.services.add_task('pool', 'func_' + str(i), 1,
3342
cwd, myFun, duration)
3443

35-
ret_val = self.services.submit_tasks('pool', use_dask=True, dask_nodes=1, dask_ppn=10)
44+
worker_plugin = DaskWorkerPlugin()
45+
46+
ret_val = self.services.submit_tasks('pool', use_dask=True, dask_nodes=1, dask_ppn=10,
47+
dask_worker_plugin=worker_plugin)
3648
print('ret_val = ', ret_val)
3749
exit_status = self.services.get_finished_tasks('pool')
3850
print(exit_status)

tests/helloworld/test_helloworld.py

+4
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,10 @@ def test_helloworld_task_pool_dask(tmpdir, capfd):
245245
assert captured_out[4] == 'HelloDriver: finished worker init call'
246246
assert captured_out[5] == 'HelloDriver: beginning step call'
247247
assert captured_out[6] == 'Hello from HelloWorker'
248+
249+
assert "Running setup of worker" in captured_out
250+
assert "Running teardown of worker" in captured_out
251+
248252
assert 'ret_val = 9' in captured_out
249253

250254
for duration in ("0.2", "0.4", "0.6"):

0 commit comments

Comments
 (0)