Skip to content

Commit ac7da77

Browse files
authored
Merge pull request #1279 from automl/1264-bug-reset-after-initialization
1264 bug reset after initialization
2 parents 12d0945 + ceb5235 commit ac7da77

File tree

4 files changed

+43
-7
lines changed

4 files changed

+43
-7
lines changed
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
# Reproducibility
22

3-
Reproducibility can only be ensured if one worker is used and no time (wallclock or CPU time) is involved.
3+
Reproducibility can only be ensured if one worker is used and no time (wallclock or CPU time) is involved.
4+
5+
!!! warning
6+
SMBO.reset() will not seed smac with the original seed. If you want to have a full reset, please set the seed again after calling reset.

examples/1_basics/3_ask_and_tell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def train(self, config: Configuration, seed: int = 0) -> float:
6262
)
6363

6464
# We can ask SMAC which trials should be evaluated next
65-
for _ in range(10):
65+
for _ in range(30):
6666
info = smac.ask()
6767
assert info.seed is not None
6868

smac/main/smbo.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,8 @@ def optimize(self, *, data_to_scatter: dict[str, Any] | None = None) -> Configur
325325

326326
# Some statistics
327327
logger.debug(
328-
f"Remaining wallclock time: {self.remaining_walltime}; "
329-
f"Remaining cpu time: {self.remaining_cputime}; "
328+
f"Remaining wallclock time: {self.remaining_walltime}, "
329+
f"Remaining cpu time: {self.remaining_cputime}, "
330330
f"Remaining trials: {self.remaining_trials}"
331331
)
332332

@@ -375,6 +375,7 @@ def reset(self) -> None:
375375
# We also reset runhistory and intensifier here
376376
self._runhistory.reset()
377377
self._intensifier.reset()
378+
self._trial_generator = iter(self._intensifier)
378379

379380
def exists(self, filename: str | Path) -> bool:
380381
"""Checks if the files associated with the run already exist.
@@ -538,7 +539,7 @@ def _initialize_state(self) -> None:
538539
)
539540
logger.info(
540541
f"Found old run in `{self._scenario.output_directory}` but it is not the same as the current "
541-
f"one:\n{diff}"
542+
f"one: \n{diff}"
542543
)
543544

544545
feedback = input(

tests/test_main/test_smbo.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,17 @@
55

66
def test_termination_cost_threshold(rosenbrock):
77
termination_cost_threshold = 100
8-
scenario = Scenario(rosenbrock.configspace, n_trials=200, termination_cost_threshold=termination_cost_threshold)
8+
scenario = Scenario(
9+
rosenbrock.configspace,
10+
n_trials=200,
11+
termination_cost_threshold=termination_cost_threshold,
12+
)
913
smac = HyperparameterOptimizationFacade(
1014
scenario,
1115
rosenbrock.train,
12-
intensifier=HyperparameterOptimizationFacade.get_intensifier(scenario, max_config_calls=1),
16+
intensifier=HyperparameterOptimizationFacade.get_intensifier(
17+
scenario, max_config_calls=1
18+
),
1319
overwrite=True,
1420
)
1521
i = smac.optimize()
@@ -55,3 +61,29 @@ def test_termination_cost_threshold_with_fidelities(rosenbrock):
5561
assert config == i
5662
assert counter == 1
5763
assert smac.validate(i) < termination_cost_threshold
64+
65+
66+
def test_smbo_reset(rosenbrock):
67+
scenario = Scenario(rosenbrock.configspace, n_trials=3)
68+
smac = HyperparameterOptimizationFacade(
69+
scenario,
70+
rosenbrock.train,
71+
intensifier=HyperparameterOptimizationFacade.get_intensifier(
72+
scenario, max_config_calls=1
73+
),
74+
overwrite=True,
75+
)
76+
77+
smac.optimize()
78+
runhistory_len_before_reset = len(smac.runhistory)
79+
assert runhistory_len_before_reset == 3
80+
81+
smac.optimizer.reset()
82+
83+
assert smac.optimizer._used_target_function_walltime == 0.0
84+
assert smac.optimizer._used_target_function_cputime == 0.0
85+
assert len(smac.runhistory) == 0
86+
87+
smac.optimize()
88+
runhistory_len_after_reset = len(smac.runhistory)
89+
assert runhistory_len_after_reset == 3

0 commit comments

Comments
 (0)