Skip to content

Commit bd0d1c6

Browse files
committed
🔧 Update CIFAR10 experiments (script + config files)
1 parent aa27f6c commit bd0d1c6

25 files changed

+553
-354
lines changed

experiments/classification/cifar10/configs/resnet.yaml

Lines changed: 0 additions & 30 deletions
This file was deleted.

experiments/classification/cifar10/configs/resnet18/batched.yaml

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# lightning.pytorch==2.1.3
21
seed_everything: false
32
eval_after_fit: true
43
trainer:
@@ -23,22 +22,34 @@ trainer:
2322
patience: 1000
2423
check_finite: true
2524
model:
25+
model:
26+
class_path: torch_uncertainty.models.classification.batched_resnet
27+
init_args:
28+
in_channels: 3
29+
num_classes: 10
30+
arch: 18
31+
num_estimators: 4
32+
style: cifar
2633
num_classes: 10
27-
in_channels: 3
2834
loss: CrossEntropyLoss
29-
version: batched
30-
arch: 18
31-
style: cifar
32-
num_estimators: 4
35+
is_ensemble: true
36+
format_batch_fn:
37+
class_path: torch_uncertainty.transforms.RepeatTarget
38+
init_args:
39+
num_repeats: 4
3340
data:
3441
root: ./data
3542
batch_size: 128
3643
optimizer:
37-
lr: 0.05
38-
momentum: 0.9
39-
weight_decay: 5e-4
44+
class_path: torch.optim.SGD
45+
init_args:
46+
lr: 0.05
47+
momentum: 0.9
48+
weight_decay: 5e-4
4049
lr_scheduler:
41-
milestones:
42-
- 25
43-
- 50
44-
gamma: 0.1
50+
class_path: torch.optim.lr_scheduler.MultiStepLR
51+
init_args:
52+
milestones:
53+
- 25
54+
- 50
55+
gamma: 0.1
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# lightning.pytorch==2.1.3
2+
seed_everything: false
3+
eval_after_fit: true
4+
trainer:
5+
accelerator: gpu
6+
devices: 1
7+
precision: 16-mixed
8+
max_epochs: 75
9+
logger:
10+
class_path: lightning.pytorch.loggers.TensorBoardLogger
11+
init_args:
12+
save_dir: logs/resnet18
13+
name: deep_ensembles
14+
default_hp_metric: false
15+
callbacks:
16+
- class_path: torch_uncertainty.callbacks.TUClsCheckpoint
17+
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
18+
init_args:
19+
logging_interval: step
20+
- class_path: lightning.pytorch.callbacks.EarlyStopping
21+
init_args:
22+
monitor: val/cls/Acc
23+
patience: 1000
24+
check_finite: true
25+
model:
26+
model:
27+
class_path: torch_uncertainty.models.deep_ensembles
28+
init_args:
29+
core_models:
30+
class_path: torch_uncertainty.models.classification.resnet
31+
init_args:
32+
in_channels: 3
33+
num_classes: 10
34+
arch: 18
35+
style: cifar
36+
num_estimators: 4
37+
task: classification
38+
# eventually you can pass the checkpoints of standard resnet18 models here
39+
# ckpt_paths: [path/to/ckpt1, path/to/ckpt2, path/to/ckpt3, path/to/ckpt4]
40+
num_classes: 10
41+
loss: CrossEntropyLoss
42+
is_ensemble: true
43+
format_batch_fn:
44+
class_path: torch_uncertainty.transforms.RepeatTarget
45+
init_args:
46+
num_repeats: 4
47+
data:
48+
root: ./data
49+
batch_size: 128
50+
optimizer:
51+
class_path: torch.optim.SGD
52+
init_args:
53+
lr: 0.2 # initial learning rate times 4 (num_estimators)
54+
momentum: 0.9
55+
weight_decay: 5e-4
56+
lr_scheduler:
57+
class_path: torch.optim.lr_scheduler.MultiStepLR
58+
init_args:
59+
milestones:
60+
- 25
61+
- 50
62+
gamma: 0.1

experiments/classification/cifar10/configs/resnet18/masked.yaml

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,34 @@ trainer:
2323
patience: 1000
2424
check_finite: true
2525
model:
26+
model:
27+
class_path: torch_uncertainty.models.classification.masked_resnet
28+
init_args:
29+
in_channels: 3
30+
num_classes: 10
31+
arch: 18
32+
num_estimators: 4
33+
scale: 2
34+
style: cifar
2635
num_classes: 10
27-
in_channels: 3
2836
loss: CrossEntropyLoss
29-
version: masked
30-
arch: 18
31-
style: cifar
32-
num_estimators: 4
33-
scale: 2
37+
format_batch_fn:
38+
class_path: torch_uncertainty.transforms.RepeatTarget
39+
init_args:
40+
num_repeats: 4
3441
data:
3542
root: ./data
3643
batch_size: 128
3744
optimizer:
38-
lr: 0.05
39-
momentum: 0.9
40-
weight_decay: 5e-4
45+
class_path: torch.optim.SGD
46+
init_args:
47+
lr: 0.05
48+
momentum: 0.9
49+
weight_decay: 5e-4
4150
lr_scheduler:
42-
milestones:
43-
- 25
44-
- 50
45-
gamma: 0.1
51+
class_path: torch.optim.lr_scheduler.MultiStepLR
52+
init_args:
53+
milestones:
54+
- 25
55+
- 50
56+
gamma: 0.1

experiments/classification/cifar10/configs/resnet18/mimo.yaml

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,36 @@ trainer:
2323
patience: 1000
2424
check_finite: true
2525
model:
26+
model:
27+
class_path: torch_uncertainty.models.classification.mimo_resnet
28+
init_args:
29+
in_channels: 3
30+
num_classes: 10
31+
arch: 18
32+
num_estimators: 4
33+
style: cifar
2634
num_classes: 10
27-
in_channels: 3
2835
loss: CrossEntropyLoss
29-
version: mimo
30-
arch: 18
31-
style: cifar
32-
num_estimators: 4
33-
rho: 1.0
36+
is_ensemble: true
37+
format_batch_fn:
38+
class_path: torch_uncertainty.transforms.MIMOBatchFormat
39+
init_args:
40+
num_estimators: 4
41+
rho: 1.0
42+
batch_repeat: 1
3443
data:
3544
root: ./data
3645
batch_size: 128
3746
optimizer:
38-
lr: 0.05
39-
momentum: 0.9
40-
weight_decay: 5e-4
47+
class_path: torch.optim.SGD
48+
init_args:
49+
lr: 0.05
50+
momentum: 0.9
51+
weight_decay: 5e-4
4152
lr_scheduler:
42-
milestones:
43-
- 25
44-
- 50
45-
gamma: 0.1
53+
class_path: torch.optim.lr_scheduler.MultiStepLR
54+
init_args:
55+
milestones:
56+
- 25
57+
- 50
58+
gamma: 0.1

experiments/classification/cifar10/configs/resnet18/packed.yaml

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,36 @@ trainer:
2323
patience: 1000
2424
check_finite: true
2525
model:
26+
model:
27+
class_path: torch_uncertainty.models.classification.packed_resnet
28+
init_args:
29+
in_channels: 3
30+
num_classes: 10
31+
arch: 18
32+
style: cifar
33+
num_estimators: 4
34+
alpha: 2
35+
gamma: 2
2636
num_classes: 10
27-
in_channels: 3
2837
loss: CrossEntropyLoss
29-
version: packed
30-
arch: 18
31-
style: cifar
32-
num_estimators: 4
33-
alpha: 2
34-
gamma: 2
38+
is_ensemble: true
39+
format_batch_fn:
40+
class_path: torch_uncertainty.transforms.RepeatTarget
41+
init_args:
42+
num_repeats: 4
3543
data:
3644
root: ./data
3745
batch_size: 128
3846
optimizer:
39-
lr: 0.05
40-
momentum: 0.9
41-
weight_decay: 5e-4
47+
class_path: torch.optim.SGD
48+
init_args:
49+
lr: 0.05
50+
momentum: 0.9
51+
weight_decay: 5e-4
4252
lr_scheduler:
43-
milestones:
44-
- 25
45-
- 50
46-
gamma: 0.1
53+
class_path: torch.optim.lr_scheduler.MultiStepLR
54+
init_args:
55+
milestones:
56+
- 25
57+
- 50
58+
gamma: 0.1

experiments/classification/cifar10/configs/resnet18/standard.yaml

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,28 @@ trainer:
2323
patience: 1000
2424
check_finite: true
2525
model:
26+
model:
27+
class_path: torch_uncertainty.models.classification.resnet
28+
init_args:
29+
in_channels: 3
30+
num_classes: 10
31+
arch: 18
32+
style: cifar
2633
num_classes: 10
27-
in_channels: 3
2834
loss: CrossEntropyLoss
29-
version: std
30-
arch: 18
31-
style: cifar
3235
data:
3336
root: ./data
3437
batch_size: 128
3538
optimizer:
36-
lr: 0.05
37-
momentum: 0.9
38-
weight_decay: 5e-4
39+
class_path: torch.optim.SGD
40+
init_args:
41+
lr: 0.05
42+
momentum: 0.9
43+
weight_decay: 5e-4
3944
lr_scheduler:
40-
milestones:
41-
- 25
42-
- 50
43-
gamma: 0.1
45+
class_path: torch.optim.lr_scheduler.MultiStepLR
46+
init_args:
47+
milestones:
48+
- 25
49+
- 50
50+
gamma: 0.1

experiments/classification/cifar10/configs/resnet50/batched.yaml

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,36 @@ trainer:
2323
patience: 1000
2424
check_finite: true
2525
model:
26+
model:
27+
class_path: torch_uncertainty.models.classification.batched_resnet
28+
init_args:
29+
in_channels: 3
30+
num_classes: 10
31+
arch: 50
32+
num_estimators: 4
33+
style: cifar
2634
num_classes: 10
27-
in_channels: 3
2835
loss: CrossEntropyLoss
29-
version: batched
30-
arch: 50
31-
style: cifar
32-
num_estimators: 4
36+
is_ensemble: true
37+
format_batch_fn:
38+
class_path: torch_uncertainty.transforms.RepeatTarget
39+
init_args:
40+
num_repeats: 4
3341
data:
3442
root: ./data
3543
batch_size: 128
3644
optimizer:
37-
lr: 0.08
38-
momentum: 0.9
39-
weight_decay: 5e-4
40-
nesterov: true
45+
class_path: torch.optim.SGD
46+
init_args:
47+
lr: 0.08
48+
momentum: 0.9
49+
weight_decay: 5e-4
50+
nesterov: true
4151
lr_scheduler:
42-
milestones:
43-
- 60
44-
- 120
45-
- 160
46-
gamma: 0.2
52+
class_path: torch.optim.lr_scheduler.MultiStepLR
53+
init_args:
54+
milestones:
55+
- 60
56+
- 120
57+
- 160
58+
gamma: 0.2

0 commit comments

Comments
 (0)