-
Notifications
You must be signed in to change notification settings - Fork 1
/
loss.py
46 lines (36 loc) · 1.46 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
"""
author: mengxue
email: [email protected]
last date: May 29 2024
"""
import torch
import torch.nn.functional as F
from torchvision.ops import sigmoid_focal_loss
def classification_loss(batch_pred, batch_gt, batch_mask=None):
if batch_mask is None:
batch_mask = torch.ones_like(batch_gt)
CLF = (sigmoid_focal_loss(batch_pred.reshape(-1, 1), batch_gt.reshape(-1, 1), alpha=0.5,
gamma=0., reduction='none') * batch_mask.reshape(-1, 1)).sum()
CLF = CLF / (batch_mask.reshape(-1, 1).sum() + torch.finfo(batch_pred.dtype).eps)
return CLF
def reconstruction_loss(recon_x, x, x_mask=None):
if x_mask is None:
x_mask = torch.ones_like(x)
MSE = F.mse_loss(recon_x.reshape(-1, 1), x.reshape(-1, 1), reduction='none')
MSE = (MSE * x_mask.reshape(-1, 1)).sum() / (x_mask.reshape(-1, 1).sum() + torch.finfo(recon_x.dtype).eps)
return MSE
def kl_loss(mu1, logvar1, mu2=None, logvar2=None):
# see Appendix B from VAE paper: https://arxiv.org/abs/1312.6114
if mu2 is None or logvar2 is None:
KLD = 0.5 * torch.mean(logvar1.exp() + mu1.pow(2) - 1 - logvar1)
else:
# Equation 6~7: https://arxiv.org/abs/1606.05908
x0 = (logvar1 - logvar2).exp()
x1 = (mu2 - mu1).pow(2) / logvar2.exp()
x2 = -1.
x3 = logvar2 - logvar1
KLD = 0.5 * torch.mean(x0 + x1 + x2 + x3)
return KLD
CLF_LOSS = classification_loss
REC_LOSS = reconstruction_loss
KL_LOSS = kl_loss