-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
210 lines (184 loc) · 9.48 KB
/
main.py
File metadata and controls
210 lines (184 loc) · 9.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import numpy as np
from torch import nn, optim
from torch.utils.tensorboard import SummaryWriter
from learn2learn.algorithms import MAML
from learn2learn.vision.models import CNN4
from tqdm import tqdm
import datetime
from utils import *
from dataset_loaders import *
from encoders import *
from partition_generators import generate_unsupervised_partitions
from task_generator import TaskGenerator
from baselines.pretraining_baseline import contrastive_pretrain, test_pretrain
from baselines.metagmvae_baseline import metagmvae_train, metagmvae_test
from analyze_results.compute_partition_overlap import compute_partition_overlap, compute_partition_overlap_to_metatest
def fast_adapt(batch, inner_learner, loss_fn, num_adaptation_steps, meta_split, args):
if meta_split == "meta_train":
K, K_te = args.KShot, args.KQuery
else:
K, K_te = args.KShotTest, args.KQueryTest
train_data, train_labels, _, test_data, test_labels, _ = batch
assert train_data.size(0) == K * args.NWay, f"{train_data.size(0)} VS {K * args.NWay}"
assert test_data.size(0) == K_te * args.NWay, f"{test_data.size(0)} VS {K_te * args.NWay}"
assert torch.numel(torch.unique(train_labels)) == torch.numel(torch.unique(test_labels)) == args.NWay
train_data, train_labels, test_data, test_labels = \
train_data.to(DEVICE), train_labels.to(DEVICE), test_data.to(DEVICE), test_labels.to(DEVICE)
# inner training (no early stopping)
for step in range(num_adaptation_steps):
train_loss = loss_fn(inner_learner(train_data), train_labels)
inner_learner.adapt(train_loss)
# inner testing
test_preds = inner_learner(test_data)
test_loss, test_accur = loss_fn(test_preds, test_labels), \
accuracy_fn(test_preds, test_labels)
return test_loss, test_accur
def train(meta_model, task_generator, optimizer, loss_fn, descriptor, args):
for i in tqdm(range(METATRAIN_OUTER_EPISODES), desc='Training Epochs'):
optimizer.zero_grad()
meta_train_loss, meta_train_accur = 0.0, 0.0
for _ in range(NUM_TASKS_METATRAIN):
# start a inner-loop learner from the current initialization parameters
inner_learner = meta_model.clone()
task_batch = task_generator.sample_task("meta_train", args)
# inner training
inner_test_loss, inner_test_accur = fast_adapt(task_batch,
inner_learner,
loss_fn,
METATRAIN_INNER_UPDATES,
"meta_train",
args)
# meta training update step
inner_test_loss.backward()
meta_train_loss += inner_test_loss.item()
meta_train_accur += inner_test_accur.item()
meta_train_loss /= NUM_TASKS_METATRAIN
meta_train_accur /= NUM_TASKS_METATRAIN
# update the meta-parameters (the initialization parameters for MAML)
for p in meta_model.parameters():
p.grad.data.mul_(1.0/NUM_TASKS_METATRAIN)
optimizer.step()
print(f"[{descriptor}] Training function completed!")
return meta_model
def test(meta_model, task_generator, loss_fn, args):
meta_test_losses, meta_test_accurs = [], []
for _ in tqdm(range(NUM_TASKS_METATEST), desc='Testing tasks'):
inner_learner = meta_model.clone()
task_batch = task_generator.sample_task("meta_test", args)
inner_test_loss, inner_test_accur = fast_adapt(task_batch,
inner_learner,
loss_fn,
METATEST_INNER_UPDATES,
"meta_test",
args)
meta_test_losses.append(inner_test_loss.item())
meta_test_accurs.append(inner_test_accur.item())
return meta_test_accurs
if __name__ == "__main__":
parser = get_args_parser()
args = parser.parse_args()
fix_seed(args.seed)
# Load meta-train datasets
(
meta_train_set,
meta_test_set,
meta_train_partitions_supervised,
meta_train_partitions_supervised_all,
meta_train_partitions_supervised_oracle,
meta_test_partitions
) = LOAD_DATASET[args.dsName](args)
if args.dsNameTest is not None:
assert args.dsNameTest != args.dsName
print(f"Using different dataset for meta-testing: {args.dsNameTest}")
(
_,
meta_test_set,
_,
_,
_,
meta_test_partitions
) = LOAD_DATASET[args.dsNameTest](args)
encoder = get_encoder(args, DEVICE)
descriptor = get_descriptor(encoder, args)
print(f"<<<<<<<<<<Running {descriptor} on {args.dsName}...>>>>>>>>>")
# Supervised meta-learning
if args.encoder in ['sup', 'scratch']:
meta_train_partitions = meta_train_partitions_supervised
elif args.encoder == 'supall':
meta_train_partitions = meta_train_partitions_supervised_all
elif args.encoder == 'supora':
meta_train_partitions = meta_train_partitions_supervised_oracle
elif args.encoder in ['simclrpretrain', 'metagmvae']:
# not using metatraining
# don't let it go through unsupervised partitions generation, as it takes quite a bit of time
meta_train_partitions = meta_train_partitions_supervised
# Unsupervised meta-learning
else:
meta_train_partitions_unsupervised = generate_unsupervised_partitions(
meta_train_set,
encoder,
descriptor,
args)
meta_train_partitions = meta_train_partitions_unsupervised
assert meta_train_partitions
task_generator = TaskGenerator(meta_train_set,
meta_test_set,
meta_train_partitions,
meta_test_partitions,
args)
if args.visualizeTasks:
assert args.encoder not in ["simclrpretrain", "metagmvae"]
visualize_constructed_tasks(task_generator, descriptor, args, n_imgs=50)
exit(0)
elif args.computePartitionOverlap:
compute_partition_overlap(meta_train_partitions, descriptor, args)
exit(0)
else:
pass
base_model = CNN4(output_size=args.NWay,
hidden_size=32,
layers=4).to(DEVICE)
meta_model = MAML(model=base_model, lr=METATRAIN_INNER_LR, first_order=False)
opt = optim.Adam(meta_model.parameters(), METATRAIN_OUTER_LR)
loss_fn = nn.CrossEntropyLoss(reduction='mean')
if args.encoder != "scratch":
model_path = os.path.join(MODELDIR, f"seed_{args.seed}", f"{descriptor}.ckpt")
os.makedirs(os.path.dirname(model_path), exist_ok=True)
try:
if args.encoder in ["simclrpretrain", "metagmvae"]:
encoder.load_state_dict(torch.load(model_path))
else:
meta_model.load_state_dict(torch.load(model_path))
print(f"[{descriptor}]: Loaded model from {model_path}!")
except FileNotFoundError:
print(f"[{descriptor}]: No model at {model_path}. Training from scratch...")
if args.encoder == "simclrpretrain":
opt = optim.Adam(encoder.parameters(), lr=PRETRAIN_LR)
encoder = contrastive_pretrain(encoder, opt, meta_train_set, descriptor, args)
torch.save(encoder.state_dict(), model_path)
elif args.encoder == "metagmvae":
opt = optim.Adam(encoder.parameters(), lr=GMVAE_METATRAIN_LR)
encoder = metagmvae_train(encoder, opt, meta_train_set, descriptor, args)
torch.save(encoder.state_dict(), model_path)
else:
meta_model = train(meta_model,
task_generator,
opt,
loss_fn,
descriptor,
args)
torch.save(meta_model.state_dict(), model_path)
print(f"Model saved at {model_path}!")
if args.encoder == "simclrpretrain":
meta_test_accurs = test_pretrain(encoder, task_generator, descriptor, args)
elif args.encoder == "metagmvae":
meta_test_accurs = metagmvae_test(encoder, task_generator, loss_fn, descriptor, args)
else:
meta_test_accurs = test(meta_model, task_generator, loss_fn, args)
with open("res.txt", "a") as f:
f.write(str(datetime.datetime.now())+f' under seed {args.seed}'+'\n')
f.write(f"[{descriptor} trained on {args.dsName} {args.NWay}-way {args.KShot}-shot {args.KQuery}-query meteTrain \n" + \
f"on {args.dsNameTest if args.dsNameTest else args.dsName} {args.NWay}-way {args.KShotTest}-shot {args.KQueryTest}-query metaTest]: \n" + \
f"Mean meta test accuracy: {np.mean(meta_test_accurs)*100:.2f}%\n")
print(f"[{descriptor} on {args.dsName}] testing completed!")
print("<<<<<<<<<<<<<<<Main script finished successfully!>>>>>>>>>>>>")