-
Notifications
You must be signed in to change notification settings - Fork 4
/
train_obj.py
382 lines (329 loc) · 13.1 KB
/
train_obj.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
import argparse
import os
import logging
import torch
import torch.distributed
import torch.nn as nn
from torch import optim
from source.training_utils import save_checkpoint, save_model, LinearWarmupScheduler, add_gradient_histograms
from source.data.datasets.objs.load_data import load_data
from source.utils import str2bool
from torch.utils.tensorboard import SummaryWriter
import accelerate
from accelerate import Accelerator
# Visualization
import torch.nn.functional as F
from tqdm import tqdm
# for distributed training
from torch.distributed.nn.functional import all_gather
def create_logger(logging_dir):
"""
Create a logger that writes to a log file and stdout.
"""
logging.basicConfig(
level=logging.INFO,
format="[\033[34m%(asctime)s\033[0m] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[
logging.StreamHandler(),
logging.FileHandler(f"{logging_dir}/log.txt"),
],
)
logger = logging.getLogger(__name__)
return logger
def simclr(zs, temperature=1.0, normalize=True, loss_type="ip"):
# zs: list of tensors. Each tensor has shape (n, d)
if normalize:
zs = [F.normalize(z, p=2, dim=-1) for z in zs]
if zs[0].dim() == 3:
zs = [z.flatten(1, 2) for z in zs]
m = len(zs)
n = zs[0].shape[0]
device = zs[0].device
mask = torch.eye(n * m, device=device)
label0 = torch.fmod(n + torch.arange(0, m * n, device=device), n * m)
z = torch.cat(zs, 0)
if loss_type == "euclid": # euclidean distance
sim = -torch.cdist(z, z)
elif loss_type == "sq": # squared euclidean distance
sim = -(torch.cdist(z, z) ** 2)
elif loss_type == "ip": # inner product
sim = torch.matmul(z, z.transpose(0, 1))
else:
raise NotImplementedError
logit_zz = sim / temperature
logit_zz += mask * -1e8
loss = nn.CrossEntropyLoss()(logit_zz, label0)
return loss
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Training options
parser.add_argument("--exp_name", type=str, help="expname")
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--beta", type=float, default=0.998, help="ema decay")
parser.add_argument("--epochs", type=int, default=500, help="num of epochs")
parser.add_argument(
"--checkpoint_every",
type=int,
default=50,
help="save checkpoint every specified epochs",
)
parser.add_argument("--lr", type=float, default=1e-3, help="lr")
parser.add_argument("--warmup_iters", type=int, default=0)
parser.add_argument(
"--finetune",
type=str,
default=None,
help="path to the checkpoint. Training starts from that checkpoint",
)
# Data loading
parser.add_argument("--limit_cores_used", type=str2bool, default=False)
parser.add_argument("--cpu_core_start", type=int, default=0, help="start core")
parser.add_argument("--cpu_core_end", type=int, default=32, help="end core")
parser.add_argument("--data", type=str, default="clevrtex")
parser.add_argument(
"--data_root",
type=str,
default=None,
help="Optional. Specify the root dir of the dataset. If None, use a default path set for each dataset",
)
parser.add_argument("--batchsize", type=int, default=256)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument(
"--data_imsize",
type=int,
default=None,
help="Image size. If None, use the default size of each dataset",
)
# Simclr options
parser.add_argument("--normalize", type=str2bool, default=True)
parser.add_argument("--temp", type=float, default=0.1, help="simclr temperature.")
# General model options
parser.add_argument("--model", type=str, default="akorn", help="model")
parser.add_argument("--L", type=int, default=1, help="num of layers")
parser.add_argument("--ch", type=int, default=256, help="num of channels")
parser.add_argument(
"--model_imsize",
type=int,
default=None,
help=
"""
Model's imsize. This is used when you want finetune a pretrained model
that was trained on images with different resolution than the finetune image dataset.
"""
)
parser.add_argument("--autorescale", type=str2bool, default=False)
parser.add_argument("--psize", type=int, default=8, help="patch size")
parser.add_argument("--ksize", type=int, default=1, help="kernel size")
parser.add_argument("--T", type=int, default=8, help="num of recurrence")
parser.add_argument(
"--maxpool", type=str2bool, default=True, help="max pooling or avg pooling"
)
parser.add_argument(
"--heads", type=int, default=8, help="num of heads in self-attention"
)
parser.add_argument(
"--gta",
type=str2bool,
default=True,
help="""
use Geometric Transform Attention (https://github.com/autonomousvision/gta) as positional encoding.
Note that, different from the original GTA, the rotating matrices are learnable.
If False, use standard absolute positional encoding used in the original transformer paper.
""",
)
# AKOrN options
parser.add_argument("--N", type=int, default=4, help="num of rotating dimensions")
parser.add_argument("--gamma", type=float, default=1.0, help="step size")
parser.add_argument("--J", type=str, default="conv", help="connectivity")
parser.add_argument("--use_omega", type=str2bool, default=False)
parser.add_argument("--global_omg", type=str2bool, default=False)
parser.add_argument(
"--c_norm",
type=str,
default="gn",
help="normalization. gn(GroupNorm), sandb(scale and bias), or none",
)
parser.add_argument(
"--init_omg", type=float, default=0.01, help="initial omega length"
)
parser.add_argument("--learn_omg", type=str2bool, default=False)
parser.add_argument(
"--use_ro_x",
type=str2bool,
default=False,
help="apply linear transform to oscillators between consecutive layers",
)
# ablation of some components in the AKOrN's block
parser.add_argument(
"--no_ro", type=str2bool, default=False, help="ablation: no use readout module"
)
parser.add_argument(
"--project",
type=str2bool,
default=True,
help="use projection or not in the Kuramoto layer",
)
args = parser.parse_args()
torch.backends.cudnn.benchmark = True
torch.backends.cuda.enable_flash_sdp(enabled=True)
# Setup accelerator
accelerator = Accelerator()
device = accelerator.device
accelerate.utils.set_seed(args.seed + accelerator.process_index)
import random
import numpy as np
torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
# Create job directory and logger
jobdir = f"runs/{args.exp_name}/"
if accelerator.is_main_process:
if not os.path.exists(jobdir):
os.makedirs(jobdir) # Make results folder (holds all experiment subfolders)
logger = create_logger(jobdir)
logger.info(f"Experiment directory created at {jobdir}")
else:
logger = create_logger(jobdir)
if args.limit_cores_used:
def worker_init_fn(worker_id):
os.sched_setaffinity(0, range(args.cpu_core_start, args.cpu_core_end))
else:
worker_init_fn = None
sstrainset, imsize, _ = load_data(args.data, args.data_root, args.data_imsize, False)
if accelerator.is_main_process:
logger.info(f"Dataset contains {len(sstrainset):,} images")
ssloader = torch.utils.data.DataLoader(
sstrainset,
batch_size=int(args.batchsize // accelerator.num_processes),
shuffle=True,
num_workers=args.num_workers,
worker_init_fn=worker_init_fn,
)
if accelerator.is_main_process:
writer = SummaryWriter(jobdir)
def train(net, ema, opt, scheduler, loader, epoch):
losses = []
initial_params = {name: param.clone() for name, param in net.named_parameters()}
running_loss = 0.0
n = 0
for i, data in tqdm(enumerate(loader, 0)):
net.train()
inputs = data.view(-1, 3, imsize, imsize).to(device) # 2x batchsize
# forward
outputs = net(inputs)
# gather outputs because simclr loss requires all outputs across all processes
if accelerator.num_processes > 1:
outputs = torch.cat(all_gather(outputs), 0)
outputs = outputs.unflatten(0, (outputs.shape[0] // 2, 2))
loss = simclr(
[outputs[:, 0], outputs[:, 1]],
temperature=args.temp,
normalize=args.normalize,
loss_type="ip",
)
opt.zero_grad()
accelerator.backward(loss)
opt.step()
scheduler.step()
running_loss += loss.item() * inputs.shape[0]
n += inputs.shape[0]
ema.update()
if accelerator.is_main_process:
add_gradient_histograms(writer, net, epoch)
for name, param in net.named_parameters():
diff = param - initial_params[name]
writer.add_histogram(f"{name}_diff", diff, epoch)
if accelerator.is_main_process:
logger.info(
f"[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss/n:.3f}"
)
total_loss = running_loss / n
if accelerator.is_main_process:
writer.add_scalar("training loss", total_loss, epoch)
return total_loss
if args.model == "akorn":
from source.models.objs.knet import AKOrN
net = AKOrN(
args.N,
ch=args.ch,
L=args.L,
T=args.T,
gamma=args.gamma,
J=args.J, # "conv" or "attn",
use_omega=args.use_omega,
global_omg=args.global_omg,
c_norm=args.c_norm,
psize=args.psize,
imsize=imsize if args.model_imsize is None else args.model_imsize,
autorescale=args.autorescale,
init_omg=args.init_omg,
learn_omg=args.learn_omg,
maxpool=args.maxpool,
project=args.project,
heads=args.heads,
use_ro_x=args.use_ro_x,
no_ro=args.no_ro,
gta=args.gta,
).to("cuda")
elif args.model == "vit":
from source.models.objs.vit import ViT
# T=1: ViT. T > 1: ItrSA.
net = ViT(
psize=args.psize,
imsize=imsize if args.model_imsize is None else args.model_imsize,
autorescale=args.autorescale,
ch=args.ch,
blocks=args.L,
heads=args.heads,
mlp_dim=2 * args.ch,
T=args.T,
maxpool=args.maxpool,
gta=args.gta,
).cuda()
else:
raise NotImplementedError
total_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(f"Total number of basemodel parameters: {total_params}")
if args.finetune:
if accelerator.is_main_process:
logger.info("Loading checkpoint...")
net.load_state_dict(torch.load(args.finetune)["model_state_dict"])
optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=0.0)
if args.finetune:
if accelerator.is_main_process:
logger.info("Loading optimizer state...")
optimizer.load_state_dict(torch.load(args.finetune)["optimizer_state_dict"])
for param_group in optimizer.param_groups:
param_group["lr"] = args.lr
from ema_pytorch import EMA
ema = EMA(net, beta=args.beta, update_every=10, update_after_step=200)
if args.finetune:
if accelerator.is_main_process:
logger.info("Loading checkpoint...")
dir_name, file_name = os.path.split(args.finetune)
file_name = file_name.replace("checkpoint", "ema")
ema_path = os.path.join(dir_name, file_name)
ema.load_state_dict(torch.load(ema_path)["model_state_dict"])
if accelerator.is_main_process:
logger.info(f"Training for {args.epochs} epochs...")
net, optimizer, ssloader = accelerator.prepare(net, optimizer, ssloader)
scheduler = LinearWarmupScheduler(optimizer, warmup_iters=args.warmup_iters)
for epoch in range(0, args.epochs):
total_loss = train(net, ema, optimizer, scheduler, ssloader, epoch)
if (epoch + 1) % args.checkpoint_every == 0:
if accelerator.is_main_process:
save_checkpoint(
accelerator.unwrap_model(net),
optimizer,
epoch,
total_loss,
checkpoint_dir=jobdir,
)
save_model(ema, epoch, checkpoint_dir=jobdir, prefix="ema")
if accelerator.is_main_process:
torch.save(
accelerator.unwrap_model(net).state_dict(),
os.path.join(jobdir, f"model.pth"),
)
torch.save(ema.state_dict(), os.path.join(jobdir, f"ema_model.pth"))