diff --git a/plugins/hydra_optuna_sweeper/example/conf/config.yaml b/plugins/hydra_optuna_sweeper/example/conf/config.yaml index e3b86c71042..b80dae6ad84 100644 --- a/plugins/hydra_optuna_sweeper/example/conf/config.yaml +++ b/plugins/hydra_optuna_sweeper/example/conf/config.yaml @@ -12,6 +12,7 @@ hydra: n_trials: 20 n_jobs: 1 max_failure_rate: 0.0 + load_if_exists: True params: x: range(-5.5, 5.5, step=0.5) y: choice(-5 ,0 ,5) diff --git a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py index d17bbb1f70a..b501f29021a 100644 --- a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py +++ b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py @@ -156,6 +156,7 @@ def __init__( n_trials: int, n_jobs: int, max_failure_rate: float, + load_if_exists: Optional[bool], search_space: Optional[DictConfig], custom_search_space: Optional[str], params: Optional[DictConfig], @@ -169,6 +170,7 @@ def __init__( self.max_failure_rate = max_failure_rate assert self.max_failure_rate >= 0.0 assert self.max_failure_rate <= 1.0 + self.load_if_exists = load_if_exists self.custom_search_space_extender: Optional[ Callable[[DictConfig, Trial], None] ] = None @@ -330,7 +332,7 @@ def sweep(self, arguments: List[str]) -> None: storage=self.storage, sampler=self.sampler, directions=directions, - load_if_exists=True, + load_if_exists=self.load_if_exists, ) log.info(f"Study name: {study.study_name}") log.info(f"Storage: {self.storage}") @@ -338,7 +340,8 @@ def sweep(self, arguments: List[str]) -> None: log.info(f"Directions: {directions}") batch_size = self.n_jobs - n_trials_to_go = self.n_trials + n_trials_to_go = self.n_trials - len(study.trials) + self.job_idx = len(study.trials) while n_trials_to_go > 0: batch_size = min(n_trials_to_go, batch_size) diff --git a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/config.py b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/config.py index 03fde5a12d0..5bd97633b56 100644 --- a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/config.py +++ b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/config.py @@ -175,6 +175,9 @@ class OptunaSweeperConf: # Maximum authorized failure rate for a batch of parameters max_failure_rate: float = 0.0 + # Load an existing study and resume it. + load_if_exists: bool = True + search_space: Optional[Dict[str, Any]] = None params: Optional[Dict[str, str]] = None diff --git a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/optuna_sweeper.py b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/optuna_sweeper.py index c53d5d7b5d0..fd09c1f249c 100644 --- a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/optuna_sweeper.py +++ b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/optuna_sweeper.py @@ -20,6 +20,7 @@ def __init__( n_trials: int, n_jobs: int, max_failure_rate: float, + load_if_exists: Optional[bool], search_space: Optional[DictConfig], custom_search_space: Optional[str], params: Optional[DictConfig], @@ -34,6 +35,7 @@ def __init__( n_trials, n_jobs, max_failure_rate, + load_if_exists, search_space, custom_search_space, params, diff --git a/plugins/hydra_optuna_sweeper/tests/test_optuna_sweeper_plugin.py b/plugins/hydra_optuna_sweeper/tests/test_optuna_sweeper_plugin.py index f042937a8fd..252b7036900 100644 --- a/plugins/hydra_optuna_sweeper/tests/test_optuna_sweeper_plugin.py +++ b/plugins/hydra_optuna_sweeper/tests/test_optuna_sweeper_plugin.py @@ -6,6 +6,7 @@ from typing import Any, List, Optional import optuna + from hydra.core.override_parser.overrides_parser import OverridesParser from hydra.core.plugins import Plugins from hydra.plugins.sweeper import Sweeper @@ -331,6 +332,7 @@ def test_warnings( n_trials=1, n_jobs=1, max_failure_rate=0.0, + load_if_exists=True, custom_search_space=None, ) if search_space is not None: @@ -371,6 +373,46 @@ def test_failure_rate(max_failure_rate: float, tmpdir: Path) -> None: assert error_string not in err +@mark.parametrize("load_if_exists", (True, False, None)) +def test_load_if_exists(load_if_exists: Optional[bool], tmpdir: Path) -> None: + storage = "sqlite:///" + os.path.join(str(tmpdir), "test.db") + study_name = "test-optuna-example" + + # Let the first run fail. + cmd = [ + sys.executable, + "example/sphere.py", + "--multirun", + "hydra.sweep.dir=" + str(tmpdir), + "hydra.job.chdir=True", + "hydra.sweeper.n_trials=15", + "hydra.sweeper.n_jobs=1", + f"hydra.sweeper.storage={storage}", + f"hydra.sweeper.study_name={study_name}", + "hydra/sweeper/sampler=random", + "hydra.sweeper.sampler.seed=123", + ] + _ = run_process(cmd, print_error=False, raise_exception=False) + + cmd.pop(-1) + if load_if_exists is True: + cmd.append("hydra.sweeper.load_if_exists=True") + + elif load_if_exists is False: + cmd.append("hydra.sweeper.load_if_exists=False") + + _, err = run_process(cmd, print_error=False, raise_exception=False) + + error_string = ( + "optuna.exceptions.DuplicatedStudyError: " + "Another study with name 'test-optuna-example' already exists." + ) + if load_if_exists is False: + assert error_string in err + else: + assert error_string not in err + + def test_example_with_deprecated_search_space( tmpdir: Path, ) -> None: diff --git a/pytest.ini b/pytest.ini index bfbef0b9979..04a5d665968 100644 --- a/pytest.ini +++ b/pytest.ini @@ -15,4 +15,6 @@ filterwarnings = ; Remove when default changes ignore:.*Future Hydra versions will no longer change working directory at job runtime by default.*:UserWarning ; Jupyter notebook test on Windows yield warnings - ignore:.*Proactor event loop does not implement add_reader family of methods required for zmq.*:RuntimeWarning \ No newline at end of file + ignore:.*Proactor event loop does not implement add_reader family of methods required for zmq.*:RuntimeWarning + ; setuptools 67.5.0+ emits this warning when setuptools.commands.develop is imported + ignore:pkg_resources is deprecated as an API:DeprecationWarning \ No newline at end of file diff --git a/website/docs/plugins/optuna_sweeper.md b/website/docs/plugins/optuna_sweeper.md index b8364809679..83ecbabafb7 100644 --- a/website/docs/plugins/optuna_sweeper.md +++ b/website/docs/plugins/optuna_sweeper.md @@ -71,6 +71,7 @@ study_name: sphere n_trials: 20 n_jobs: 1 max_failure_rate: 0.0 +load_if_exists: True params: x: range(-5.5,5.5,step=0.5) y: choice(-5,0,5)