-
Notifications
You must be signed in to change notification settings - Fork 2
/
loss.py
74 lines (54 loc) · 2.42 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
def mse(inp, target):
return torch.mean((inp - target)**2)
def tv_chan(inp, target):
l = inp.shape[-1]-1
tvIn = torch.sum(torch.abs(inp[:, :, 1:] - inp[:, :, :-1]), dim=2)
tvTarget = torch.sum(torch.abs(target[:, :, 1:] - target[:, :, :-1]), dim=2)
return torch.mean(torch.abs(tvTarget - tvIn)) / l
def tv_no_chan(inp, target):
l = inp.shape[-1]-1
tvIn = torch.sum(torch.abs(inp[:, 1:] - inp[:, :-1]), dim=1)
tvTarget = torch.sum(torch.abs(target[:, 1:] - target[:, :-1]), dim=1)
return torch.mean(torch.abs(tvTarget - tvIn)) / l
def mse_tv(inp, target):
return 0.9 * mse(inp, target) + 0.1 * tv_no_chan(inp, target)
# def mmd_multiscale_on(dev):
# def mmd_multiscale(x, y):
# xx, yy, zz = torch.mm(x,x.t()), torch.mm(y,y.t()), torch.mm(x,y.t())
# rx = (xx.diag().unsqueeze(0).expand_as(xx))
# ry = (yy.diag().unsqueeze(0).expand_as(yy))
# dxx = rx.t() + rx - 2.*xx
# dyy = ry.t() + ry - 2.*yy
# dxy = rx.t() + ry - 2.*zz
# XX, YY, XY = (torch.zeros(xx.shape).to(dev),
# torch.zeros(xx.shape).to(dev),
# torch.zeros(xx.shape).to(dev))
# for a in [0.2, 0.5, 0.9, 1.3, 2.4, 5.0, 10.0, 20.0, 40.0]:
# # for a in [0.05, 0.125, 0.225, 0.325]:
# XX += a**2 * (a**2 + dxx)**-1
# YY += a**2 * (a**2 + dyy)**-1
# XY += a**2 * (a**2 + dxy)**-1
# return torch.mean(XX + YY - 2.*XY)
# return mmd_multiscale
def mmd_multiscale_on(dev, alphas=None):
if alphas is None:
alphas = [0.2, 0.5, 0.9, 1.3, 2.4, 5.0, 10.0, 20.0, 40.0]
def mmd_multiscale(x, y):
xx, yy, zz = torch.mm(x,x.t()), torch.mm(y,y.t()), torch.mm(x,y.t())
rx = (xx.diag().unsqueeze(0).expand_as(xx))
ry = (yy.diag().unsqueeze(0).expand_as(yy))
dxx = rx.t() + rx - 2.*xx
dyy = ry.t() + ry - 2.*yy
dxy = rx.t() + ry - 2.*zz
XX, YY, XY = (torch.zeros(xx.shape).to(dev),
torch.zeros(xx.shape).to(dev),
torch.zeros(xx.shape).to(dev))
# for a in [0.2, 0.5, 0.9, 1.3, 2.4, 5.0, 10.0, 20.0, 40.0]:
# for a in [0.05, 0.125, 0.225, 0.325]:
for a in alphas:
XX += a**2 * (a**2 + dxx)**-1
YY += a**2 * (a**2 + dyy)**-1
XY += a**2 * (a**2 + dxy)**-1
return torch.mean(XX + YY - 2.*XY)
return mmd_multiscale