-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_resnet.py
144 lines (115 loc) · 5.27 KB
/
eval_resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import json
import torch
import torch.nn.functional as F
import hydra
import logging
from hydra.utils import to_absolute_path
from omegaconf import DictConfig
from model import *
from utils import ECE, get_dataloader
log = logging.getLogger(__name__)
def eval(model, dataset, data_dir, device, test_bsize=512, intensity=0, corrupt_types=None, ece_bins=15):
"""
Evaluates the performance of the given model on the provided test data.
Returns:
acc (float): Accuracy of the model on the test set.
ece (float): Expected Calibration Error of the model on the test set.
nll (float): Negative Log-Likelihood of the model on the test set.
"""
testloader = get_dataloader(data_dir=data_dir, dataset=dataset,
batch_size=test_bsize,
train=False,
intensity=intensity,
corrupt_types=corrupt_types)
ece_eval = ECE(n_bins=ece_bins)
pred_total, labels_total = [], []
correct_count = 0
nll_total = 0.
size_testset = len(testloader) * test_bsize
model.eval()
with torch.no_grad():
for imgs, labels in testloader:
imgs, labels = imgs.to(device), labels.to(device)
pred = model(imgs)
pred = pred.exp()
nll = F.nll_loss(torch.log(pred), labels) # the input of F.nll_loss should be log-likelihood
nll_total += nll.item() * labels.size(0)
_, pred_id = torch.max(pred, dim=-1)
correct_count += (pred_id==labels).sum().item()
pred_total.append(pred)
labels_total.append(labels)
acc = correct_count/size_testset
nll = nll_total/size_testset
pred_total = torch.cat(pred_total, axis=0)
labels_total = torch.cat(labels_total, axis=0)
ece = ece_eval(pred_total, labels_total) # the input of ece_eval should be probability
return acc, ece, nll
@hydra.main(config_path='configuration/conf_resnet18', config_name='eval_v0_config')
def main(cfg: DictConfig):
dataset_name = cfg.dataset.name
datadir_clean = to_absolute_path(cfg.dataset.dir_clean)
datadir_corrupted = to_absolute_path(cfg.dataset.dir_corrupted)
n_classes = cfg.dataset.n_classes
in_channel = cfg.dataset.in_channel
corrupt_types = cfg.dataset.corrupt_types
experiment_name = cfg.experiment.name
res_dir = to_absolute_path(cfg.experiment.res_dir)
seed = cfg.experiment.seed
model_name = f'{cfg.model.name}_v{cfg.model.version}'
ck_dir = to_absolute_path(cfg.model.ck_dir)
n_epochs = cfg.params.n_epoch
test_bsize = cfg.params.batch_size
ece_bins = cfg.params.ece_bins
os.makedirs(ck_dir, exist_ok=True)
os.makedirs(res_dir, exist_ok=True)
log.info(f'Experiment: {experiment_name}')
log.info(f' -Seed: {seed}')
log.info(f'Dataset: {dataset_name}')
log.info(f'Model: {model_name} epochs: {n_epochs}')
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
results = {}
results_file = os.path.join(res_dir, 'evaluation_results.json')
if os.path.exists(results_file):
with open(results_file, 'r') as f:
results = json.load(f)
for epoch in n_epochs:
model = ResNet18(num_classes=n_classes, in_channels=in_channel).to(device)
test_ck_path = os.path.join(ck_dir, f'resnet18_epoch{epoch}.pt')
model.load_state_dict(torch.load(test_ck_path))
log.info(f'Evaluating model at epoch {epoch}')
intensity = 0
acc_clean, ece_clean, nll_clean = eval(model=model,
dataset=dataset_name,
data_dir=datadir_clean,
test_bsize=test_bsize,
device=device,
ece_bins=ece_bins)
results[model_name] = {}
results[model_name][intensity] = []
results[model_name][intensity].append({
'acc': acc_clean,
'ece': ece_clean,
'nll': nll_clean
})
for intensity in range(1, 6):
acc_corrupted, ece_corrupted, nll_corrupted = eval(model=model,
dataset=f'{dataset_name}-C',
data_dir=datadir_corrupted,
test_bsize=test_bsize,
intensity=intensity,
corrupt_types=corrupt_types,
device=device,
ece_bins=ece_bins)
results[model_name][intensity] = []
results[model_name][intensity].append({
'acc': acc_corrupted,
'ece': ece_corrupted,
'nll': nll_corrupted
})
with open(results_file, 'w') as f:
json.dump(results, f, indent=4)
log.info(f'Results saved to {results_file}')
if __name__ == "__main__":
main()