-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathsimple-bisequencer-network.lua
88 lines (65 loc) · 2.42 KB
/
simple-bisequencer-network.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
require 'rnn'
-- hyper-parameters
batchSize = 8
seqlen = 5 -- sequence length
hiddenSize = 7
nIndex = 10
lr = 0.1
-- forward rnn
-- build simple recurrent neural network
local fwd = nn.LookupRNN(nIndex, hiddenSize)
-- backward rnn (will be applied in reverse order of input sequence)
local bwd = fwd:clone()
bwd:reset() -- reinitializes parameters
-- merges the output of one time-step of fwd and bwd rnns.
-- You could also try nn.AddTable(), nn.Identity(), etc.
local merge = nn.JoinTable(1, 1)
-- we use BiSequencerLM because this is a language model (previous and next words to predict current word).
-- If we used BiSequencer, x[t] would be used to predict y[t] = x[t] (which is cheating).
-- Note that bwd and merge argument are optional and will default to the above.
local brnn = nn.BiSequencerLM(fwd, bwd, merge)
local rnn = nn.Sequential()
:add(brnn)
:add(nn.Sequencer(nn.Linear(hiddenSize*2, nIndex))) -- times two due to JoinTable
:add(nn.Sequencer(nn.LogSoftMax()))
print(rnn)
-- build criterion
criterion = nn.SequencerCriterion(nn.ClassNLLCriterion())
-- build dummy dataset (task is to predict next item, given previous)
sequence_ = torch.LongTensor():range(1,10) -- 1,2,3,4,5,6,7,8,9,10
sequence = torch.LongTensor(100,10):copy(sequence_:view(1,10):expand(100,10))
sequence:resize(100*10) -- one long sequence of 1,2,3...,10,1,2,3...10...
offsets = {}
for i=1,batchSize do
table.insert(offsets, math.ceil(math.random()*sequence:size(1)))
end
offsets = torch.LongTensor(offsets)
-- training
local iteration = 1
while true do
-- 1. create a sequence of seqlen time-steps
local inputs, targets = {}, {}
for step=1,seqlen do
-- a batch of inputs
inputs[step] = sequence:index(1, offsets)
-- incement indices
offsets:add(1)
for j=1,batchSize do
if offsets[j] > sequence:size(1) then
offsets[j] = 1
end
end
targets[step] = sequence:index(1, offsets)
end
-- 2. forward sequence through rnn
rnn:zeroGradParameters()
local outputs = rnn:forward(inputs)
local err = criterion:forward(outputs, targets)
print(string.format("Iteration %d ; NLL err = %f ", iteration, err))
-- 3. backward sequence through rnn (i.e. backprop through time)
local gradOutputs = criterion:backward(outputs, targets)
local gradInputs = rnn:backward(inputs, gradOutputs)
-- 4. update
rnn:updateParameters(lr)
iteration = iteration + 1
end