Skip to content

Commit

Permalink
Use tasks_per_node to split sweep across tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
odelalleau committed Apr 4, 2023
1 parent 744318b commit 016dc28
Showing 1 changed file with 33 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence
from typing import Any, Dict, List, Optional, Sequence, Tuple

from hydra.core.singleton import Singleton
from hydra.core.utils import JobReturn, filter_overrides, run_job, setup_globals
Expand Down Expand Up @@ -43,19 +43,26 @@ def setup(

def __call__(
self,
sweep_overrides: List[str],
job_dir_key: str,
job_num: int,
job_id: str,
singleton_state: Dict[type, Singleton],
) -> JobReturn:
job_params: List[Tuple[List[str], str, int, str, Dict[type, Singleton]]],
) -> Optional[JobReturn]:
# lazy import to ensure plugin discovery remains fast
import submitit

assert self.hydra_context is not None
assert self.config is not None
assert self.task_function is not None

job_env = submitit.JobEnvironment()
task_id = job_env.global_rank
if task_id >= len(job_params):
# May happen on the last job if the total number of tasks is not a multiple
# of `tasks_per_node`.
return None

sweep_overrides, job_dir_key, job_num, job_id, singleton_state = job_params[
task_id
]

Singleton.set_state(singleton_state)
setup_globals()
sweep_config = self.hydra_context.config_loader.load_sweep_config(
Expand All @@ -64,7 +71,7 @@ def __call__(

with open_dict(sweep_config.hydra.job) as job:
# Populate new job variables
job.id = submitit.JobEnvironment().job_id # type: ignore
job.id = f"{job_env.job_id}_{job_env.global_rank}" # type: ignore [attr-defined]
sweep_config.hydra.job.num = job_num

return run_job(
Expand Down Expand Up @@ -141,8 +148,24 @@ def launch(
)
)

jobs = executor.map_array(self, *zip(*job_params))
return [j.results()[0] for j in jobs]
# Create groups of parameters of size `tasks_per_node`, so that each task
# can get assigned its own set of parameters.
tasks_per_node = params.get("tasks_per_node", 1)
job_params = [
job_params[start_idx : start_idx + tasks_per_node]
for start_idx in range(0, len(job_params), tasks_per_node)
]

# We need at least two jobs, otherwise submitit will create a single job instead
# of a job array, which will cause issues down the line.
# We create a new job with empty parameters (=> will terminate immediately).
if len(job_params) == 1:
job_params.append([])

jobs = executor.map_array(self, job_params)
return [
result for job in jobs for result in job.results() if result is not None
]


class LocalLauncher(BaseSubmititLauncher):
Expand Down

0 comments on commit 016dc28

Please sign in to comment.