-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathsimple-sequencer-network.lua
73 lines (54 loc) · 1.8 KB
/
simple-sequencer-network.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
require 'rnn'
-- hyper-parameters
batchSize = 8
seqlen = 5 -- sequence length
hiddenSize = 7
nIndex = 10
lr = 0.1
local rnn = nn.Sequential()
:add(nn.LookupRNN(nIndex, hiddenSize))
:add(nn.Linear(hiddenSize, nIndex))
:add(nn.LogSoftMax())
-- internally, rnn will be wrapped into a Recursor to make it an AbstractRecurrent instance.
rnn = nn.Sequencer(rnn)
print(rnn)
-- build criterion
criterion = nn.SequencerCriterion(nn.ClassNLLCriterion())
-- build dummy dataset (task is to predict next item, given previous)
sequence_ = torch.LongTensor():range(1,10) -- 1,2,3,4,5,6,7,8,9,10
sequence = torch.LongTensor(100,10):copy(sequence_:view(1,10):expand(100,10))
sequence:resize(100*10) -- one long sequence of 1,2,3...,10,1,2,3...10...
offsets = {}
for i=1,batchSize do
table.insert(offsets, math.ceil(math.random()*sequence:size(1)))
end
offsets = torch.LongTensor(offsets)
-- training
local iteration = 1
while true do
-- 1. create a sequence of seqlen time-steps
local inputs, targets = {}, {}
for step=1,seqlen do
-- a batch of inputs
inputs[step] = sequence:index(1, offsets)
-- incement indices
offsets:add(1)
for j=1,batchSize do
if offsets[j] > sequence:size(1) then
offsets[j] = 1
end
end
targets[step] = sequence:index(1, offsets)
end
-- 2. forward sequence through rnn
rnn:zeroGradParameters()
local outputs = rnn:forward(inputs)
local err = criterion:forward(outputs, targets)
print(string.format("Iteration %d ; NLL err = %f ", iteration, err))
-- 3. backward sequence through rnn (i.e. backprop through time)
local gradOutputs = criterion:backward(outputs, targets)
local gradInputs = rnn:backward(inputs, gradOutputs)
-- 4. update
rnn:updateParameters(lr)
iteration = iteration + 1
end