-
Notifications
You must be signed in to change notification settings - Fork 1
/
sweep_vit.yaml
86 lines (85 loc) · 2.09 KB
/
sweep_vit.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# first we specify what we're sweeping
# we specify a program to run
program: train.py
# we optionally specify how to run it, including setting default arguments
command:
- ${env}
- ${interpreter}
- ${program}
- "--wandb"
- "--log_every_n_steps"
- "50"
- "--max_epochs_ss"
- "5"
- "--max_epochs_noisy"
- "5"
- "--max_epochs_clean"
- "1000"
- "--val_check_interval"
- "1.0"
- "--gradient_clip_val"
- "0.5"
- "--precision"
- "16"
- "--run_noisy"
- "--run_ss"
- "--run_clean"
- ${args} # these arguments come from the sweep parameters below
# and we specify which parameters to sweep over, what we're optimizing, and how we want to optimize it
method: random # generally, random searches perform well, can also be "grid" or "bayes"
metric:
name: final_target
goal: maximize
parameters:
##########################
# pipeline hyperparameters
###########################
ls_modifier:
values: [0.0, 0.1, 0.2]
##########################
# litmodel hyperparameters
###########################
#suggest 0.0006
lr_ss:
values: [0.00001]
lr:
values: [0.001]
lr_ft:
values: [0.0001]
###########################
# MAE hyperparameters
###########################
decoder_dim:
values: [128]
decoder_dim_head:
values: [32]
decoder_depth:
values: [1]
decoder_heads:
values: [4]
masking_ratio:
values: [0.95]
###########################
# ViT hyperparameters
###########################
dim:
values: [128]
dim_head:
values: [16,32]
patch_len:
values: [5]
mlp_dim:
values: [16,32]
dropout:
values: [0, 0.2]
emb_dropout:
values: [0]
# we can also fix some values, just like we set default arguments
#gpus:
# value: 1
model_class:
value: vit.simpleVIT
data_class:
value: xarray_module.XarrayDataModule
ft_schedule:
value: hyperiap/litmodels/LitClassifier_vit_ft_schedule.yaml