forked from syne-tune/syne-tune
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlaunch_asha_yahpo.py
137 lines (121 loc) · 4.31 KB
/
launch_asha_yahpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""
Example for running ASHA with 4 workers with the simulator back-end based on three Yahpo surrogate benchmarks.
"""
import logging
from dataclasses import dataclass
import matplotlib.pyplot as plt
from syne_tune.blackbox_repository import BlackboxRepositoryBackend
from syne_tune.backend.simulator_backend.simulator_callback import SimulatorCallback
from syne_tune.experiments import load_experiment
from syne_tune.optimizer.baselines import ASHA
from syne_tune import Tuner, StoppingCriterion
def plot_yahpo_learning_curves(
trial_backend, benchmark: str, time_col: str, metric_col: str
):
bb = trial_backend.blackbox
plt.figure()
plt.title(
f"Learning curves from Yahpo {benchmark} for 10 different hyperparameters."
)
for i in range(10):
config = {k: v.sample() for k, v in bb.configuration_space.items()}
evals = bb(config)
time_index = next(
i for i, name in enumerate(bb.objectives_names) if name == time_col
)
accuracy_index = next(
i for i, name in enumerate(bb.objectives_names) if name == metric_col
)
import numpy as np
if np.diff(evals[:, time_index]).min() < 0:
print("negative time between two different steps...")
plt.plot(evals[:, time_index], evals[:, accuracy_index])
plt.xlabel(time_col)
plt.ylabel(metric_col)
plt.show()
@dataclass
class BenchmarkInfo:
blackbox_name: str
elapsed_time_attr: str
metric: str
dataset: str
mode: str
max_t: int
resource_attr: str
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
benchmark_infos = {
"nb301": BenchmarkInfo(
elapsed_time_attr="runtime",
metric="val_accuracy",
blackbox_name="yahpo-nb301",
dataset="CIFAR10",
mode="max",
max_t=97,
resource_attr="epoch",
),
"lcbench": BenchmarkInfo(
elapsed_time_attr="time",
metric="val_accuracy",
blackbox_name="yahpo-lcbench",
dataset="3945",
mode="max",
max_t=51,
resource_attr="epoch",
),
"fcnet": BenchmarkInfo(
elapsed_time_attr="runtime",
metric="valid_mse",
blackbox_name="yahpo-fcnet",
dataset="fcnet_naval_propulsion",
mode="min",
max_t=99,
resource_attr="epoch",
),
}
for benchmark in ["nb301", "lcbench", "fcnet"]:
benchmark_info = benchmark_infos[benchmark]
trial_backend = BlackboxRepositoryBackend(
blackbox_name=benchmark_info.blackbox_name,
elapsed_time_attr=benchmark_info.elapsed_time_attr,
dataset=benchmark_info.dataset,
)
plot_yahpo_learning_curves(
trial_backend,
benchmark=benchmark,
time_col=benchmark_info.elapsed_time_attr,
metric_col=benchmark_info.metric,
)
scheduler = ASHA(
config_space=trial_backend.blackbox.configuration_space,
max_t=benchmark_info.max_t,
resource_attr=benchmark_info.resource_attr,
mode=benchmark_info.mode,
metric=benchmark_info.metric,
)
stop_criterion = StoppingCriterion(max_num_trials_started=200)
tuner = Tuner(
trial_backend=trial_backend,
scheduler=scheduler,
stop_criterion=stop_criterion,
n_workers=4,
sleep_time=0,
print_update_interval=10,
callbacks=[SimulatorCallback()],
tuner_name=f"ASHA-Yahpo-{benchmark}",
)
tuner.run()
tuning_experiment = load_experiment(tuner.name)
tuning_experiment.plot()