|
1 | 1 | from collections import defaultdict
|
| 2 | +import hashlib |
2 | 3 | import itertools
|
3 | 4 | from typing import Any
|
4 | 5 | from typing import DefaultDict
|
|
19 | 20 | # Define key names of `Trial.system_attrs`.
|
20 | 21 | _GENERATION_KEY = "multi_objective:nsga2:generation"
|
21 | 22 | _PARENTS_KEY = "multi_objective:nsga2:parents"
|
| 23 | +_POPULATION_CACHE_KEY_PREFIX = "multi_objective:nsga2:population" |
22 | 24 |
|
23 | 25 |
|
24 | 26 | @experimental("1.5.0")
|
@@ -158,42 +160,99 @@ def sample_independent(
|
158 | 160 | def _collect_parent_population(
|
159 | 161 | self, study: "multi_objective.study.MultiObjectiveStudy"
|
160 | 162 | ) -> Tuple[int, List["multi_objective.trial.FrozenMultiObjectiveTrial"]]:
|
161 |
| - # TODO(ohta): Optimize this method. |
| 163 | + trials = [ |
| 164 | + multi_objective.trial.FrozenMultiObjectiveTrial(study.n_objectives, t) |
| 165 | + for t in study._storage.get_all_trials(study._study_id, deepcopy=False) |
| 166 | + ] |
162 | 167 |
|
| 168 | + generation_to_runnings = defaultdict(list) |
163 | 169 | generation_to_population = defaultdict(list)
|
164 |
| - for trial in study.get_trials(deepcopy=False): |
| 170 | + for trial in trials: |
| 171 | + if _GENERATION_KEY not in trial.system_attrs: |
| 172 | + continue |
| 173 | + |
| 174 | + generation = trial.system_attrs[_GENERATION_KEY] |
165 | 175 | if trial.state != optuna.trial.TrialState.COMPLETE:
|
| 176 | + if trial.state == optuna.trial.TrialState.RUNNING: |
| 177 | + generation_to_runnings[generation].append(trial) |
166 | 178 | continue
|
167 | 179 |
|
168 |
| - generation = trial.system_attrs.get(_GENERATION_KEY, 0) |
169 | 180 | generation_to_population[generation].append(trial)
|
170 | 181 |
|
| 182 | + hasher = hashlib.sha256() |
171 | 183 | parent_population = [] # type: List[multi_objective.trial.FrozenMultiObjectiveTrial]
|
172 | 184 | parent_generation = -1
|
173 |
| - for generation in itertools.count(): |
| 185 | + while True: |
| 186 | + generation = parent_generation + 1 |
174 | 187 | population = generation_to_population[generation]
|
175 | 188 |
|
176 | 189 | # Under multi-worker settings, the population size might become larger than
|
177 | 190 | # `self._population_size`.
|
178 | 191 | if len(population) < self._population_size:
|
179 | 192 | break
|
180 | 193 |
|
181 |
| - population.extend(parent_population) |
182 |
| - parent_population = [] |
183 |
| - parent_generation = generation |
| 194 | + # [NOTE] |
| 195 | + # It's generally safe to assume that once the above condition is satisfied, |
| 196 | + # there are no additional individuals added to the generation (i.e., the members of |
| 197 | + # the generation have been fixed). |
| 198 | + # If the number of parallel workers is huge, this assumption can be broken, but |
| 199 | + # this is a very rare case and doesn't significantly impact optimization performance. |
| 200 | + # So we can ignore the case. |
| 201 | + |
| 202 | + # The cache key is calculated based on the key of the previous generation and |
| 203 | + # the remaining running trials in the current population. |
| 204 | + # If there are no running trials, the new cache key becomes exactly the same as |
| 205 | + # the previous one, and the cached content will be overwritten. This allows us to |
| 206 | + # skip redundant cache key calculations when this method is called for the subsequent |
| 207 | + # trials. |
| 208 | + for trial in generation_to_runnings[generation]: |
| 209 | + hasher.update(bytes(str(trial.number), "utf-8")) |
| 210 | + |
| 211 | + cache_key = "{}:{}".format(_POPULATION_CACHE_KEY_PREFIX, hasher.hexdigest()) |
| 212 | + cached_generation, cached_population_numbers = study.system_attrs.get( |
| 213 | + cache_key, (-1, []) |
| 214 | + ) |
| 215 | + if cached_generation >= generation: |
| 216 | + generation = cached_generation |
| 217 | + population = [trials[n] for n in cached_population_numbers] |
| 218 | + else: |
| 219 | + population.extend(parent_population) |
| 220 | + population = self._select_elite_population(study, population) |
| 221 | + |
| 222 | + # To reduce the number of system attribute entries, |
| 223 | + # we cache the population information only if there are no running trials |
| 224 | + # (i.e., the information of the population has been fixed). |
| 225 | + # Usually, if there are no too delayed running trials, the single entry |
| 226 | + # will be used. |
| 227 | + if len(generation_to_runnings[generation]) == 0: |
| 228 | + population_numbers = [t.number for t in population] |
| 229 | + study.set_system_attr( |
| 230 | + cache_key, (generation, population_numbers), |
| 231 | + ) |
184 | 232 |
|
185 |
| - population_per_rank = _fast_non_dominated_sort(population, study.directions) |
186 |
| - for population in population_per_rank: |
187 |
| - if len(parent_population) + len(population) < self._population_size: |
188 |
| - parent_population.extend(population) |
189 |
| - else: |
190 |
| - n = self._population_size - len(parent_population) |
191 |
| - _crowding_distance_sort(population) |
192 |
| - parent_population.extend(population[:n]) |
193 |
| - break |
| 233 | + parent_generation = generation |
| 234 | + parent_population = population |
194 | 235 |
|
195 | 236 | return parent_generation, parent_population
|
196 | 237 |
|
| 238 | + def _select_elite_population( |
| 239 | + self, |
| 240 | + study: "multi_objective.study.MultiObjectiveStudy", |
| 241 | + population: List["multi_objective.trial.FrozenMultiObjectiveTrial"], |
| 242 | + ) -> List["multi_objective.trial.FrozenMultiObjectiveTrial"]: |
| 243 | + elite_population = [] # type: List[multi_objective.trial.FrozenMultiObjectiveTrial] |
| 244 | + population_per_rank = _fast_non_dominated_sort(population, study.directions) |
| 245 | + for population in population_per_rank: |
| 246 | + if len(elite_population) + len(population) < self._population_size: |
| 247 | + elite_population.extend(population) |
| 248 | + else: |
| 249 | + n = self._population_size - len(elite_population) |
| 250 | + _crowding_distance_sort(population) |
| 251 | + elite_population.extend(population[:n]) |
| 252 | + break |
| 253 | + |
| 254 | + return elite_population |
| 255 | + |
197 | 256 | def _select_parent(
|
198 | 257 | self,
|
199 | 258 | study: "multi_objective.study.MultiObjectiveStudy",
|
|
0 commit comments