-
Notifications
You must be signed in to change notification settings - Fork 7
/
polyloss.py
executable file
·126 lines (111 loc) · 4.93 KB
/
polyloss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class Poly1CrossEntropyLoss(nn.Module):
def __init__(self,
num_classes: int,
epsilon: float = 1.0,
reduction: str = "none",
weight: Tensor = None):
"""
Create instance of Poly1CrossEntropyLoss
:param num_classes:
:param epsilon:
:param reduction: one of none|sum|mean, apply reduction to final loss tensor
:param weight: manual rescaling weight for each class, passed to Cross-Entropy loss
"""
super(Poly1CrossEntropyLoss, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.reduction = reduction
self.weight = weight
return
def forward(self, logits, labels):
"""
Forward pass
:param logits: tensor of shape [N, num_classes]
:param labels: tensor of shape [N]
:return: poly cross-entropy loss
"""
labels_onehot = F.one_hot(labels, num_classes=self.num_classes).to(device=logits.device,
dtype=logits.dtype)
pt = torch.sum(labels_onehot * F.softmax(logits, dim=-1), dim=-1)
CE = F.cross_entropy(input=logits,
target=labels,
reduction='none',
weight=self.weight)
poly1 = CE + self.epsilon * (1 - pt)
if self.reduction == "mean":
poly1 = poly1.mean()
elif self.reduction == "sum":
poly1 = poly1.sum()
return poly1
class Poly1FocalLoss(nn.Module):
def __init__(self,
num_classes: int,
epsilon: float = 1.0,
alpha: float = 0.25,
gamma: float = 2.0,
reduction: str = "none",
weight: Tensor = None,
pos_weight: Tensor = None,
label_is_onehot: bool = False):
"""
Create instance of Poly1FocalLoss
:param num_classes: number of classes
:param epsilon: poly loss epsilon
:param alpha: focal loss alpha
:param gamma: focal loss gamma
:param reduction: one of none|sum|mean, apply reduction to final loss tensor
:param weight: manual rescaling weight for each class, passed to binary Cross-Entropy loss
:param label_is_onehot: set to True if labels are one-hot encoded
"""
super(Poly1FocalLoss, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
self.weight = weight
self.pos_weight = pos_weight
self.label_is_onehot = label_is_onehot
return
def forward(self, logits, labels):
"""
Forward pass
:param logits: output of neural netwrok of shape [N, num_classes] or [N, num_classes, ...]
:param labels: ground truth tensor of shape [N] or [N, ...] with class ids if label_is_onehot was set to False, otherwise
one-hot encoded tensor of same shape as logits
:return: poly focal loss
"""
# focal loss implementation taken from
# https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/focal_loss.py
p = torch.sigmoid(logits)
if not self.label_is_onehot:
# if labels are of shape [N]
# convert to one-hot tensor of shape [N, num_classes]
if labels.ndim == 1:
labels = F.one_hot(labels, num_classes=self.num_classes)
# if labels are of shape [N, ...] e.g. segmentation task
# convert to one-hot tensor of shape [N, num_classes, ...]
else:
labels = F.one_hot(labels.unsqueeze(1), self.num_classes).transpose(1, -1).squeeze_(-1)
labels = labels.to(device=logits.device,
dtype=logits.dtype)
ce_loss = F.binary_cross_entropy_with_logits(input=logits,
target=labels,
reduction="none",
weight=self.weight,
pos_weight=self.pos_weight)
pt = labels * p + (1 - labels) * (1 - p)
FL = ce_loss * ((1 - pt) ** self.gamma)
if self.alpha >= 0:
alpha_t = self.alpha * labels + (1 - self.alpha) * (1 - labels)
FL = alpha_t * FL
poly1 = FL + self.epsilon * torch.pow(1 - pt, self.gamma + 1)
if self.reduction == "mean":
poly1 = poly1.mean()
elif self.reduction == "sum":
poly1 = poly1.sum()
return poly1