Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optuna resume study #2647

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions plugins/hydra_optuna_sweeper/example/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ hydra:
n_trials: 20
n_jobs: 1
max_failure_rate: 0.0
load_if_exists: True
params:
x: range(-5.5, 5.5, step=0.5)
y: choice(-5 ,0 ,5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def __init__(
n_trials: int,
n_jobs: int,
max_failure_rate: float,
load_if_exists: Optional[bool],
search_space: Optional[DictConfig],
custom_search_space: Optional[str],
params: Optional[DictConfig],
Expand All @@ -169,6 +170,7 @@ def __init__(
self.max_failure_rate = max_failure_rate
assert self.max_failure_rate >= 0.0
assert self.max_failure_rate <= 1.0
self.load_if_exists = load_if_exists
self.custom_search_space_extender: Optional[
Callable[[DictConfig, Trial], None]
] = None
Expand Down Expand Up @@ -330,15 +332,16 @@ def sweep(self, arguments: List[str]) -> None:
storage=self.storage,
sampler=self.sampler,
directions=directions,
load_if_exists=True,
load_if_exists=self.load_if_exists,
)
log.info(f"Study name: {study.study_name}")
log.info(f"Storage: {self.storage}")
log.info(f"Sampler: {type(self.sampler).__name__}")
log.info(f"Directions: {directions}")

batch_size = self.n_jobs
n_trials_to_go = self.n_trials
n_trials_to_go = self.n_trials - len(study.trials)
self.job_idx = len(study.trials)

while n_trials_to_go > 0:
batch_size = min(n_trials_to_go, batch_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ class OptunaSweeperConf:
# Maximum authorized failure rate for a batch of parameters
max_failure_rate: float = 0.0

# Load an existing study and resume it.
load_if_exists: bool = True

search_space: Optional[Dict[str, Any]] = None

params: Optional[Dict[str, str]] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
n_trials: int,
n_jobs: int,
max_failure_rate: float,
load_if_exists: Optional[bool],
search_space: Optional[DictConfig],
custom_search_space: Optional[str],
params: Optional[DictConfig],
Expand All @@ -34,6 +35,7 @@ def __init__(
n_trials,
n_jobs,
max_failure_rate,
load_if_exists,
search_space,
custom_search_space,
params,
Expand Down
42 changes: 42 additions & 0 deletions plugins/hydra_optuna_sweeper/tests/test_optuna_sweeper_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, List, Optional

import optuna

from hydra.core.override_parser.overrides_parser import OverridesParser
from hydra.core.plugins import Plugins
from hydra.plugins.sweeper import Sweeper
Expand Down Expand Up @@ -331,6 +332,7 @@ def test_warnings(
n_trials=1,
n_jobs=1,
max_failure_rate=0.0,
load_if_exists=True,
custom_search_space=None,
)
if search_space is not None:
Expand Down Expand Up @@ -371,6 +373,46 @@ def test_failure_rate(max_failure_rate: float, tmpdir: Path) -> None:
assert error_string not in err


@mark.parametrize("load_if_exists", (True, False, None))
def test_load_if_exists(load_if_exists: Optional[bool], tmpdir: Path) -> None:
storage = "sqlite:///" + os.path.join(str(tmpdir), "test.db")
study_name = "test-optuna-example"

# Let the first run fail.
cmd = [
sys.executable,
"example/sphere.py",
"--multirun",
"hydra.sweep.dir=" + str(tmpdir),
"hydra.job.chdir=True",
"hydra.sweeper.n_trials=15",
"hydra.sweeper.n_jobs=1",
f"hydra.sweeper.storage={storage}",
f"hydra.sweeper.study_name={study_name}",
"hydra/sweeper/sampler=random",
"hydra.sweeper.sampler.seed=123",
]
_ = run_process(cmd, print_error=False, raise_exception=False)

cmd.pop(-1)
if load_if_exists is True:
cmd.append("hydra.sweeper.load_if_exists=True")

elif load_if_exists is False:
cmd.append("hydra.sweeper.load_if_exists=False")

_, err = run_process(cmd, print_error=False, raise_exception=False)

error_string = (
"optuna.exceptions.DuplicatedStudyError: "
"Another study with name 'test-optuna-example' already exists."
)
if load_if_exists is False:
assert error_string in err
else:
assert error_string not in err


def test_example_with_deprecated_search_space(
tmpdir: Path,
) -> None:
Expand Down
4 changes: 3 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ filterwarnings =
; Remove when default changes
ignore:.*Future Hydra versions will no longer change working directory at job runtime by default.*:UserWarning
; Jupyter notebook test on Windows yield warnings
ignore:.*Proactor event loop does not implement add_reader family of methods required for zmq.*:RuntimeWarning
ignore:.*Proactor event loop does not implement add_reader family of methods required for zmq.*:RuntimeWarning
; setuptools 67.5.0+ emits this warning when setuptools.commands.develop is imported
ignore:pkg_resources is deprecated as an API:DeprecationWarning
1 change: 1 addition & 0 deletions website/docs/plugins/optuna_sweeper.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ study_name: sphere
n_trials: 20
n_jobs: 1
max_failure_rate: 0.0
load_if_exists: True
params:
x: range(-5.5,5.5,step=0.5)
y: choice(-5,0,5)
Expand Down