-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathloss.py
61 lines (51 loc) · 2.06 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
from torchvision import models
class VGG19Loss(object):
def __init__(self, layers, weights=None, loss_function=nn.MSELoss()):
super(VGG19Loss, self).__init__()
if isinstance(layers, str): layers = [layers]
if weights==None: weights = [1.0]*len(layers)
assert len(layers)==len(weights)
vgg = models.vgg19(pretrained=True).features
self.modules = nn.ModuleList()
_modules = []
_conv = 1
_relu = 1
_layer = 1
for module in vgg.children():
if isinstance(module, nn.Conv2d):
name = f'conv{_layer}_{_conv}'
_conv += 1
elif isinstance(module, nn.ReLU):
name = f'relu{_layer}_{_relu}'
_relu += 1
elif isinstance(module, nn.MaxPool2d):
name = f'pool{_layer}'
_conv = 1
_relu = 1
_layer += 1
_modules.append(module)
if name in layers:
self.modules.append(nn.Sequential(*_modules))
_modules = []
self.weights = torch.FloatTensor(weights)
self.loss_function = loss_function
for param in self.modules.parameters():
param.requires_grad = False
self.mean = torch.FloatTensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
self.stddev = torch.FloatTensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
def __call__(self, source, target):
source = (source - self.mean) / self.stddev
target = (target - self.mean) / self.stddev
losses = []
for weight, module in zip(self.weights, self.modules):
source = module(source)
target = module(target)
losses.append(weight * self.loss_function(source, target))
return torch.mean(torch.DoubleTensor(losses))
def to(self, device):
self.weights = self.weights.to(device)
self.modules = self.modules.to(device)
self.mean = self.mean.to(device)
self.stddev = self.stddev.to(device)