forked from syne-tune/syne-tune
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlaunch_fashionmnist_costaware.py
72 lines (63 loc) · 2.5 KB
/
launch_fashionmnist_costaware.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""
Example for cost-aware promotion-based Hyperband
"""
import logging
from benchmarking.definitions.definition_mlp_on_fashion_mnist import (
mlp_fashionmnist_default_params,
mlp_fashionmnist_benchmark,
)
from syne_tune.backend import LocalBackend
from syne_tune.optimizer.schedulers import HyperbandScheduler
from syne_tune import Tuner, StoppingCriterion
if __name__ == "__main__":
logging.getLogger().setLevel(logging.DEBUG)
# logging.getLogger().setLevel(logging.INFO)
# We pick the MLP on FashionMNIST benchmark
# The 'benchmark' dict contains arguments needed by scheduler and
# searcher (e.g., 'mode', 'metric'), along with suggested default values
# for other arguments (which you are free to override)
random_seed = 31415927
n_workers = 4
default_params = mlp_fashionmnist_default_params()
benchmark = mlp_fashionmnist_benchmark(default_params)
mode = benchmark["mode"]
metric = benchmark["metric"]
# If you don't like the default config_space, change it here. But let
# us use the default
config_space = benchmark["config_space"]
# Local back-end
trial_backend = LocalBackend(entry_point=benchmark["script"])
# Cost-aware variant of ASHA, using a random searcher
scheduler = HyperbandScheduler(
config_space,
searcher="random",
max_t=default_params["max_resource_level"],
grace_period=default_params["grace_period"],
reduction_factor=default_params["reduction_factor"],
resource_attr=benchmark["resource_attr"],
mode=mode,
metric=metric,
type="cost_promotion",
rung_system_kwargs={"cost_attr": benchmark["elapsed_time_attr"]},
random_seed=random_seed,
)
stop_criterion = StoppingCriterion(max_wallclock_time=120)
tuner = Tuner(
trial_backend=trial_backend,
scheduler=scheduler,
stop_criterion=stop_criterion,
n_workers=n_workers,
)
tuner.run()