-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathencoder-decoder-coupling.lua
102 lines (83 loc) · 3.61 KB
/
encoder-decoder-coupling.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
--[[
Example of "coupled" separate encoder and decoder networks, e.g. for sequence-to-sequence networks.
]]--
require 'rnn'
local opt = {}
opt.version = 1.5 -- fixed setHiddenState(0) bug
opt.learningRate = 0.1
opt.hiddenSize = 6
opt.numLayers = 1
opt.vocabSize = 7
opt.seqLen = 7 -- length of the encoded sequence (with padding)
opt.niter = 1000
--[[ Forward coupling: Copy encoder cell and output to decoder LSTM ]]--
function forwardConnect(enc, dec)
for i=1,#enc.lstmLayers do
dec.lstmLayers[i]:setHiddenState(0, enc.lstmLayers[i]:getHiddenState(opt.seqLen))
end
end
--[[ Backward coupling: Copy decoder gradients to encoder LSTM ]]--
function backwardConnect(enc, dec)
for i=1,#enc.lstmLayers do
enc.lstmLayers[i]:setGradHiddenState(opt.seqLen, dec.lstmLayers[i]:getGradHiddenState(0))
end
end
-- Encoder
local enc = nn.Sequential()
enc:add(nn.LookupTableMaskZero(opt.vocabSize, opt.hiddenSize))
enc.lstmLayers = {}
for i=1,opt.numLayers do
enc.lstmLayers[i] = nn.Sequencer(nn.RecLSTM(opt.hiddenSize, opt.hiddenSize):maskZero())
enc:add(enc.lstmLayers[i])
end
enc:add(nn.Select(1, -1))
-- Decoder
local dec = nn.Sequential()
dec:add(nn.LookupTableMaskZero(opt.vocabSize, opt.hiddenSize))
dec.lstmLayers = {}
for i=1,opt.numLayers do
dec.lstmLayers[i] = nn.Sequencer(nn.RecLSTM(opt.hiddenSize, opt.hiddenSize):maskZero())
dec:add(dec.lstmLayers[i])
end
dec:add(nn.Sequencer(nn.MaskZero(nn.Linear(opt.hiddenSize, opt.vocabSize), 1)))
dec:add(nn.Sequencer(nn.MaskZero(nn.LogSoftMax(), 1)))
local criterion = nn.SequencerCriterion(nn.MaskZeroCriterion(nn.ClassNLLCriterion(),1))
-- Some example data (batchsize = 2) with variable length input and output sequences
-- The input sentences to the encoder, padded with zeros from the left
local encInSeq = torch.Tensor({{0,0,0,0,1,2,3},{0,0,0,4,3,2,1}}):t()
-- The input sentences to the decoder, padded with zeros from the right.
-- Label '6' represents the start of a sentence (GO).
local decInSeq = torch.Tensor({{6,1,2,3,4,0,0,0},{6,5,4,3,2,1,0,0}}):t()
-- The expected output from the decoder (it will return one character per time-step),
-- padded with zeros from the right
-- Label '7' represents the end of sentence (EOS).
local decOutSeq = torch.Tensor({{1,2,3,4,7,0,0,0},{5,4,3,2,1,7,0,0}}):t()
-- the zeroMasks are used for zeroing intermediate RNN states where the zeroMask = 1
-- randomly set the zeroMasks from the input sequence or explicitly
local encZeroMask = math.random() < 0.5 and nn.utils.getZeroMaskSequence(encInSeq) -- auto zeroMask from input sequence
or torch.ByteTensor({{1,1,1,1,0,0,0},{1,1,1,0,0,0,0}}):t():contiguous() -- explicit zeroMask
local decZeroMask = math.random() < 0.5 and nn.utils.getZeroMaskSequence(decInSeq)
or torch.ByteTensor({{0,0,0,0,0,1,1,1},{0,0,0,0,0,0,1,1}}):t():contiguous()
for i=1,opt.niter do
enc:zeroGradParameters()
dec:zeroGradParameters()
-- zero-masking
enc:setZeroMask(encZeroMask)
dec:setZeroMask(decZeroMask)
criterion:setZeroMask(decZeroMask)
-- Forward pass
local encOut = enc:forward(encInSeq)
forwardConnect(enc, dec)
local decOut = dec:forward(decInSeq)
--print(decOut)
local err = criterion:forward(decOut, decOutSeq)
print(string.format("Iteration %d ; NLL err = %f ", i, err))
-- Backward pass
local gradOutput = criterion:backward(decOut, decOutSeq)
dec:backward(decInSeq, gradOutput)
backwardConnect(enc, dec)
local zeroTensor = torch.Tensor(encOut):zero()
enc:backward(encInSeq, zeroTensor)
dec:updateParameters(opt.learningRate)
enc:updateParameters(opt.learningRate)
end