forked from syne-tune/syne-tune
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsearcher_callback.py
106 lines (88 loc) · 3.94 KB
/
searcher_callback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import logging
from syne_tune.backend.trial_status import Trial
from syne_tune.tuner_callback import StoreResultsCallback
from syne_tune.backend.simulator_backend.simulator_callback import SimulatorCallback
from syne_tune.optimizer.schedulers.fifo import FIFOScheduler
from syne_tune.optimizer.schedulers.searchers.gp_fifo_searcher import ModelBasedSearcher
from syne_tune.optimizer.schedulers.searchers.bayesopt.models.model_transformer import (
ModelStateTransformer,
)
logger = logging.getLogger(__name__)
def _get_model_based_searcher(tuner):
searcher = None
scheduler = tuner.scheduler
if isinstance(scheduler, FIFOScheduler):
if isinstance(scheduler.searcher, ModelBasedSearcher):
searcher = scheduler.searcher
state_transformer = searcher.state_transformer
if state_transformer is not None:
if not isinstance(state_transformer, ModelStateTransformer):
searcher = None
elif not state_transformer.use_single_model:
logger.warning(
"StoreResultsAndModelParamsCallback does not currently "
"support multi-model setups. Model parameters sre "
"not logged."
)
searcher = None
else:
searcher = None
return searcher
def _extended_result(searcher, result):
if searcher is not None:
kwargs = dict()
# Append surrogate model parameters to `result`
params = searcher.state_transformer.get_params()
if params:
prefix = "model_"
kwargs = {prefix + k: v for k, v in params.items()}
kwargs["cumulative_get_config_time"] = searcher.cumulative_get_config_time
result = dict(result, **kwargs)
return result
class StoreResultsAndModelParamsCallback(StoreResultsCallback):
"""
Extends :class:`StoreResultsCallback` by also storing the current
surrogate model parameters in `on_trial_result`. This works for
schedulers with model-based searchers. For other schedulers, this
callback behaves the same as the superclass.
"""
def __init__(
self,
add_wallclock_time: bool = True,
):
super().__init__(add_wallclock_time)
self._searcher = None
def on_tuning_start(self, tuner):
super().on_tuning_start(tuner)
self._searcher = _get_model_based_searcher(tuner)
def on_trial_result(self, trial: Trial, status: str, result: dict, decision: str):
result = _extended_result(self._searcher, result)
super().on_trial_result(trial, status, result, decision)
class SimulatorAndModelParamsCallback(SimulatorCallback):
"""
Extends :class:`SimulatorCallback` by also storing the current
surrogate model parameters in `on_trial_result`. This works for
schedulers with model-based searchers. For other schedulers, this
callback behaves the same as the superclass.
"""
def __init__(self):
super().__init__()
self._searcher = None
def on_tuning_start(self, tuner):
super().on_tuning_start(tuner)
self._searcher = _get_model_based_searcher(tuner)
def on_trial_result(self, trial: Trial, status: str, result: dict, decision: str):
result = _extended_result(self._searcher, result)
super().on_trial_result(trial, status, result, decision)