-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathNCECriterion.lua
119 lines (97 loc) · 3.34 KB
/
NCECriterion.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
------------------------------------------------------------------------
--[[ Noise Contrast Estimation Criterion ]]--
-- Ref.: A. http://mi.eng.cam.ac.uk/~xc257/papers/ICASSP2015-rnnlm-nce.pdf
-- B. https://www.cs.toronto.edu/~amnih/papers/ncelm.pdf
------------------------------------------------------------------------
local NCECriterion, parent = torch.class("nn.NCECriterion", "nn.Criterion")
local eps = 0.0000001
function NCECriterion:__init()
parent.__init(self)
self.sizeAverage = true
self.gradInput = {torch.Tensor(), torch.Tensor(), torch.Tensor(), torch.Tensor()}
end
function NCECriterion:updateOutput(inputTable, target)
-- P_model(target), P_model(sample), P_noise(target), P_noise(sample)
local Pmt, Pms, Pnt, Pns = unpack(inputTable)
local k = Pms:size(2)
assert(Pmt:dim() == 1)
assert(Pms:dim() == 2)
assert(Pnt:dim() == 1)
assert(Pns:dim() == 2)
-- equation 5 in ref. A
-- eq 5.1 : P(origin=model) = Pmt / (Pmt + k*Pnt)
self._Pom = self._Pom or Pmt.new()
self._Pom:resizeAs(Pmt):copy(Pmt)
self._Pomdiv = self._Pomdiv or Pmt.new()
self._Pomdiv:resizeAs(Pmt):copy(Pmt)
self._Pomdiv:add(k, Pnt):add(eps)
self._Pom:cdiv(self._Pomdiv)
-- eq 5.2 : P(origin=noise) = k*Pns / (Pms + k*Pns)
self._Pon = self._Pon or Pns.new()
self._Pon:resizeAs(Pns):copy(Pns):mul(k)
self._Pondiv = self._Pondiv or Pms.new()
self._Pondiv:resizeAs(Pms):copy(Pms)
self._Pondiv:add(k, Pns):add(eps)
self._Pon:cdiv(self._Pondiv)
-- equation 6 in ref. A
self._lnPom = self._lnPom or self._Pom.new()
self._lnPom:log(self._Pom)
self._lnPon = self._lnPon or self._Pon.new()
self._lnPon:log(self._Pon)
local lnPomsum = self._lnPom:sum()
local lnPonsum = self._lnPon:sum()
self.output = - (lnPomsum + lnPonsum)
if self.sizeAverage then
self.output = self.output / Pmt:size(1)
end
return self.output
end
function NCECriterion:updateGradInput(inputTable, target)
self.gradInput = self.gradInput or nn.utils.recursiveNew(inputTable)
assert(#self.gradInput == 4)
local Pmt, Pms, Pnt, Pns = unpack(inputTable)
local k = Pms:size(2)
-- equation 7 in ref. A
-- d ln(Pom) / d input = -k*Pnt / ( Pmt * (Pmt + k*Pnt) )
local dlnPom = self.gradInput[1]
dlnPom = dlnPom or Pnt.new()
dlnPom:resizeAs(Pnt):copy(Pnt):mul(-k)
dlnPom:cdiv(self._Pomdiv)
Pmt:add(eps)
dlnPom:cdiv(Pmt) -- d ln(Pmt) / d Pmt = 1 / d Pmt
Pmt:add(-eps)
-- d ln(Pon) / d input = Pms / ( Pms * (Pms + k*Pns) )
local dlnPon = self.gradInput[2]
dlnPon = dlnPon or Pms.new()
dlnPon:resizeAs(Pms):copy(Pms)
dlnPon:cdiv(self._Pondiv)
Pms:add(eps)
dlnPon:cdiv(Pms) -- d ln(Pms) / d Pms = 1 / d Pms
Pms:add(-eps)
if self.gradInput[3]:nElement() ~= Pnt:nElement() then
self.gradInput[3]:resizeAs(Pnt):zero()
end
if self.gradInput[4]:nElement() ~= Pns:nElement() then
self.gradInput[4]:resizeAs(Pns):zero()
end
if self.sizeAverage then
dlnPom:div(Pmt:size(1))
dlnPon:div(Pmt:size(1))
end
return self.gradInput
end
function NCECriterion:clearState()
self._Pom = nil
self._Pomdiv = nil
self._Pon = nil
self._Pondiv = nil
self._lnPon = nil
self._lnPom = nil
self.gradInput = nil
parent.clearState(self)
return self
end
function NCECriterion:type(...)
self:clearState()
return parent.type(self, ...)
end