-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_utils.py
54 lines (50 loc) · 1.99 KB
/
train_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import torch
import shutil
import numpy as np
import os
def save_checkpoint(state, checkpoint='checkpoint', filename='checkpoint.pth.tar' , AUC_BEST = False, ACC_BEST = False):
name_save = ''
filepath = os.path.join(checkpoint, filename)
torch.save(state, filepath)
if AUC_BEST :
name_save = 'model_best.pth.tar'
shutil.copyfile(filepath, os.path.join(checkpoint, name_save))
if ACC_BEST :
name_save = 'model_best_accuracy.pth.tar'
shutil.copyfile(filepath, os.path.join(checkpoint, name_save))
def save_checkpoint_for_unlearning(state, checkpoint='checkpoint', filename='checkpoint.pth.tar' , isLoss=False, isAcc=False):
filepath = os.path.join(checkpoint, filename)
torch.save(state, filepath)
if isLoss :
name_save = 'model_lowest_loss.pth.tar'
shutil.copyfile(filepath, os.path.join(checkpoint, name_save))
if isAcc :
name_save = 'model_best_acc.pth.tar'
shutil.copyfile(filepath, os.path.join(checkpoint, name_save))
def adjust_learning_rate(optimizer, epoch, opt):
lr_set = [1, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5]
lr_list = opt.schedule.copy()
lr_list.append(epoch)
lr_list.sort()
idx = lr_list.index(epoch)
opt.lr *= lr_set[idx]
for param_group in optimizer.param_groups:
param_group['lr'] = opt.lr
#mhkim add-temp
def save_arr_acc_loss(list_acc,list_real_acc,list_fake_acc,list_loss,
list_val_acc,list_val_real_acc,list_val_fake_acc,list_val_loss,
path):
list_final,list_val_final=[],[]
list_final.append(list_acc)
list_final.append(list_real_acc)
list_final.append(list_fake_acc)
list_final.append(list_loss)
list_val_final.append(list_val_acc)
list_val_final.append(list_val_real_acc)
list_val_final.append(list_val_fake_acc)
list_val_final.append(list_val_loss)
train_dir = path + '_train'
val_dir = path + '_val'
np.save(train_dir,list_final)
np.save(val_dir,list_val_final)