-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
114 lines (89 loc) · 3.18 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import sys
import logging
import copy
import torch
from utils import factory
from utils.data_manager import DataManager
from utils.toolkit import count_parameters
import os
def train(args):
seed_list = copy.deepcopy(args["seed"])
device = copy.deepcopy(args["device"])
for seed in seed_list:
args["seed"] = seed
args["device"] = device
_train(args)
def _train(args):
init_cls = 0 if args ["init_cls"] == args["increment"] else args["init_cls"]
logs_name = "logs/{}/{}/{}/{}".format(args["model_name"],args["dataset"], init_cls, args['increment'])
if not os.path.exists(logs_name):
os.makedirs(logs_name)
logfilename = "logs/{}/{}/{}/{}/{}_{}_{}".format(
args["model_name"],
args["dataset"],
init_cls,
args["increment"],
args["prefix"],
args["seed"],
args["convnet_type"],
)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(filename)s] => %(message)s",
handlers=[
logging.FileHandler(filename=logfilename + ".log"),
logging.StreamHandler(sys.stdout),
],
)
_set_random(args["seed"])
#_set_random()
_set_device(args)
print_args(args)
data_manager = DataManager(
args["dataset"],
args["shuffle"],
args["seed"],
args["init_cls"],
args["increment"],
)
model = factory.get_model(args["model_name"], args)
cnn_curve, nme_curve, maha_curve = {"top1": [], "top5": []}, {"top1": [], "top5": []}, {"top1": [], "top5": []}
for task in range(data_manager.nb_tasks):
logging.info("All params: {}".format(count_parameters(model._network)))
logging.info(
"Trainable params: {}".format(count_parameters(model._network, True))
)
model.incremental_train(data_manager)
cnn_accy, nme_accy = model.eval_task()
model.after_task()
if nme_accy is not None:
logging.info("CNN: {}".format(cnn_accy["grouped"]))
logging.info("NME: {}".format(nme_accy["grouped"]))
cnn_curve["top1"].append(cnn_accy["top1"])
nme_curve["top1"].append(nme_accy["top1"])
logging.info("CNN top1 curve: {}".format(cnn_curve["top1"]))
logging.info("NME top1 curve: {}".format(nme_curve["top1"]))
else:
logging.info("No NME accuracy.")
logging.info("CNN: {}".format(cnn_accy["grouped"]))
cnn_curve["top1"].append(cnn_accy["top1"])
logging.info("CNN top1 curve: {}".format(cnn_curve["top1"]))
def _set_device(args):
device_type = args["device"]
gpus = []
for device in device_type:
if device_type == -1:
device = torch.device("cpu")
else:
device = torch.device("cuda:{}".format(device))
gpus.append(device)
args["device"] = gpus
def _set_random(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def print_args(args):
for key, value in args.items():
logging.info("{}: {}".format(key, value))