-
Notifications
You must be signed in to change notification settings - Fork 12
/
template.yaml
105 lines (88 loc) · 2.85 KB
/
template.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# @package _global_
# A config template for running Weasel on your own data and end-model
# ??? means that the field must be defined.
seed: 7
ignore_warnings: True
datamodule:
_target_: ???
# The following three params are ignored by the data_module, but semantically are tied to the dataset,
# which is why they are included here. In practice though, only Weasel needs to know them, see below in the Weasel params
# where they are all automatically inferred from these params in datamodule.
num_LFs: ???
n_classes: ???
class_balance: null # will assume uniform classes if null
# Actual params used by DataModule
batch_size: 64
val_test_split: [250, -1]
num_workers: 0
pin_memory: True
seed: 3
trainer:
_target_: pytorch_lightning.Trainer
gpus: 0 # set to 0 to train on CPU, and >= 1 for GPU(s)
min_epochs: 1
max_epochs: 100
# resume_from_checkpoint: <ckpt_dir>/best.ckpt
end_model:
_target_: ???
# What comes here really depends on the end-model!
# dropout: 0.3
# net_norm: "none"
# activation_func: "GeLU"
# input_dim: 300
# hidden_dims: [50, 50, 25]
output_dim: ${datamodule.n_classes}
# adjust_thresh: True
Weasel:
_target_: weasel.models.weasel.Weasel
# very convenient interpolation by Hydra/OmegaConf that infers params from the dataset details
num_LFs: ${datamodule.num_LFs}
n_classes: ${datamodule.n_classes}
class_balance: ${datamodule.class_balance}
monitor: "Val/auc" # adjust for multi-class case
loss_function: "cross_entropy"
temperature: 2.0
accuracy_scaler: "sqrt"
use_aux_input_for_encoder: False
class_conditional_accuracies: True
encoder:
_target_: weasel.models.encoder_models.encoder_MLP.MLPEncoder
# example_extra_input: [1, 300] # OR (1, ${end_model.input_dim})
# |--> Commented out, since it is better to specify the 'example_input_array' attribute in your end-model
dropout: 0.3
net_norm: "batch_norm"
activation_func: "GeLU"
hidden_dims: [70, 70] # although you probably want to different hidden dim. configurations!
optim_end_model:
_target_: torch.optim.Adam
lr: 1e-4
weight_decay: 7e-7
optim_encoder:
_target_: torch.optim.Adam
lr: 1e-4
weight_decay: 7e-7
scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
mode: "max"
factor: 0.5
patience: 10
verbose: True
callbacks:
model_checkpoint:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
monitor: ${Weasel.monitor}
save_top_k: 2
save_last: True
mode: "max"
dirpath: "checkpoints/"
filename: "{epoch:02d}"
logger:
wandb:
_target_: pytorch_lightning.loggers.wandb.WandbLogger
# entity: "" # optionally set to name of your wandb team
name: null
tags: []
notes: "..."
project: "Weasel" # give it a better name ;)
group: ""
mode: online # disabled # disabled for no wandb logging