Skip to content

Commit d3e3885

Browse files
Merge pull request optuna#1430 from bigbird555/fix-pruner-bug
Experimental commit to fix pruner bug.
2 parents 5abc36a + 2c4e0ce commit d3e3885

File tree

4 files changed

+21
-21
lines changed

4 files changed

+21
-21
lines changed

Diff for: optuna/pruners/_percentile.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def _is_first_in_interval_step(step, intermediate_steps, n_warmup_steps, interva
5555
# type: (int, KeysView[int], int, int) -> bool
5656

5757
nearest_lower_pruning_step = (
58-
(step - n_warmup_steps - 1) // interval_steps * interval_steps + n_warmup_steps + 1
59-
)
58+
step - n_warmup_steps
59+
) // interval_steps * interval_steps + n_warmup_steps
6060
assert nearest_lower_pruning_step >= 0
6161

6262
# `intermediate_steps` may not be sorted so we must go through all elements.
@@ -167,7 +167,7 @@ def prune(self, study, trial):
167167
return False
168168

169169
n_warmup_steps = self._n_warmup_steps
170-
if step <= n_warmup_steps:
170+
if step < n_warmup_steps:
171171
return False
172172

173173
if not _is_first_in_interval_step(

Diff for: optuna/pruners/_threshold.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def prune(self, study: "optuna.study.Study", trial: "optuna.trial.FrozenTrial")
118118
return False
119119

120120
n_warmup_steps = self._n_warmup_steps
121-
if step <= n_warmup_steps:
121+
if step < n_warmup_steps:
122122
return False
123123

124124
if not _is_first_in_interval_step(

Diff for: tests/pruners_tests/test_median.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -95,28 +95,28 @@ def test_median_pruner_n_warmup_steps():
9595
study = optuna.study.create_study()
9696

9797
trial = optuna.trial.Trial(study, study._storage.create_new_trial(study._study_id))
98+
trial.report(1, 0)
9899
trial.report(1, 1)
99-
trial.report(1, 2)
100100
study._storage.set_trial_state(trial._trial_id, TrialState.COMPLETE)
101101

102102
trial = optuna.trial.Trial(study, study._storage.create_new_trial(study._study_id))
103-
trial.report(2, 1)
103+
trial.report(2, 0)
104104
# A pruner is not activated during warm-up steps.
105105
assert not pruner.prune(study=study, trial=study._storage.get_trial(trial._trial_id))
106106

107-
trial.report(2, 2)
107+
trial.report(2, 1)
108108
# A pruner is activated after warm-up steps.
109109
assert pruner.prune(study=study, trial=study._storage.get_trial(trial._trial_id))
110110

111111

112112
@pytest.mark.parametrize(
113113
"n_warmup_steps,interval_steps,report_steps,expected_prune_steps",
114114
[
115-
(1, 2, 1, [2, 4]),
116-
(0, 3, 10, list(range(1, 30))),
117-
(2, 3, 10, list(range(11, 30))),
118-
(0, 10, 3, [1, 2, 3, 13, 14, 15, 22, 23, 24]),
119-
(2, 10, 3, [4, 5, 6, 13, 14, 15, 25, 26, 27]),
115+
(1, 2, 1, [1, 3]),
116+
(0, 3, 10, list(range(29))),
117+
(2, 3, 10, list(range(10, 29))),
118+
(0, 10, 3, [0, 1, 2, 12, 13, 14, 21, 22, 23]),
119+
(2, 10, 3, [3, 4, 5, 12, 13, 14, 24, 25, 26]),
120120
],
121121
)
122122
def test_median_pruner_interval_steps(
@@ -129,7 +129,7 @@ def test_median_pruner_interval_steps(
129129

130130
trial = optuna.trial.Trial(study, study._storage.create_new_trial(study._study_id))
131131
n_steps = max(expected_prune_steps)
132-
base_index = 1
132+
base_index = 0
133133
for i in range(base_index, base_index + n_steps):
134134
trial.report(base_index, i)
135135
study._storage.set_trial_state(trial._trial_id, TrialState.COMPLETE)
@@ -139,5 +139,5 @@ def test_median_pruner_interval_steps(
139139
if (i - base_index) % report_steps == 0:
140140
trial.report(2, i)
141141
assert pruner.prune(study=study, trial=study._storage.get_trial(trial._trial_id)) == (
142-
i > n_warmup_steps and i in expected_prune_steps
142+
i >= n_warmup_steps and i in expected_prune_steps
143143
)

Diff for: tests/pruners_tests/test_threshold.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,10 @@ def test_threshold_pruner_n_warmup_steps() -> None:
8282
trial = optuna.trial.Trial(study, study._storage.create_new_trial(study._study_id))
8383
pruner = optuna.pruners.ThresholdPruner(lower=0.0, upper=1.0, n_warmup_steps=2)
8484

85-
trial.report(-10.0, 1)
85+
trial.report(-10.0, 0)
8686
assert not pruner.prune(study=study, trial=study._storage.get_trial(trial._trial_id))
8787

88-
trial.report(100.0, 2)
88+
trial.report(100.0, 1)
8989
assert not pruner.prune(study=study, trial=study._storage.get_trial(trial._trial_id))
9090

9191
trial.report(-100.0, 3)
@@ -103,17 +103,17 @@ def test_threshold_pruner_interval_steps() -> None:
103103
trial = optuna.trial.Trial(study, study._storage.create_new_trial(study._study_id))
104104
pruner = optuna.pruners.ThresholdPruner(lower=0.0, upper=1.0, interval_steps=2)
105105

106-
trial.report(-10.0, 1)
106+
trial.report(-10.0, 0)
107107
assert pruner.prune(study=study, trial=study._storage.get_trial(trial._trial_id))
108108

109-
trial.report(100.0, 2)
109+
trial.report(100.0, 1)
110110
assert not pruner.prune(study=study, trial=study._storage.get_trial(trial._trial_id))
111111

112-
trial.report(-100.0, 3)
112+
trial.report(-100.0, 2)
113113
assert pruner.prune(study=study, trial=study._storage.get_trial(trial._trial_id))
114114

115-
trial.report(10.0, 4)
115+
trial.report(10.0, 3)
116116
assert not pruner.prune(study=study, trial=study._storage.get_trial(trial._trial_id))
117117

118-
trial.report(1000.0, 5)
118+
trial.report(1000.0, 4)
119119
assert pruner.prune(study=study, trial=study._storage.get_trial(trial._trial_id))

0 commit comments

Comments
 (0)