Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use tasks_per_node to split sweep across tasks #2633

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence
from typing import Any, Dict, List, Optional, Sequence, Tuple

from hydra.core.singleton import Singleton
from hydra.core.utils import JobReturn, filter_overrides, run_job, setup_globals
Expand Down Expand Up @@ -43,19 +43,26 @@ def setup(

def __call__(
self,
sweep_overrides: List[str],
job_dir_key: str,
job_num: int,
job_id: str,
singleton_state: Dict[type, Singleton],
) -> JobReturn:
job_params: List[Tuple[List[str], str, int, str, Dict[type, Singleton]]],
) -> Optional[JobReturn]:
# lazy import to ensure plugin discovery remains fast
import submitit

assert self.hydra_context is not None
assert self.config is not None
assert self.task_function is not None

job_env = submitit.JobEnvironment()
task_id = job_env.global_rank
if task_id >= len(job_params):
# May happen on the last job if the total number of tasks is not a multiple
# of `tasks_per_node`.
return None

sweep_overrides, job_dir_key, job_num, job_id, singleton_state = job_params[
task_id
]

Singleton.set_state(singleton_state)
setup_globals()
sweep_config = self.hydra_context.config_loader.load_sweep_config(
Expand All @@ -64,7 +71,7 @@ def __call__(

with open_dict(sweep_config.hydra.job) as job:
# Populate new job variables
job.id = submitit.JobEnvironment().job_id # type: ignore
job.id = f"{job_env.job_id}_{job_env.global_rank}" # type: ignore [attr-defined]
sweep_config.hydra.job.num = job_num

return run_job(
Expand Down Expand Up @@ -141,8 +148,24 @@ def launch(
)
)

jobs = executor.map_array(self, *zip(*job_params))
return [j.results()[0] for j in jobs]
# Create groups of parameters of size `tasks_per_node`, so that each task
# can get assigned its own set of parameters.
tasks_per_node = params.get("tasks_per_node", 1)
job_params = [
job_params[start_idx : start_idx + tasks_per_node]
for start_idx in range(0, len(job_params), tasks_per_node)
]

# We need at least two jobs, otherwise submitit will create a single job instead
# of a job array, which will cause issues down the line.
# We create a new job with empty parameters (=> will terminate immediately).
if len(job_params) == 1:
job_params.append([])

jobs = executor.map_array(self, job_params)
return [
result for job in jobs for result in job.results() if result is not None
]


class LocalLauncher(BaseSubmititLauncher):
Expand Down