-
Notifications
You must be signed in to change notification settings - Fork 3
/
test_all_imagenet.py
133 lines (111 loc) · 5.47 KB
/
test_all_imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import argparse
import torch
from tqdm import tqdm
import data_loader.data_loaders as module_data
from data_loader.imagenet_lt_data_loaders import ImageNetLTDataLoader
import model.loss as module_loss
import model.metric as module_metric
import model.model as module_arch
import numpy as np
from parse_config import ConfigParser
import torch.nn.functional as F
from utils import adjusted_model_wrapper
def main(config, posthoc_bias_correction=False):
logger = config.get_logger('test')
# build model architecture
if 'returns_feat' in config['arch']['args']:
model = config.init_obj('arch', module_arch, allow_override=True, returns_feat=False)
else:
model = config.init_obj('arch', module_arch)
logger.info('Loading checkpoint: {} ...'.format(config.resume))
checkpoint = torch.load(config.resume)
state_dict = checkpoint['state_dict']
if config['n_gpu'] > 1:
model = torch.nn.DataParallel(model)
model.load_state_dict(state_dict)
# prepare model for testing
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()
num_classes = config._config["arch"]["args"]["num_classes"]
record_list=[]
test_distribution_set = ["forward50", "forward25", "forward10", "forward5", "forward2", "uniform", "backward2", "backward5", "backward10", "backward25", "backward50"]
for test_distribution in test_distribution_set:
test_txt = '/ImageNet_LT_%s.txt'%(test_distribution)
print(test_txt)
data_loader = ImageNetLTDataLoader(
config['data_loader']['args']['data_dir'],
batch_size=128,
shuffle=False,
training=False,
num_workers=2,
test_txt=test_txt
)
if posthoc_bias_correction:
test_prior = torch.tensor(data_loader.cls_num_list).float().to(device)
test_prior = test_prior / test_prior.sum()
test_bias = test_prior.log()
else:
test_bias = None
adjusted_model = adjusted_model_wrapper(model, test_bias=test_bias)
record = validation(data_loader, adjusted_model, num_classes, device)
record_list.append(record)
print('='*25, ' Final results ', '='*25)
i = 0
for txt in record_list:
print(test_distribution_set[i]+'\t')
print(*txt)
i+=1
def mic_acc_cal(preds, labels):
if isinstance(labels, tuple):
assert len(labels) == 3
targets_a, targets_b, lam = labels
acc_mic_top1 = (lam * preds.eq(targets_a.data).cpu().sum().float() \
+ (1 - lam) * preds.eq(targets_b.data).cpu().sum().float()) / len(preds)
else:
acc_mic_top1 = (preds == labels).sum().item() / len(labels)
return acc_mic_top1
def validation(data_loader, model, num_classes, device):
b = np.load("./data/imagenet_lt_shot_list.npy")
many_shot = b[0]
medium_shot = b[1]
few_shot = b[2]
confusion_matrix = torch.zeros(num_classes, num_classes).cuda()
total_logits = torch.empty((0, num_classes)).cuda()
total_labels = torch.empty(0, dtype=torch.long).cuda()
with torch.no_grad():
for i, (data, target) in enumerate(tqdm(data_loader)):
data, target = data.to(device), target.to(device)
output = model(data)
for t, p in zip(target.view(-1), output.argmax(dim=1).view(-1)):
confusion_matrix[t.long(), p.long()] += 1
total_logits = torch.cat((total_logits, output))
total_labels = torch.cat((total_labels, target))
probs, preds = F.softmax(total_logits.detach(), dim=1).max(dim=1)
# Calculate the overall accuracy and F measurement
eval_acc_mic_top1= mic_acc_cal(preds[total_labels != -1],
total_labels[total_labels != -1])
print('All top-1 Acc:', np.round(eval_acc_mic_top1 * 100, decimals=2))
acc_per_class = confusion_matrix.diag()/confusion_matrix.sum(1)
acc = acc_per_class.cpu().numpy()
many_shot_acc = acc[many_shot].mean()
medium_shot_acc = acc[medium_shot].mean()
few_shot_acc = acc[few_shot].mean()
print("{}, {}, {}".format(np.round(many_shot_acc * 100, decimals=2), np.round(medium_shot_acc * 100, decimals=2), np.round(few_shot_acc * 100, decimals=2)))
return np.round(many_shot_acc * 100, decimals=2), np.round(medium_shot_acc * 100, decimals=2), np.round(few_shot_acc * 100, decimals=2), np.round(eval_acc_mic_top1 * 100, decimals=2)
if __name__ == '__main__':
args = argparse.ArgumentParser(description='PyTorch Template')
args.add_argument('-c', '--config', default=None, type=str,
help='config file path (default: None)')
args.add_argument('-r', '--resume', default=None, type=str,
help='path to latest checkpoint (default: None)')
args.add_argument('-d', '--device', default=None, type=str,
help='indices of GPUs to enable (default: all)')
args.add_argument('-l', '--log-config', default='logger/logger_config.json', type=str,
help='logging config file path (default: logger/logger_config.json)')
args.add_argument("--posthoc_bias_correction", dest="posthoc_bias_correction", action="store_true", default=False)
# dummy arguments used during training time
args.add_argument("--validate")
args.add_argument("--use-wandb")
config, args = ConfigParser.from_args(args)
main(config, args.posthoc_bias_correction)