forked from kxz18/Research-Tools
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
70 lines (51 loc) · 2.27 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/python
# -*- coding:utf-8 -*-
import os
import argparse
import yaml
import torch
from utils.logger import print_log
from utils.random_seed import setup_seed, SEED
from utils.config_utils import overwrite_values
from utils import register as R
########### Import your packages below ##########
from trainer import create_trainer
from data import create_dataset, create_dataloader
from utils.nn_utils import count_parameters
import models
def parse():
parser = argparse.ArgumentParser(description='training')
# device
parser.add_argument('--gpus', type=int, nargs='+', required=True, help='gpu to use, -1 for cpu')
parser.add_argument("--local_rank", type=int, default=-1,
help="Local rank. Necessary for using the torch.distributed.launch utility.")
# config
parser.add_argument('--config', type=str, required=True, help='Path to the yaml configure')
parser.add_argument('--seed', type=int, default=SEED, help='Random seed')
return parser.parse_known_args()
def main(args, opt_args):
# load config
config = yaml.safe_load(open(args.config, 'r'))
config = overwrite_values(config, opt_args)
########## define your model #########
model = R.construct(config['model'])
########### load your train / valid set ###########
train_set, valid_set, _ = create_dataset(config['dataset'])
########## define your trainer/trainconfig #########
if len(args.gpus) > 1:
args.local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', world_size=len(args.gpus))
else:
args.local_rank = -1
if args.local_rank <= 0:
print_log(f'Number of parameters: {count_parameters(model) / 1e6} M')
train_loader = create_dataloader(train_set, config['dataloader'].get('train', config['dataloader']), len(args.gpus))
valid_loader = create_dataloader(valid_set, config['dataloader'].get('valid', config['dataloader']))
trainer = create_trainer(config, model, train_loader, valid_loader)
trainer.train(args.gpus, args.local_rank)
if __name__ == '__main__':
args, opt_args = parse()
print_log(f'Overwritting args: {opt_args}')
setup_seed(args.seed)
main(args, opt_args)