Skip to content

Commit 943fd25

Browse files
authored
Merge pull request optuna#1257 from sile/make-nsga-ii-sampler-faster
Add a caching mechanism to make `NSGAIIMultiObjectiveSampler` faster.
2 parents e2d11e8 + cd5d04f commit 943fd25

File tree

3 files changed

+110
-16
lines changed

3 files changed

+110
-16
lines changed

Diff for: optuna/multi_objective/samplers/_nsga2.py

+75-16
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections import defaultdict
2+
import hashlib
23
import itertools
34
from typing import Any
45
from typing import DefaultDict
@@ -19,6 +20,7 @@
1920
# Define key names of `Trial.system_attrs`.
2021
_GENERATION_KEY = "multi_objective:nsga2:generation"
2122
_PARENTS_KEY = "multi_objective:nsga2:parents"
23+
_POPULATION_CACHE_KEY_PREFIX = "multi_objective:nsga2:population"
2224

2325

2426
@experimental("1.5.0")
@@ -158,42 +160,99 @@ def sample_independent(
158160
def _collect_parent_population(
159161
self, study: "multi_objective.study.MultiObjectiveStudy"
160162
) -> Tuple[int, List["multi_objective.trial.FrozenMultiObjectiveTrial"]]:
161-
# TODO(ohta): Optimize this method.
163+
trials = [
164+
multi_objective.trial.FrozenMultiObjectiveTrial(study.n_objectives, t)
165+
for t in study._storage.get_all_trials(study._study_id, deepcopy=False)
166+
]
162167

168+
generation_to_runnings = defaultdict(list)
163169
generation_to_population = defaultdict(list)
164-
for trial in study.get_trials(deepcopy=False):
170+
for trial in trials:
171+
if _GENERATION_KEY not in trial.system_attrs:
172+
continue
173+
174+
generation = trial.system_attrs[_GENERATION_KEY]
165175
if trial.state != optuna.trial.TrialState.COMPLETE:
176+
if trial.state == optuna.trial.TrialState.RUNNING:
177+
generation_to_runnings[generation].append(trial)
166178
continue
167179

168-
generation = trial.system_attrs.get(_GENERATION_KEY, 0)
169180
generation_to_population[generation].append(trial)
170181

182+
hasher = hashlib.sha256()
171183
parent_population = [] # type: List[multi_objective.trial.FrozenMultiObjectiveTrial]
172184
parent_generation = -1
173-
for generation in itertools.count():
185+
while True:
186+
generation = parent_generation + 1
174187
population = generation_to_population[generation]
175188

176189
# Under multi-worker settings, the population size might become larger than
177190
# `self._population_size`.
178191
if len(population) < self._population_size:
179192
break
180193

181-
population.extend(parent_population)
182-
parent_population = []
183-
parent_generation = generation
194+
# [NOTE]
195+
# It's generally safe to assume that once the above condition is satisfied,
196+
# there are no additional individuals added to the generation (i.e., the members of
197+
# the generation have been fixed).
198+
# If the number of parallel workers is huge, this assumption can be broken, but
199+
# this is a very rare case and doesn't significantly impact optimization performance.
200+
# So we can ignore the case.
201+
202+
# The cache key is calculated based on the key of the previous generation and
203+
# the remaining running trials in the current population.
204+
# If there are no running trials, the new cache key becomes exactly the same as
205+
# the previous one, and the cached content will be overwritten. This allows us to
206+
# skip redundant cache key calculations when this method is called for the subsequent
207+
# trials.
208+
for trial in generation_to_runnings[generation]:
209+
hasher.update(bytes(str(trial.number), "utf-8"))
210+
211+
cache_key = "{}:{}".format(_POPULATION_CACHE_KEY_PREFIX, hasher.hexdigest())
212+
cached_generation, cached_population_numbers = study.system_attrs.get(
213+
cache_key, (-1, [])
214+
)
215+
if cached_generation >= generation:
216+
generation = cached_generation
217+
population = [trials[n] for n in cached_population_numbers]
218+
else:
219+
population.extend(parent_population)
220+
population = self._select_elite_population(study, population)
221+
222+
# To reduce the number of system attribute entries,
223+
# we cache the population information only if there are no running trials
224+
# (i.e., the information of the population has been fixed).
225+
# Usually, if there are no too delayed running trials, the single entry
226+
# will be used.
227+
if len(generation_to_runnings[generation]) == 0:
228+
population_numbers = [t.number for t in population]
229+
study.set_system_attr(
230+
cache_key, (generation, population_numbers),
231+
)
184232

185-
population_per_rank = _fast_non_dominated_sort(population, study.directions)
186-
for population in population_per_rank:
187-
if len(parent_population) + len(population) < self._population_size:
188-
parent_population.extend(population)
189-
else:
190-
n = self._population_size - len(parent_population)
191-
_crowding_distance_sort(population)
192-
parent_population.extend(population[:n])
193-
break
233+
parent_generation = generation
234+
parent_population = population
194235

195236
return parent_generation, parent_population
196237

238+
def _select_elite_population(
239+
self,
240+
study: "multi_objective.study.MultiObjectiveStudy",
241+
population: List["multi_objective.trial.FrozenMultiObjectiveTrial"],
242+
) -> List["multi_objective.trial.FrozenMultiObjectiveTrial"]:
243+
elite_population = [] # type: List[multi_objective.trial.FrozenMultiObjectiveTrial]
244+
population_per_rank = _fast_non_dominated_sort(population, study.directions)
245+
for population in population_per_rank:
246+
if len(elite_population) + len(population) < self._population_size:
247+
elite_population.extend(population)
248+
else:
249+
n = self._population_size - len(elite_population)
250+
_crowding_distance_sort(population)
251+
elite_population.extend(population[:n])
252+
break
253+
254+
return elite_population
255+
197256
def _select_parent(
198257
self,
199258
study: "multi_objective.study.MultiObjectiveStudy",

Diff for: optuna/multi_objective/study.py

+4
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,10 @@ def get_pareto_front_trials(self) -> List["multi_objective.trial.FrozenMultiObje
392392
def _storage(self) -> BaseStorage:
393393
return self._study._storage
394394

395+
@property
396+
def _study_id(self) -> int:
397+
return self._study._study_id
398+
395399

396400
def _log_completed_trial(self: Study, trial: Trial, result: float) -> None:
397401
if not _logger.isEnabledFor(logging.INFO):

Diff for: tests/multi_objective/samplers_tests/test_nsga2.py

+31
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections import Counter
22
from typing import List
3+
from typing import Tuple
34

45
import pytest
56

@@ -155,6 +156,36 @@ def test_crowding_distance_sort() -> None:
155156
assert [t.number for t in trials] == [2, 3, 0, 1]
156157

157158

159+
def test_study_system_attr_for_population_cache() -> None:
160+
sampler = multi_objective.samplers.NSGAIIMultiObjectiveSampler(population_size=10)
161+
study = multi_objective.create_study(["minimize"], sampler=sampler)
162+
163+
def get_cached_entries(
164+
study: multi_objective.study.MultiObjectiveStudy,
165+
) -> List[Tuple[int, List[int]]]:
166+
return [
167+
v
168+
for k, v in study.system_attrs.items()
169+
if k.startswith(multi_objective.samplers._nsga2._POPULATION_CACHE_KEY_PREFIX)
170+
]
171+
172+
study.optimize(lambda t: [t.suggest_uniform("x", 0, 9)], n_trials=10)
173+
cached_entries = get_cached_entries(study)
174+
assert len(cached_entries) == 0
175+
176+
study.optimize(lambda t: [t.suggest_uniform("x", 0, 9)], n_trials=1)
177+
cached_entries = get_cached_entries(study)
178+
assert len(cached_entries) == 1
179+
assert cached_entries[0][0] == 0 # Cached generation.
180+
assert len(cached_entries[0][1]) == 10 # Population size.
181+
182+
study.optimize(lambda t: [t.suggest_uniform("x", 0, 9)], n_trials=10)
183+
cached_entries = get_cached_entries(study)
184+
assert len(cached_entries) == 1
185+
assert cached_entries[0][0] == 1 # Cached generation.
186+
assert len(cached_entries[0][1]) == 10 # Population size.
187+
188+
158189
# TODO(ohta): Consider to move this utility function to `optuna.testing` module.
159190
def _create_frozen_trial(
160191
number: int, values: List[float]

0 commit comments

Comments
 (0)