Skip to content

Commit 8ec13c1

Browse files
committed
Included evaluation metrics based on cosine similarities - see metrics folder and README.md
1 parent 4df1292 commit 8ec13c1

11 files changed

+313
-9
lines changed

README.md

+10-2
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,19 @@ For a detailed explanation of everything, please refer to the supplementary of o
240240

241241

242242
### Evaluation Metrics
243-
243+
**Metrics based on Euclidean Distances**
244244
* **Recall@k**: Include R@1 e.g. with `e_recall@1` into the list of evaluation metrics `--evaluation_metrics`.
245245
* **Normalized Mutual Information (NMI)**: Include with `nmi`.
246-
* **F1**: include with `nmi`.
246+
* **F1**: include with `f1`.
247247
* **mAP (class-averaged)**: Include standard mAP at Recall with `mAP_lim`. You may also include `mAP_1000` for mAP limited to Recall@1000, and `mAP_c` limited to mAP at Recall@Max_Num_Samples_Per_Class. Note that all of these are heavily correlated.
248+
249+
**Metrics based on Cosine Similarities** *(not included by default)*
250+
* **Cosine Recall@k**: Cosine-Similarity variant of Recall@k. Include with `c_recall@k` in `--evaluation_metrics`.
251+
* **Cosine Normalized Mutual Information (NMI)**: Include with `c_nmi`.
252+
* **Cosine F1**: include with `c_f1`.
253+
* **Cosine mAP (class-averaged)**: Include cosine similarity mAP at Recall variants with `c_mAP_lim`. You may also include `c_mAP_1000` for mAP limited to Recall@1000, and `c_mAP_c` limited to mAP at Recall@Max_Num_Samples_Per_Class.
254+
255+
**Embedding Space Metrics**
248256
* **Spectral Variance**: This metric refers to the spectral decay metric used in our ICML paper. Include it with `rho_spectrum@1`. To exclude the `k` largest spectral values for a more robust estimate, simply include `rho_spectrum@k+1`. Adding `rho_spectrum@0` logs the whole singular value distribution, and `rho_spectrum@-1` computes KL(q,p) instead of KL(p,q).
249257
* **Mean Intraclass Distance**: Include the mean intraclass distance via `dists@intra`.
250258
* **Mean Interclass Distance**: Include the mean interlcass distance via `dists@inter`.

metrics/__init__.py

+72-4
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
from metrics import e_recall, dists, rho_spectrum
2-
from metrics import nmi, f1, mAP, mAP_c, mAP_1000, mAP_lim
1+
from metrics import e_recall, nmi, f1, mAP, mAP_c, mAP_1000, mAP_lim
2+
from metrics import dists, rho_spectrum
3+
from metrics import c_recall, c_nmi, c_f1, c_mAP_c, c_mAP_1000, c_mAP_lim
34
import numpy as np
45
import faiss
56
import torch
7+
from sklearn.preprocessing import normalize
68
from tqdm import tqdm
79
import copy
810

911

1012
def select(metricname, opt):
13+
#### Metrics based on euclidean distances
1114
if 'e_recall' in metricname:
1215
k = int(metricname.split('@')[-1])
1316
return e_recall.Metric(k)
@@ -23,6 +26,25 @@ def select(metricname, opt):
2326
return mAP_1000.Metric()
2427
elif metricname=='f1':
2528
return f1.Metric()
29+
30+
#### Metrics based on cosine similarity
31+
elif 'c_recall' in metricname:
32+
k = int(metricname.split('@')[-1])
33+
return c_recall.Metric(k)
34+
elif metricname=='c_nmi':
35+
return c_nmi.Metric()
36+
elif metricname=='c_mAP':
37+
return c_mAP.Metric()
38+
elif metricname=='c_mAP_c':
39+
return c_mAP_c.Metric()
40+
elif metricname=='c_mAP_lim':
41+
return c_mAP_lim.Metric()
42+
elif metricname=='c_mAP_1000':
43+
return c_mAP_1000.Metric()
44+
elif metricname=='c_f1':
45+
return c_f1.Metric()
46+
47+
#### Generic Embedding space metrics
2648
elif 'dists' in metricname:
2749
mode = metricname.split('@')[-1]
2850
return dists.Metric(mode)
@@ -91,9 +113,12 @@ def compute_standard(self, opt, model, dataloader, evaltypes, device, **kwargs):
91113

92114
import time
93115
for evaltype in evaltypes:
94-
features = np.vstack(feature_colls[evaltype]).astype('float32')
116+
features = np.vstack(feature_colls[evaltype]).astype('float32')
117+
features_cosine = normalize(features, axis=1)
95118

96119
start = time.time()
120+
121+
"""============ Compute k-Means ==============="""
97122
if 'kmeans' in self.requires:
98123
### Set CPU Cluster index
99124
cluster_idx = faiss.IndexFlatL2(features.shape[-1])
@@ -106,13 +131,36 @@ def compute_standard(self, opt, model, dataloader, evaltypes, device, **kwargs):
106131
kmeans.train(features, cluster_idx)
107132
centroids = faiss.vector_float_to_array(kmeans.centroids).reshape(n_classes, features.shape[-1])
108133

134+
if 'kmeans_cosine' in self.requires:
135+
### Set CPU Cluster index
136+
cluster_idx = faiss.IndexFlatL2(features_cosine.shape[-1])
137+
if res is not None: cluster_idx = faiss.index_cpu_to_gpu(res, 0, cluster_idx)
138+
kmeans = faiss.Clustering(features_cosine.shape[-1], n_classes)
139+
kmeans.niter = 20
140+
kmeans.min_points_per_centroid = 1
141+
kmeans.max_points_per_centroid = 1000000000
142+
### Train Kmeans
143+
kmeans.train(features_cosine, cluster_idx)
144+
centroids_cosine = faiss.vector_float_to_array(kmeans.centroids).reshape(n_classes, features_cosine.shape[-1])
145+
centroids_cosine = normalize(centroids,axis=1)
146+
109147

148+
"""============ Compute Cluster Labels ==============="""
110149
if 'kmeans_nearest' in self.requires:
111150
faiss_search_index = faiss.IndexFlatL2(centroids.shape[-1])
112151
if res is not None: faiss_search_index = faiss.index_cpu_to_gpu(res, 0, faiss_search_index)
113152
faiss_search_index.add(centroids)
114153
_, computed_cluster_labels = faiss_search_index.search(features, 1)
115154

155+
if 'kmeans_nearest_cosine' in self.requires:
156+
faiss_search_index = faiss.IndexFlatIP(centroids_cosine.shape[-1])
157+
if res is not None: faiss_search_index = faiss.index_cpu_to_gpu(res, 0, faiss_search_index)
158+
faiss_search_index.add(centroids_cosine)
159+
_, computed_cluster_labels_cosine = faiss_search_index.search(features_cosine, 1)
160+
161+
162+
163+
"""============ Compute Nearest Neighbours ==============="""
116164
if 'nearest_features' in self.requires:
117165
faiss_search_index = faiss.IndexFlatL2(features.shape[-1])
118166
if res is not None: faiss_search_index = faiss.index_cpu_to_gpu(res, 0, faiss_search_index)
@@ -122,18 +170,38 @@ def compute_standard(self, opt, model, dataloader, evaltypes, device, **kwargs):
122170
_, k_closest_points = faiss_search_index.search(features, int(max_kval+1))
123171
k_closest_classes = target_labels.reshape(-1)[k_closest_points[:,1:]]
124172

173+
if 'nearest_features_cosine' in self.requires:
174+
faiss_search_index = faiss.IndexFlatIP(features_cosine.shape[-1])
175+
if res is not None: faiss_search_index = faiss.index_cpu_to_gpu(res, 0, faiss_search_index)
176+
faiss_search_index.add(normalize(features_cosine,axis=1))
177+
178+
max_kval = np.max([int(x.split('@')[-1]) for x in self.metric_names if 'recall' in x])
179+
_, k_closest_points_cosine = faiss_search_index.search(normalize(features_cosine,axis=1), int(max_kval+1))
180+
k_closest_classes_cosine = target_labels.reshape(-1)[k_closest_points_cosine[:,1:]]
181+
182+
183+
125184
###
126185
if self.pars.evaluate_on_gpu:
127-
features = torch.from_numpy(features).to(self.pars.device)
186+
features = torch.from_numpy(features).to(self.pars.device)
187+
features_cosine = torch.from_numpy(features_cosine).to(self.pars.device)
128188

129189
start = time.time()
130190
for metric in self.list_of_metrics:
131191
input_dict = {}
132192
if 'features' in metric.requires: input_dict['features'] = features
133193
if 'target_labels' in metric.requires: input_dict['target_labels'] = target_labels
194+
134195
if 'kmeans' in metric.requires: input_dict['centroids'] = centroids
135196
if 'kmeans_nearest' in metric.requires: input_dict['computed_cluster_labels'] = computed_cluster_labels
136197
if 'nearest_features' in metric.requires: input_dict['k_closest_classes'] = k_closest_classes
198+
199+
if 'features_cosine' in metric.requires: input_dict['features_cosine'] = features_cosine
200+
201+
if 'kmeans_cosine' in metric.requires: input_dict['centroids_cosine'] = centroids_cosine
202+
if 'kmeans_nearest_cosine' in metric.requires: input_dict['computed_cluster_labels_cosine'] = computed_cluster_labels_cosine
203+
if 'nearest_features_cosine' in metric.requires: input_dict['k_closest_classes_cosine'] = k_closest_classes_cosine
204+
137205
computed_metrics[evaltype][metric.name] = metric(**input_dict)
138206

139207
extra_infos[evaltype] = {'features':features, 'target_labels':target_labels,

metrics/c_f1.py

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import numpy as np
2+
from scipy.special import comb, binom
3+
import torch
4+
5+
class Metric():
6+
def __init__(self, **kwargs):
7+
self.requires = ['kmeans_cosine', 'kmeans_nearest_cosine', 'features_cosine', 'target_labels']
8+
self.name = 'c_f1'
9+
10+
def __call__(self, target_labels, computed_cluster_labels_cosine, features_cosine, centroids_cosine):
11+
import time
12+
start = time.time()
13+
if isinstance(features_cosine, torch.Tensor):
14+
features_cosine = features_cosine.detach().cpu().numpy()
15+
d = np.zeros(len(features_cosine))
16+
for i in range(len(features_cosine)):
17+
d[i] = np.linalg.norm(features_cosine[i,:] - centroids_cosine[computed_cluster_labels_cosine[i],:])
18+
19+
start = time.time()
20+
labels_pred = np.zeros(len(features_cosine))
21+
for i in np.unique(computed_cluster_labels_cosine):
22+
index = np.where(computed_cluster_labels_cosine == i)[0]
23+
ind = np.argmin(d[index])
24+
cid = index[ind]
25+
labels_pred[index] = cid
26+
27+
28+
start = time.time()
29+
N = len(target_labels)
30+
31+
# cluster n_labels
32+
avail_labels = np.unique(target_labels)
33+
n_labels = len(avail_labels)
34+
35+
# count the number of objects in each cluster
36+
count_cluster = np.zeros(n_labels)
37+
for i in range(n_labels):
38+
count_cluster[i] = len(np.where(target_labels == avail_labels[i])[0])
39+
40+
# build a mapping from item_id to item index
41+
keys = np.unique(labels_pred)
42+
num_item = len(keys)
43+
values = range(num_item)
44+
item_map = dict()
45+
for i in range(len(keys)):
46+
item_map.update([(keys[i], values[i])])
47+
48+
49+
# count the number of objects of each item
50+
count_item = np.zeros(num_item)
51+
for i in range(N):
52+
index = item_map[labels_pred[i]]
53+
count_item[index] = count_item[index] + 1
54+
55+
# compute True Positive (TP) plus False Positive (FP)
56+
# tp_fp = 0
57+
tp_fp = comb(count_cluster, 2).sum()
58+
# for k in range(n_labels):
59+
# if count_cluster[k] > 1:
60+
# tp_fp = tp_fp + comb(count_cluster[k], 2)
61+
62+
# compute True Positive (TP)
63+
tp = 0
64+
start = time.time()
65+
for k in range(n_labels):
66+
member = np.where(target_labels == avail_labels[k])[0]
67+
member_ids = labels_pred[member]
68+
count = np.zeros(num_item)
69+
for j in range(len(member)):
70+
index = item_map[member_ids[j]]
71+
count[index] = count[index] + 1
72+
# for i in range(num_item):
73+
# if count[i] > 1:
74+
# tp = tp + comb(count[i], 2)
75+
tp += comb(count,2).sum()
76+
# False Positive (FP)
77+
fp = tp_fp - tp
78+
79+
# Compute False Negative (FN)
80+
count = comb(count_item, 2).sum()
81+
# count = 0
82+
# for j in range(num_item):
83+
# if count_item[j] > 1:
84+
# count = count + comb(count_item[j], 2)
85+
fn = count - tp
86+
87+
# compute F measure
88+
P = tp / (tp + fp)
89+
R = tp / (tp + fn)
90+
beta = 1
91+
F = (beta*beta + 1) * P * R / (beta*beta * P + R)
92+
return F

metrics/c_mAP_1000.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import torch
2+
import numpy as np
3+
import faiss
4+
5+
6+
7+
class Metric():
8+
def __init__(self, **kwargs):
9+
self.requires = ['features_cosine', 'target_labels']
10+
self.name = 'c_mAP_1000'
11+
12+
def __call__(self, target_labels, features_cosine):
13+
labels, freqs = np.unique(target_labels, return_counts=True)
14+
R = 1000
15+
16+
faiss_search_index = faiss.IndexFlatIP(features_cosine.shape[-1])
17+
if isinstance(features_cosine, torch.Tensor):
18+
features_cosine = features_cosine.detach().cpu().numpy()
19+
res = faiss.StandardGpuResources()
20+
faiss_search_index = faiss.index_cpu_to_gpu(res, 0, faiss_search_index)
21+
faiss_search_index.add(features_cosine)
22+
nearest_neighbours = faiss_search_index.search(features_cosine, int(R+1))[1][:,1:]
23+
24+
target_labels = target_labels.reshape(-1)
25+
nn_labels = target_labels[nearest_neighbours]
26+
27+
avg_r_precisions = []
28+
for label, freq in zip(labels, freqs):
29+
rows_with_label = np.where(target_labels==label)[0]
30+
for row in rows_with_label:
31+
n_recalled_samples = np.arange(1,R+1)
32+
target_label_occ_in_row = nn_labels[row,:]==label
33+
cumsum_target_label_freq_row = np.cumsum(target_label_occ_in_row)
34+
avg_r_pr_row = np.sum(cumsum_target_label_freq_row*target_label_occ_in_row/n_recalled_samples)/freq
35+
avg_r_precisions.append(avg_r_pr_row)
36+
37+
return np.mean(avg_r_precisions)

metrics/c_mAP_c.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import torch
2+
import numpy as np
3+
import faiss
4+
5+
6+
7+
class Metric():
8+
def __init__(self, **kwargs):
9+
self.requires = ['features_cosine', 'target_labels']
10+
self.name = 'c_mAP_c'
11+
12+
def __call__(self, target_labels, features_cosine):
13+
labels, freqs = np.unique(target_labels, return_counts=True)
14+
R = np.max(freqs)
15+
16+
faiss_search_index = faiss.IndexFlatIP(features_cosine.shape[-1])
17+
if isinstance(features_cosine, torch.Tensor):
18+
features_cosine = features_cosine.detach().cpu().numpy()
19+
res = faiss.StandardGpuResources()
20+
faiss_search_index = faiss.index_cpu_to_gpu(res, 0, faiss_search_index)
21+
faiss_search_index.add(features_cosine)
22+
nearest_neighbours = faiss_search_index.search(features_cosine, int(R+1))[1][:,1:]
23+
24+
target_labels = target_labels.reshape(-1)
25+
nn_labels = target_labels[nearest_neighbours]
26+
27+
avg_r_precisions = []
28+
for label, freq in zip(labels, freqs):
29+
rows_with_label = np.where(target_labels==label)[0]
30+
for row in rows_with_label:
31+
n_recalled_samples = np.arange(1,freq+1)
32+
target_label_occ_in_row = nn_labels[row,:freq]==label
33+
cumsum_target_label_freq_row = np.cumsum(target_label_occ_in_row)
34+
avg_r_pr_row = np.sum(cumsum_target_label_freq_row*target_label_occ_in_row/n_recalled_samples)/freq
35+
avg_r_precisions.append(avg_r_pr_row)
36+
37+
return np.mean(avg_r_precisions)

metrics/c_mAP_lim.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
import numpy as np
3+
import faiss
4+
5+
6+
7+
class Metric():
8+
def __init__(self, **kwargs):
9+
self.requires = ['features_cosine', 'target_labels']
10+
self.name = 'c_mAP_lim'
11+
12+
def __call__(self, target_labels, features_cosine):
13+
labels, freqs = np.unique(target_labels, return_counts=True)
14+
## Account for faiss-limit at k=1023
15+
R = min(1023,len(features_cosine))
16+
17+
faiss_search_index = faiss.IndexFlatIP(features_cosine.shape[-1])
18+
if isinstance(features_cosine, torch.Tensor):
19+
features_cosine = features_cosine.detach().cpu().numpy()
20+
res = faiss.StandardGpuResources()
21+
faiss_search_index = faiss.index_cpu_to_gpu(res, 0, faiss_search_index)
22+
faiss_search_index.add(features_cosine)
23+
nearest_neighbours = faiss_search_index.search(features_cosine, int(R+1))[1][:,1:]
24+
25+
target_labels = target_labels.reshape(-1)
26+
nn_labels = target_labels[nearest_neighbours]
27+
28+
avg_r_precisions = []
29+
for label, freq in zip(labels, freqs):
30+
rows_with_label = np.where(target_labels==label)[0]
31+
for row in rows_with_label:
32+
n_recalled_samples = np.arange(1,R+1)
33+
target_label_occ_in_row = nn_labels[row,:]==label
34+
cumsum_target_label_freq_row = np.cumsum(target_label_occ_in_row)
35+
avg_r_pr_row = np.sum(cumsum_target_label_freq_row*target_label_occ_in_row/n_recalled_samples)/freq
36+
avg_r_precisions.append(avg_r_pr_row)
37+
38+
return np.mean(avg_r_precisions)

metrics/c_nmi.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from sklearn import metrics
2+
3+
class Metric():
4+
def __init__(self, **kwargs):
5+
self.requires = ['kmeans_nearest_cosine', 'target_labels']
6+
self.name = 'c_nmi'
7+
8+
def __call__(self, target_labels, computed_cluster_labels_cosine):
9+
NMI = metrics.cluster.normalized_mutual_info_score(computed_cluster_labels_cosine.reshape(-1), target_labels.reshape(-1))
10+
return NMI

metrics/c_recall.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import numpy as np
2+
3+
class Metric():
4+
def __init__(self, k, **kwargs):
5+
self.k = k
6+
self.requires = ['nearest_features_cosine', 'target_labels']
7+
self.name = 'c_recall@{}'.format(k)
8+
9+
def __call__(self, target_labels, k_closest_classes_cosine, **kwargs):
10+
recall_at_k = np.sum([1 for target, recalled_predictions in zip(target_labels, k_closest_classes_cosine) if target in recalled_predictions[:self.k]])/len(target_labels)
11+
return recall_at_k

metrics/e_recall.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ def __init__(self, k, **kwargs):
66
self.requires = ['nearest_features', 'target_labels']
77
self.name = 'e_recall@{}'.format(k)
88

9-
def __call__(self, target_labels, k_closest_classes):
9+
def __call__(self, target_labels, k_closest_classes, **kwargs):
1010
recall_at_k = np.sum([1 for target, recalled_predictions in zip(target_labels, k_closest_classes) if target in recalled_predictions[:self.k]])/len(target_labels)
1111
return recall_at_k

0 commit comments

Comments
 (0)