From e119fcd94aa7168d3ef96784a12d53bdd1348ee1 Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Tue, 4 Apr 2023 08:33:40 -0700 Subject: [PATCH] Use `tasks_per_node` to split sweep across tasks --- .../submitit_launcher.py | 46 +++++++++++++++---- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/submitit_launcher.py b/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/submitit_launcher.py index 1efc8e4ce8..b98bc856ca 100644 --- a/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/submitit_launcher.py +++ b/plugins/hydra_submitit_launcher/hydra_plugins/hydra_submitit_launcher/submitit_launcher.py @@ -43,12 +43,8 @@ def setup( def __call__( self, - sweep_overrides: List[str], - job_dir_key: str, - job_num: int, - job_id: str, - singleton_state: Dict[type, Singleton], - ) -> JobReturn: + job_params, + ) -> Optional[JobReturn]: # lazy import to ensure plugin discovery remains fast import submitit @@ -56,6 +52,22 @@ def __call__( assert self.config is not None assert self.task_function is not None + job_env = submitit.JobEnvironment() + task_id = job_env.global_rank + if task_id >= len(job_params): + # May happen on the last job if the total number of tasks is not a multiple + # of `tasks_per_node`. + return None + + sweep_overrides: List[str] + job_dir_key: str + job_num: int + job_id: str + singleton_state: Dict[type, Singleton] + sweep_overrides, job_dir_key, job_num, job_id, singleton_state = job_params[ + task_id + ] + Singleton.set_state(singleton_state) setup_globals() sweep_config = self.hydra_context.config_loader.load_sweep_config( @@ -64,7 +76,7 @@ def __call__( with open_dict(sweep_config.hydra.job) as job: # Populate new job variables - job.id = submitit.JobEnvironment().job_id # type: ignore + job.id = f"{job_env.job_id}_{job_env.global_rank}" # type: ignore [attr-defined] sweep_config.hydra.job.num = job_num return run_job( @@ -141,8 +153,24 @@ def launch( ) ) - jobs = executor.map_array(self, *zip(*job_params)) - return [j.results()[0] for j in jobs] + # Create groups of parameters of size `tasks_per_node`, so that each task + # can get assigned its own set of parameters. + tasks_per_node = params.get("tasks_per_node", 1) + job_params = [ + job_params[start_idx : start_idx + tasks_per_node] + for start_idx in range(0, len(job_params), tasks_per_node) + ] + + # We need at least two jobs, otherwise submitit will create a single job instead + # of a job array, which will cause issues down the line. + # We create a new job with empty parameters (=> will terminate immediately). + if len(job_params) == 1: + job_params.append([]) + + jobs = executor.map_array(self, job_params) + return [ + result for job in jobs for result in job.results() if result is not None + ] class LocalLauncher(BaseSubmititLauncher):