-
Notifications
You must be signed in to change notification settings - Fork 92
/
loss.py
42 lines (35 loc) · 1.05 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from torch.autograd import Variable
class NLLLoss(nn.Module):
"""Self-Defined NLLLoss Function
Args:
weight: Tensor (num_class, )
"""
def __init__(self, weight):
super(NLLLoss, self).__init__()
self.weight = weight
def forward(self, prob, target):
"""
Args:
prob: (N, C)
target : (N, )
"""
N = target.size(0)
C = prob.size(1)
weight = Variable(self.weight).view((1, -1))
weight = weight.expand(N, C) # (N, C)
if prob.is_cuda:
weight = weight.cuda()
prob = weight * prob
one_hot = torch.zeros((N, C))
if prob.is_cuda:
one_hot = one_hot.cuda()
one_hot.scatter_(1, target.data.view((-1,1)), 1)
one_hot = one_hot.type(torch.ByteTensor)
one_hot = Variable(one_hot)
if prob.is_cuda:
one_hot = one_hot.cuda()
loss = torch.masked_select(prob, one_hot)
return -torch.sum(loss)