-
Notifications
You must be signed in to change notification settings - Fork 57
/
train.py
224 lines (193 loc) · 7.02 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
import datetime
import json
import logging
import os
import sys
import time
from pathlib import Path
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torchvision.datasets as datasets
from models.model_configs import instantiate_model
from train_arg_parser import get_args_parser
from training import distributed_mode
from training.data_transform import get_train_transform
from training.eval_loop import eval_model
from training.grad_scaler import NativeScalerWithGradNormCount as NativeScaler
from training.load_and_save import load_model, save_model
from training.train_loop import train_one_epoch
logger = logging.getLogger(__name__)
def main(args):
logging.basicConfig(
level=logging.INFO,
stream=sys.stdout,
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
distributed_mode.init_distributed_mode(args)
logger.info("job dir: {}".format(os.path.dirname(os.path.realpath(__file__))))
logger.info("{}".format(args).replace(", ", ",\n"))
if distributed_mode.is_main_process():
args_filepath = Path(args.output_dir) / "args.json"
logger.info(f"Saving args to {args_filepath}")
with open(args_filepath, "w") as f:
json.dump(vars(args), f)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + distributed_mode.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
logger.info(f"Initializing Dataset: {args.dataset}")
transform_train = get_train_transform()
if args.dataset == "imagenet":
dataset_train = datasets.ImageFolder(args.data_path, transform=transform_train)
elif args.dataset == "cifar10":
dataset_train = datasets.CIFAR10(
root=args.data_path,
train=True,
download=True,
transform=transform_train,
)
else:
raise NotImplementedError(f"Unsupported dataset {args.dataset}")
logger.info(dataset_train)
logger.info("Intializing DataLoader")
num_tasks = distributed_mode.get_world_size()
global_rank = distributed_mode.get_rank()
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
data_loader_train = torch.utils.data.DataLoader(
dataset_train,
sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
logger.info(str(sampler_train))
# define the model
logger.info("Initializing Model")
model = instantiate_model(
architechture=args.dataset,
is_discrete=args.discrete_flow_matching,
use_ema=args.use_ema,
)
model.to(device)
model_without_ddp = model
logger.info(str(model_without_ddp))
eff_batch_size = (
args.batch_size * args.accum_iter * distributed_mode.get_world_size()
)
logger.info(f"Learning rate: {args.lr:.2e}")
logger.info(f"Accumulate grad iterations: {args.accum_iter}")
logger.info(f"Effective batch size: {eff_batch_size}")
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.gpu], find_unused_parameters=True
)
model_without_ddp = model.module
optimizer = torch.optim.AdamW(
model_without_ddp.parameters(), lr=args.lr, betas=args.optimizer_betas
)
if args.decay_lr:
lr_schedule = torch.optim.lr_scheduler.LinearLR(
optimizer,
total_iters=args.epochs,
start_factor=1.0,
end_factor=1e-8 / args.lr,
)
else:
lr_schedule = torch.optim.lr_scheduler.ConstantLR(
optimizer, total_iters=args.epochs, factor=1.0
)
logger.info(f"Optimizer: {optimizer}")
logger.info(f"Learning-Rate Schedule: {lr_schedule}")
loss_scaler = NativeScaler()
load_model(
args=args,
model_without_ddp=model_without_ddp,
optimizer=optimizer,
loss_scaler=loss_scaler,
lr_schedule=lr_schedule,
)
logger.info(f"Start from {args.start_epoch} to {args.epochs} epochs")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
if not args.eval_only:
train_stats = train_one_epoch(
model=model,
data_loader=data_loader_train,
optimizer=optimizer,
lr_schedule=lr_schedule,
device=device,
epoch=epoch,
loss_scaler=loss_scaler,
args=args,
)
log_stats = {
**{f"train_{k}": v for k, v in train_stats.items()},
"epoch": epoch,
}
else:
log_stats = {
"epoch": epoch,
}
if args.output_dir and (
(args.eval_frequency > 0 and (epoch + 1) % args.eval_frequency == 0)
or args.eval_only
or args.test_run
):
if not args.eval_only:
save_model(
args=args,
model=model,
model_without_ddp=model_without_ddp,
optimizer=optimizer,
lr_schedule=lr_schedule,
loss_scaler=loss_scaler,
epoch=epoch,
)
if args.distributed:
data_loader_train.sampler.set_epoch(0)
if distributed_mode.is_main_process():
fid_samples = args.fid_samples - (num_tasks - 1) * (
args.fid_samples // num_tasks
)
else:
fid_samples = args.fid_samples // num_tasks
eval_stats = eval_model(
model,
data_loader_train,
device,
epoch=epoch,
fid_samples=fid_samples,
args=args,
)
log_stats.update({f"eval_{k}": v for k, v in eval_stats.items()})
if args.output_dir and distributed_mode.is_main_process():
with open(
os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8"
) as f:
f.write(json.dumps(log_stats) + "\n")
if args.test_run or args.eval_only:
break
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info(f"Training time {total_time_str}")
if __name__ == "__main__":
args = get_args_parser()
args = args.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)