-
Notifications
You must be signed in to change notification settings - Fork 0
/
statistics_utils_store.py
64 lines (39 loc) · 1.68 KB
/
statistics_utils_store.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from collections import OrderedDict
class StatisticsSaver(object):
def __init__(self):
# the length of the bellow dictionaries is assumed to be the batch size
'''
key:epoch nubmer
value:list of losses
'''
self.discriminator_loss_real = OrderedDict()
'''
key:epoch nubmer
value:number of batches
'''
self.number_batches = OrderedDict()
self.mlp_model = None
self.mlp_optimizer = None
self.losses_total = []
self.accuracy_total = []
self.total_traininng_time = []
def set_trainned_model(self, mlp_model):
self.mlp_model = mlp_model
def set_trainned_optimizer(self, mlp_optimizer):
self.mlp_optimizer = mlp_optimizer
def save_trainned_model(self, save_model_file_name, last_epoch):
torch.save(self.mlp_model,
save_model_file_name + "_whole"+'.pth')
def save_model_params(self, save_model_file_name, last_epoch):
torch.save({'epoch' : last_epoch,
'model_state_dict' : self.mlp_model.state_dict(),
'optimizer_state_dict' : self.mlp_optimizer.state_dict()
},
save_model_file_name + "_params"+'.pth')
def save_losses(self, losses):
self.losses_total.append(losses)
def save_accuracy(self, accuracy):
self.accuracy_total.append(accuracy)
def save_total_training_time(self,total_time):
self.total_traininng_time.append(total_time)