forked from syne-tune/syne-tune
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhp_ranges.py
282 lines (245 loc) · 10.8 KB
/
hp_ranges.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import Tuple, List, Iterable, Dict, Optional
import numpy as np
from numpy.random import RandomState
from syne_tune.config_space import (
non_constant_hyperparameter_keys,
is_log_space,
config_to_match_string,
is_reverse_log_space,
)
from syne_tune.optimizer.schedulers.searchers.utils.common import (
Hyperparameter,
Configuration,
)
__all__ = ["HyperparameterRanges"]
def _filter_constant_hyperparameters(config_space: Dict) -> Dict:
nonconst_keys = set(non_constant_hyperparameter_keys(config_space))
return {k: v for k, v in config_space.items() if k in nonconst_keys}
class HyperparameterRanges:
def __init__(
self,
config_space: Dict,
name_last_pos: Optional[str] = None,
value_for_last_pos=None,
active_config_space: Optional[Dict] = None,
prefix_keys: Optional[List[str]] = None,
):
"""
If name_last_pos is given, the hyperparameter of that name is assigned
the final position in the vector returned by `to_ndarray`. This can be
used to single out the (time) resource for a GP model, where that
component has to come last.
If in this case (name_last_pos given), value_for_last_pos is also given,
some methods are modified:
- `random_config` samples a config as normal, but then overwrites the
name_last_pos component by value_for_last_pos
- `get_ndarray_bounds` works as normal, but returns bound (a, a) for
name_last_pos component, where a is the internal value corresponding
to value_for_last_pos
The use case is HPO with a resource attribute. This attribute should be
fixed when optimizing the acquisition function, but can take different
values in the evaluation data (coming from all previous searches).
If `active_config_space` is given, it contains a subset of non-constant
hyperparameters in `config_space`, and the range of each entry is a
subset of the range of the corresponding `config_space` entry. These
active ranges affect the choice of new configs (by sampling). While the
internal encoding is based on original ranges, search is restricted to
active ranges (e.g., optimization of surrogate model).
:param config_space: Configuration space. Constant hyperparameters are
filtered out here
:param name_last_pos: See above
:param value_for_last_pos: See above
:param active_config_space: See above
:param prefix_keys: If given, these keys into `config_space` come first
in the internal ordering, which determines the internal
encoding
"""
self.config_space = _filter_constant_hyperparameters(config_space)
self.name_last_pos = name_last_pos
self.value_for_last_pos = value_for_last_pos
self._set_internal_keys(prefix_keys)
self._set_active_config_space(active_config_space)
def _set_internal_keys(self, prefix_keys: Optional[List[str]]):
keys = sorted(self.config_space.keys())
if prefix_keys is not None:
pk_set = set(prefix_keys)
assert pk_set.issubset(
set(keys)
), f"prefix_keys = {prefix_keys} is not a subset of {keys}"
keys = prefix_keys + [key for key in keys if key not in pk_set]
if self.name_last_pos is not None:
assert self.name_last_pos in keys, (
f"name_last_pos = '{self.name_last_pos}' not among "
+ f"hyperparameter names [{keys}]"
)
pos = keys.index(self.name_last_pos)
keys = keys[:pos] + keys[(pos + 1) :] + [self.name_last_pos]
self._internal_keys = keys
def _set_active_config_space(self, active_config_space: Dict):
if active_config_space is None:
self.active_config_space = dict()
self._config_space_for_sampling = self.config_space
else:
self._assert_sub_config_space(active_config_space)
self.active_config_space = active_config_space
self._config_space_for_sampling = dict(
self.config_space, **active_config_space
)
def _assert_sub_config_space(self, active_config_space: Dict):
for k, v in active_config_space.items():
assert (
k in self.config_space
), f"active_config_space[{k}] not in config_space"
same_value_type = v.value_type == self.config_space[k].value_type
same_log_type = is_log_space(v) == is_log_space(
self.config_space[k]
) and is_reverse_log_space(v) == is_reverse_log_space(self.config_space[k])
same_domain_type = isinstance(v, type(self.config_space[k]))
assert (
k in self.config_space
and same_value_type
and same_log_type
and same_domain_type
), f"active_config_space[{k}] has different type"
@property
def internal_keys(self) -> List[str]:
return self._internal_keys
@property
def config_space_for_sampling(self) -> Dict:
return self._config_space_for_sampling
def to_ndarray(self, config: Configuration) -> np.ndarray:
"""
Categorical values are one-hot encoded.
:param config: Config to encode
:return: Encoded HP vector
"""
raise NotImplementedError()
def to_ndarray_matrix(self, configs: Iterable[Configuration]) -> np.ndarray:
return np.vstack([self.to_ndarray(config) for config in configs])
@property
def ndarray_size(self) -> int:
"""
:return: Dimensionality of encoded HP vector returned by `to_ndarray`
"""
raise NotImplementedError()
def from_ndarray(self, enc_config: np.ndarray) -> Configuration:
"""
Converts a config from internal ndarray representation (fed to the
GP) to its external (dict) representation. This typically involves
rounding.
"""
raise NotImplementedError()
@property
def encoded_ranges(self) -> Dict[str, Tuple[int, int]]:
"""
:return: Ranges of HPs in the encoded ndarray representation
"""
raise NotImplementedError()
def is_attribute_fixed(self):
return (self.name_last_pos is not None) and (
self.value_for_last_pos is not None
)
def _fix_attribute_value(self, name):
return self.is_attribute_fixed() and name == self.name_last_pos
def _transform_config(self, config: Configuration):
if self.is_attribute_fixed():
config[self.name_last_pos] = self.value_for_last_pos
return config
def _random_config(self, random_state: RandomState) -> Configuration:
return {
k: v.sample(random_state=random_state)
for k, v in self._config_space_for_sampling.items()
}
def random_config(self, random_state: RandomState) -> Configuration:
return self._transform_config(self._random_config(random_state))
def _random_configs(
self, random_state: RandomState, num_configs: int
) -> List[Configuration]:
return [self._random_config(random_state) for _ in range(num_configs)]
def random_configs(self, random_state, num_configs: int) -> List[Configuration]:
return [
self._transform_config(config)
for config in self._random_configs(random_state, num_configs)
]
def get_ndarray_bounds(self) -> List[Tuple[float, float]]:
"""
Returns (lower, upper) bounds for each dimension in ndarray vector
representation.
"""
raise NotImplementedError()
def __repr__(self) -> str:
raise NotImplementedError()
def __eq__(self, other: object) -> bool:
raise NotImplementedError()
def __len__(self) -> int:
return len(self.config_space)
def filter_for_last_pos_value(
self, configs: List[Configuration]
) -> List[Configuration]:
"""
If is_attribute_fixed, `configs` is filtered by removing
entries whose name_last_pos attribute value is different from
value_for_last_pos. Otherwise, it is returned unchanged.
"""
if self.is_attribute_fixed():
configs = [
config
for config in configs
if config[self.name_last_pos] == self.value_for_last_pos
]
return configs
def config_to_tuple(
self, config: Configuration, keys=None, skip_last: bool = False
) -> Tuple[Hyperparameter, ...]:
"""
:param config: Configuration
:param keys: Overrides `_internal_keys`
:param skip_last: If True and `name_last_pos` is used, the
corresponding attribute is skipped, so that config and
tuple are non-extended
:return: Tuple representation
"""
if keys is None:
keys = self.internal_keys
if skip_last and self.name_last_pos is not None:
keys = keys[:-1] # Skip last pos
return tuple(config[k] for k in keys)
def tuple_to_config(
self, config_tpl: Tuple[Hyperparameter, ...], keys=None, skip_last: bool = False
) -> Configuration:
if keys is None:
keys = self.internal_keys
if skip_last and self.name_last_pos is not None:
keys = keys[:-1] # Skip last pos
return dict(zip(keys, config_tpl))
def config_to_match_string(
self, config: Configuration, keys=None, skip_last: bool = False
) -> str:
"""
Maps configuration to match string, used to compare for approximate
equality. Two configurations are considered to be different if their
match strings are not the same.
:param config: Configuration
:param keys: Overrides `_internal_keys`
:param skip_last: If True and `name_last_pos` is used, the
corresponding attribute is skipped, so that config and
tuple are non-extended
:return: Match string
"""
if keys is None:
keys = self.internal_keys
if skip_last and self.name_last_pos is not None:
keys = keys[:-1] # Skip last pos
return config_to_match_string(config, self.config_space, keys)