-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathlosses.py
More file actions
104 lines (86 loc) · 4.6 KB
/
losses.py
File metadata and controls
104 lines (86 loc) · 4.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Third party import
import logging
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as func
from sklearn.metrics.pairwise import rbf_kernel
import numpy as np
class GeneralizedSupervisedNTXenLoss(nn.Module):
def __init__(self, kernel='rbf', temperature=0.1, return_logits=False, sigma=1.0):
"""
:param kernel: a callable function f: [K, *] x [K, *] -> [K, K]
y1, y2 -> f(y1, y2)
where (*) is the dimension of the labels (yi)
default: an rbf kernel parametrized by 'sigma' which corresponds to gamma=1/(2*sigma**2)
:param temperature:
:param return_logits:
"""
# sigma = prior over the label's range
super().__init__()
self.kernel = kernel
self.sigma = sigma
if self.kernel == 'rbf':
self.kernel = lambda y1, y2: rbf_kernel(y1, y2, gamma=1./(2*self.sigma**2))
else:
assert hasattr(self.kernel, '__call__'), 'kernel must be a callable'
self.temperature = temperature
self.return_logits = return_logits
self.INF = 1e8
def forward(self, z_i, z_j, labels):
N = len(z_i)
assert N == len(labels), "Unexpected labels length: %i"%len(labels)
z_i = func.normalize(z_i, p=2, dim=-1) # dim [N, D]
z_j = func.normalize(z_j, p=2, dim=-1) # dim [N, D]
sim_zii= (z_i @ z_i.T) / self.temperature # dim [N, N] => Upper triangle contains incorrect pairs
sim_zjj = (z_j @ z_j.T) / self.temperature # dim [N, N] => Upper triangle contains incorrect pairs
sim_zij = (z_i @ z_j.T) / self.temperature # dim [N, N] => the diag contains the correct pairs (i,j) (x transforms via T_i and T_j)
# 'Remove' the diag terms by penalizing it (exp(-inf) = 0)
sim_zii = sim_zii - self.INF * torch.eye(N, device=z_i.device)
sim_zjj = sim_zjj - self.INF * torch.eye(N, device=z_i.device)
all_labels = labels.view(N, -1).repeat(2, 1).detach().cpu().numpy() # [2N, *]
weights = self.kernel(all_labels, all_labels) # [2N, 2N]
weights = weights * (1 - np.eye(2*N)) # puts 0 on the diagonal
weights /= weights.sum(axis=1)
# if 'rbf' kernel and sigma->0, we retrieve the classical NTXenLoss (without labels)
sim_Z = torch.cat([torch.cat([sim_zii, sim_zij], dim=1), torch.cat([sim_zij.T, sim_zjj], dim=1)], dim=0) # [2N, 2N]
log_sim_Z = func.log_softmax(sim_Z, dim=1)
loss = -1./N * (torch.from_numpy(weights).to(z_i.device) * log_sim_Z).sum()
correct_pairs = torch.arange(N, device=z_i.device).long()
if self.return_logits:
return loss, sim_zij, correct_pairs
return loss
def __str__(self):
return "{}(temp={}, kernel={}, sigma={})".format(type(self).__name__, self.temperature,
self.kernel.__name__, self.sigma)
class NTXenLoss(nn.Module):
"""
Normalized Temperature Cross-Entropy Loss for Constrastive Learning
Refer for instance to:
Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton
A Simple Framework for Contrastive Learning of Visual Representations, arXiv 2020
"""
def __init__(self, temperature=0.1, return_logits=False):
super().__init__()
self.temperature = temperature
self.INF = 1e8
self.return_logits = return_logits
def forward(self, z_i, z_j):
N = len(z_i)
z_i = func.normalize(z_i, p=2, dim=-1) # dim [N, D]
z_j = func.normalize(z_j, p=2, dim=-1) # dim [N, D]
sim_zii= (z_i @ z_i.T) / self.temperature # dim [N, N] => Upper triangle contains incorrect pairs
sim_zjj = (z_j @ z_j.T) / self.temperature # dim [N, N] => Upper triangle contains incorrect pairs
sim_zij = (z_i @ z_j.T) / self.temperature # dim [N, N] => the diag contains the correct pairs (i,j) (x transforms via T_i and T_j)
# 'Remove' the diag terms by penalizing it (exp(-inf) = 0)
sim_zii = sim_zii - self.INF * torch.eye(N, device=z_i.device)
sim_zjj = sim_zjj - self.INF * torch.eye(N, device=z_i.device)
correct_pairs = torch.arange(N, device=z_i.device).long()
loss_i = func.cross_entropy(torch.cat([sim_zij, sim_zii], dim=1), correct_pairs)
loss_j = func.cross_entropy(torch.cat([sim_zij.T, sim_zjj], dim=1), correct_pairs)
if self.return_logits:
return (loss_i + loss_j), sim_zij, correct_pairs
return (loss_i + loss_j)
def __str__(self):
return "{}(temp={})".format(type(self).__name__, self.temperature)