-
Notifications
You must be signed in to change notification settings - Fork 0
/
runner.jl
121 lines (106 loc) · 2.11 KB
/
runner.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
include("model.jl")
using Knet: Param, @diff, value, grad
const seq_len = 10
const hm_data = 20
const hm_epochs = 10
const lr = .01
model = make()
data = [[randn(1,in_size) for _ in 1:seq_len] for __ in 1:hm_data]
runner() =
for i in 1:hm_epochs
l = 0
for seq in data
input = seq[1:end-1]
label = seq[2:end]
result = @diff begin
response, state, memory = prop(model, input)
loss(response, label)
end
l += value(result)
for layer in model
for param in fieldnames(Layer)
setfield!(layer, param, Param(getfield(layer, param) - grad(result, getfield(layer, param)) * lr))
end
end
end
@show i, l
end ; runner()
# initial test code.
# in_size = 10
# l_size = 10
#
# seq_len = 10
# hm_data = 20
#
# lstm = Layer(in_size,l_size)
#
# state = zeros(1,l_size)
#
# memory = zeros(1,memory_size)
#
#
# data = [[randn(1,in_size) for _ in 1:seq_len] for __ in 1:hm_data]
#
#
#
# main(model, state, memory, data) =
# begin
#
# for datapoint in data
#
# in_data = datapoint[1:end-1]
# out_data = datapoint[2:end]
#
#
# g = @diff begin
#
# outs = []
# for timestep in in_data
#
# out, state, memory = lstm(timestep, state, memory)
# push!(outs, out)
#
# end
#
# sum(sum([(e1-e2).^2 for (e1,e2) in zip(outs, out_data)]))
# end
#
# # for param in params(lstm)
# # param -= .01 .* grad(g, param)
# # end
#
# for field in fieldnames(Layer)
# setfield!(model, field, Param(getfield(model, field) - .01 .* grad(g, getfield(model, field))))
# end
#
# end
#
# end
#
#
# test(model, state, memory, data) =
#
# for timestep in data[1]
#
# out, state, memory = lstm(timestep, state, memory)
#
# # @show out
# # @show state
# @show memory
# println(" ")
#
# end
#
#
#
# test(lstm, state, memory, data)
#
#
# println("---")
#
# main(lstm, state, memory, data)
#
# println("---")
#
#
# test(lstm, state, memory, data)