@@ -1848,7 +1848,7 @@ def add_task(self, task_pool_name, task_name, nproc, working_dir,
1848
1848
* args , keywords = keywords )
1849
1849
1850
1850
def submit_tasks (self , task_pool_name , block = True , use_dask = False , dask_nodes = 1 ,
1851
- dask_ppn = None , launch_interval = 0.0 , use_shifter = False ):
1851
+ dask_ppn = None , launch_interval = 0.0 , use_shifter = False , dask_worker_plugin = None ):
1852
1852
"""
1853
1853
Launch all unfinished tasks in task pool *task_pool_name*. If *block* is ``True``,
1854
1854
return when all tasks have been launched. If *block* is ``False``, return when all
@@ -1860,7 +1860,7 @@ def submit_tasks(self, task_pool_name, block=True, use_dask=False, dask_nodes=1,
1860
1860
start_time = time .time ()
1861
1861
self ._send_monitor_event ('IPS_TASK_POOL_BEGIN' , 'task_pool = %s ' % task_pool_name )
1862
1862
task_pool : TaskPool = self .task_pools [task_pool_name ]
1863
- retval = task_pool .submit_tasks (block , use_dask , dask_nodes , dask_ppn , launch_interval , use_shifter )
1863
+ retval = task_pool .submit_tasks (block , use_dask , dask_nodes , dask_ppn , launch_interval , use_shifter , dask_worker_plugin )
1864
1864
elapsed_time = time .time () - start_time
1865
1865
self ._send_monitor_event ('IPS_TASK_POOL_END' , 'task_pool = %s elapsed time = %.2f S' %
1866
1866
(task_pool_name , elapsed_time ),
@@ -2066,7 +2066,7 @@ def add_task(self, task_name, nproc, working_dir, binary, *args, **keywords):
2066
2066
self .queued_tasks [task_name ] = Task (task_name , nproc , working_dir , binary_fullpath , * args ,
2067
2067
** keywords ["keywords" ])
2068
2068
2069
- def submit_dask_tasks (self , block = True , dask_nodes = 1 , dask_ppn = None , use_shifter = False ):
2069
+ def submit_dask_tasks (self , block = True , dask_nodes = 1 , dask_ppn = None , use_shifter = False , dask_worker_plugin = None ):
2070
2070
"""Launch tasks in *queued_tasks* using dask.
2071
2071
2072
2072
:param block: Unused, this will always return after tasks are submitted
@@ -2077,6 +2077,8 @@ def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter
2077
2077
:type dask_ppn: int
2078
2078
:param use_shifter: Option to launch dask scheduler and workers in shifter container
2079
2079
:type use_shifter: bool
2080
+ :param dask_worker_plugin: If provided this will be registered as a worker plugin with the dask client
2081
+ :type dask_worker_plugin: distributed.diagnostics.plugin.WorkerPlugin
2080
2082
"""
2081
2083
services : ServicesProxy = self .services
2082
2084
self .dask_file_name = os .path .join (os .getcwd (),
@@ -2115,6 +2117,9 @@ def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter
2115
2117
2116
2118
self .dask_client = self .dask .distributed .Client (scheduler_file = self .dask_file_name )
2117
2119
2120
+ if dask_worker_plugin is not None :
2121
+ self .dask_client .register_worker_plugin (dask_worker_plugin )
2122
+
2118
2123
try :
2119
2124
self .worker_event_logfile = services .sim_name + '_' + services .get_config_param ("PORTAL_RUNID" ) + '_' + self .name + '_{}.json'
2120
2125
except KeyError :
@@ -2135,7 +2140,7 @@ def submit_dask_tasks(self, block=True, dask_nodes=1, dask_ppn=None, use_shifter
2135
2140
self .queued_tasks = {}
2136
2141
return len (self .futures )
2137
2142
2138
- def submit_tasks (self , block = True , use_dask = False , dask_nodes = 1 , dask_ppn = None , launch_interval = 0.0 , use_shifter = False ):
2143
+ def submit_tasks (self , block = True , use_dask = False , dask_nodes = 1 , dask_ppn = None , launch_interval = 0.0 , use_shifter = False , dask_worker_plugin = None ):
2139
2144
"""Launch tasks in *queued_tasks*. Finished tasks are handled before
2140
2145
launching new ones. If *block* is ``True``, the number of
2141
2146
tasks submitted is returned after all tasks have been launched
@@ -2157,7 +2162,8 @@ def submit_tasks(self, block=True, use_dask=False, dask_nodes=1, dask_ppn=None,
2157
2162
:type launch_internal: float
2158
2163
:param use_shifter: Option to launch dask scheduler and workers in shifter container
2159
2164
:type use_shifter: bool
2160
-
2165
+ :param dask_worker_plugin: If provided this will be registered as a worker plugin with the dask client
2166
+ :type dask_worker_plugin: distributed.diagnostics.plugin.WorkerPlugin
2161
2167
"""
2162
2168
2163
2169
if use_dask :
@@ -2167,7 +2173,7 @@ def submit_tasks(self, block=True, use_dask=False, dask_nodes=1, dask_ppn=None,
2167
2173
self .services .error ("Requested to run dask within shifter but shifter not available" )
2168
2174
raise Exception ("shifter not found" )
2169
2175
else :
2170
- return self .submit_dask_tasks (block , dask_nodes , dask_ppn , use_shifter )
2176
+ return self .submit_dask_tasks (block , dask_nodes , dask_ppn , use_shifter , dask_worker_plugin )
2171
2177
elif not TaskPool .dask :
2172
2178
self .services .warning ("Requested use_dask but cannot because import dask failed" )
2173
2179
elif not self .serial_pool :
0 commit comments