Skip to content

Commit 75cd651

Browse files
author
pblouw
committed
Updates to tutorial
1 parent 2bf1546 commit 75cd651

5 files changed

+52
-67
lines changed

Reinforcement Learning in Nengo.ipynb

+14-14
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
},
8585
{
8686
"cell_type": "code",
87-
"execution_count": 1,
87+
"execution_count": 3,
8888
"metadata": {
8989
"slideshow": {
9090
"slide_type": "subslide"
@@ -134,7 +134,7 @@
134134
},
135135
{
136136
"cell_type": "code",
137-
"execution_count": 2,
137+
"execution_count": 4,
138138
"metadata": {},
139139
"outputs": [],
140140
"source": [
@@ -153,11 +153,11 @@
153153
},
154154
{
155155
"cell_type": "code",
156-
"execution_count": 3,
156+
"execution_count": 5,
157157
"metadata": {},
158158
"outputs": [],
159159
"source": [
160-
"with nengo.Network() as model:\n",
160+
"with nengo.Network(seed=2) as model:\n",
161161
" env = td_grid.GridNode(environment, dt=0.001)\n",
162162
" \n",
163163
" # define nodes for plotting data, managing agent's interface with environment\n",
@@ -171,7 +171,7 @@
171171
" intercepts=nengo.dists.Choice([0.15]), radius=2)\n",
172172
" \n",
173173
" # define neurons that compute the learning signal\n",
174-
" learn_signal = nengo.Ensemble(n_neurons=1000, dimensions=4, neuron_type=nengo.LIF())\n",
174+
" learn_signal = nengo.Ensemble(n_neurons=1000, dimensions=4)\n",
175175
" \n",
176176
" # connect the sensor to state ensemble\n",
177177
" nengo.Connection(sensor_node, state, synapse=None)\n",
@@ -197,7 +197,7 @@
197197
},
198198
{
199199
"cell_type": "code",
200-
"execution_count": 4,
200+
"execution_count": 6,
201201
"metadata": {
202202
"slideshow": {
203203
"slide_type": "subslide"
@@ -208,9 +208,9 @@
208208
"data": {
209209
"text/html": [
210210
"\n",
211-
" <div id=\"0f6d33cf-7e1f-457b-9444-f1e3f778dd12\">\n",
211+
" <div id=\"70bbc289-1bbc-4cfc-bd1f-4bc3db75e00f\">\n",
212212
" <iframe\n",
213-
" src=\"http://localhost:60833/?token=90d6ab7ca7678f0607c082d335727f5c53810a699d4bcc03\"\n",
213+
" src=\"http://localhost:53191/?token=8eb3f65da8a76e6cf31820ef8613cda506f3eeefb3632442\"\n",
214214
" width=\"100%\"\n",
215215
" height=\"600\"\n",
216216
" frameborder=\"0\"\n",
@@ -255,7 +255,7 @@
255255
},
256256
{
257257
"cell_type": "code",
258-
"execution_count": 5,
258+
"execution_count": 7,
259259
"metadata": {},
260260
"outputs": [],
261261
"source": [
@@ -280,7 +280,7 @@
280280
},
281281
{
282282
"cell_type": "code",
283-
"execution_count": 6,
283+
"execution_count": 8,
284284
"metadata": {},
285285
"outputs": [],
286286
"source": [
@@ -318,7 +318,7 @@
318318
},
319319
{
320320
"cell_type": "code",
321-
"execution_count": 7,
321+
"execution_count": 9,
322322
"metadata": {
323323
"slideshow": {
324324
"slide_type": "subslide"
@@ -365,7 +365,7 @@
365365
},
366366
{
367367
"cell_type": "code",
368-
"execution_count": 8,
368+
"execution_count": 10,
369369
"metadata": {
370370
"slideshow": {
371371
"slide_type": "subslide"
@@ -376,9 +376,9 @@
376376
"data": {
377377
"text/html": [
378378
"\n",
379-
" <div id=\"392796a8-b128-421f-920a-a6ecb6efdadb\">\n",
379+
" <div id=\"bf33c426-004c-4dac-a9bc-93ae7b00137b\">\n",
380380
" <iframe\n",
381-
" src=\"http://localhost:61025/?token=441e808a24bcd81a6497f66689838f3785b6ead1e0d41334\"\n",
381+
" src=\"http://localhost:53264/?token=3cd1018bbdcfa030679f51c0bde3a6a54ac1001f9101ec8c\"\n",
382382
" width=\"100%\"\n",
383383
" height=\"600\"\n",
384384
" frameborder=\"0\"\n",

configs/default.py.cfg

+7-7
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ _viz_config[_viz_1].height = 0.40030392521829083
2222
_viz_config[_viz_1].label_visible = True
2323
_viz_2 = nengo_gui.components.Raster(learn_signal)
2424
_viz_config[_viz_2].n_neurons = 100
25-
_viz_config[_viz_2].x = 0.9562372057490022
26-
_viz_config[_viz_2].y = 0.261286048718174
25+
_viz_config[_viz_2].x = 0.9151104251757648
26+
_viz_config[_viz_2].y = 0.22700986118226302
2727
_viz_config[_viz_2].width = 0.2455550152884971
2828
_viz_config[_viz_2].height = 0.35734453744254163
2929
_viz_config[_viz_2].label_visible = True
@@ -45,8 +45,8 @@ _viz_config[_viz_4].width = 0.555808142243912
4545
_viz_config[_viz_4].height = 0.5104425418379376
4646
_viz_config[_viz_4].label_visible = True
4747
_viz_8 = nengo_gui.components.SpikeGrid(state)
48-
_viz_config[_viz_8].x = 1.5088625183590774
49-
_viz_config[_viz_8].y = 0.18595803714975423
48+
_viz_config[_viz_8].x = 1.4912367552562615
49+
_viz_config[_viz_8].y = 0.15168184961384323
5050
_viz_config[_viz_8].width = 0.2552559706112676
5151
_viz_config[_viz_8].height = 0.3554230847816174
5252
_viz_config[_viz_8].label_visible = True
@@ -64,13 +64,13 @@ _viz_config[env].pos=(2.0688976229564138, 0.3195690695620929)
6464
_viz_config[env].size=(0.1, 0.1)
6565
_viz_config[learn_signal].pos=(0.8142521273436952, 0.8820053350648869)
6666
_viz_config[learn_signal].size=(0.1, 0.1)
67-
_viz_config[model].pos=(1.0815485124596864, 1.1291685346823868)
68-
_viz_config[model].size=(0.3459459281303725, 0.3459459281303725)
67+
_viz_config[model].pos=(1.3653844099521146, 1.336242672340802)
68+
_viz_config[model].size=(0.30168332401368336, 0.30168332401368336)
6969
_viz_config[model].expanded=True
7070
_viz_config[model].has_layout=True
7171
_viz_config[qvalue_node].pos=(1.5219768231749249, 0.9078091928089596)
7272
_viz_config[qvalue_node].size=(0.1, 0.1)
73-
_viz_config[reward_node].pos=(-0.8019821637475498, 1.8231801626252788)
73+
_viz_config[reward_node].pos=(-0.5682603312530061, 2.485424050468169)
7474
_viz_config[reward_node].size=(0.1, 0.1)
7575
_viz_config[sensor_node].pos=(0.8029007404374915, 1.336993998970692)
7676
_viz_config[sensor_node].size=(0.1, 0.12514175766245555)

configs/learning6-value.py.cfg

+30-41
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,30 @@
11
_viz_0 = nengo_gui.components.HTMLView(env)
2-
_viz_config[_viz_0].x = 0.779824356063676
3-
_viz_config[_viz_0].y = 0.2994841745479973
4-
_viz_config[_viz_0].width = 0.13850415512465375
5-
_viz_config[_viz_0].height = 0.18083182640144665
2+
_viz_config[_viz_0].x = 0.8774793077580895
3+
_viz_config[_viz_0].y = 0.42583851046818516
4+
_viz_config[_viz_0].width = 0.18482451594727958
5+
_viz_config[_viz_0].height = 0.2415716648986474
66
_viz_config[_viz_0].label_visible = True
77
_viz_1 = nengo_gui.components.XYValue(state)
88
_viz_config[_viz_1].max_value = 1
99
_viz_config[_viz_1].min_value = -1
1010
_viz_config[_viz_1].index_x = 0
1111
_viz_config[_viz_1].index_y = 1
12-
_viz_config[_viz_1].x = 0.10813922988152525
13-
_viz_config[_viz_1].y = 0.7138886698485432
14-
_viz_config[_viz_1].width = 0.13850415512465375
15-
_viz_config[_viz_1].height = 0.18083182640144665
12+
_viz_config[_viz_1].x = 0.02613411018942154
13+
_viz_config[_viz_1].y = 0.7770658378086368
14+
_viz_config[_viz_1].width = 0.1679226973847849
15+
_viz_config[_viz_1].height = 0.3047488328587414
1616
_viz_config[_viz_1].label_visible = True
17-
_viz_2 = nengo_gui.components.Value(reward)
18-
_viz_config[_viz_2].max_value = 10
19-
_viz_config[_viz_2].min_value = -10
17+
_viz_2 = nengo_gui.components.Value(value)
18+
_viz_config[_viz_2].max_value = 3
19+
_viz_config[_viz_2].min_value = -3
2020
_viz_config[_viz_2].show_legend = False
2121
_viz_config[_viz_2].legend_labels = ['label_0']
2222
_viz_config[_viz_2].synapse = 0.01
23-
_viz_config[_viz_2].x = 0.3803227744877454
24-
_viz_config[_viz_2].y = 0.7064717582566888
25-
_viz_config[_viz_2].width = 0.10964955033075563
26-
_viz_config[_viz_2].height = 0.16543033364861992
23+
_viz_config[_viz_2].x = 0.3151722259100631
24+
_viz_config[_viz_2].y = 0.8123759102576043
25+
_viz_config[_viz_2].width = 0.12519865601847896
26+
_viz_config[_viz_2].height = 0.24298910753882308
2727
_viz_config[_viz_2].label_visible = True
28-
_viz_3 = nengo_gui.components.Value(value)
29-
_viz_config[_viz_3].max_value = 1.5
30-
_viz_config[_viz_3].min_value = -1.5
31-
_viz_config[_viz_3].show_legend = False
32-
_viz_config[_viz_3].legend_labels = ['label_0']
33-
_viz_config[_viz_3].synapse = 0.01
34-
_viz_config[_viz_3].x = 0.8016667532806668
35-
_viz_config[_viz_3].y = 0.7106255243594504
36-
_viz_config[_viz_3].width = 0.13850415512465375
37-
_viz_config[_viz_3].height = 0.18083182640144665
38-
_viz_config[_viz_3].label_visible = True
3928
_viz_net_graph = nengo_gui.components.NetGraph()
4029
_viz_progress = nengo_gui.components.Progress()
4130
_viz_config[_viz_progress].x = 0
@@ -48,21 +37,21 @@ _viz_config[_viz_sim_control].shown_time = 0.5
4837
_viz_config[_viz_sim_control].kept_time = 4.0
4938
_viz_config[env].pos=(0.49863439098327234, 1.2535465790590687)
5039
_viz_config[env].size=(0.29494106379918655, 0.05045044512354506)
51-
_viz_config[model].pos=(0.0634570655505291, -0.08615261133288868)
52-
_viz_config[model].size=(0.8954611734951313, 0.8954611734951313)
40+
_viz_config[model].pos=(0.19267060726875407, 0.016160128715177136)
41+
_viz_config[model].size=(0.7858385976587554, 0.7858385976587554)
5342
_viz_config[model].expanded=True
5443
_viz_config[model].has_layout=True
55-
_viz_config[movement].pos=(0.8255734843067148, 1.0690267739030497)
56-
_viz_config[movement].size=(0.075181055478224, 0.050450445123545046)
57-
_viz_config[position].pos=(0.12604643094326723, 0.19045870165051)
58-
_viz_config[position].size=(0.0723440345167816, 0.05045044512354506)
59-
_viz_config[radar].pos=(0.4843601683374062, 1.071451576862655)
60-
_viz_config[radar].size=(0.09397631934778002, 0.06306305640443131)
61-
_viz_config[reward].pos=(0.3215221418227118, 0.4370894765073292)
62-
_viz_config[reward].size=(0.027346424993232085, 0.040376746672174134)
63-
_viz_config[state].pos=(0.3771058783570984, 0.1908731433935878)
44+
_viz_config[movement].pos=(0.15548512716379476, 0.24996069309428318)
45+
_viz_config[movement].size=(0.1, 0.1)
46+
_viz_config[position].pos=(0.5752419773806918, 1.0152372365661286)
47+
_viz_config[position].size=(0.1, 0.1)
48+
_viz_config[radar].pos=(0.1168014893583144, 0.37145651012735564)
49+
_viz_config[radar].size=(0.1, 0.1)
50+
_viz_config[reward].pos=(0.5452335256050307, 0.28120210217205893)
51+
_viz_config[reward].size=(0.1, 0.1)
52+
_viz_config[state].pos=(0.4759960334248104, 0.7624876369995743)
6453
_viz_config[state].size=(0.090430043145977, 0.06306305640443133)
65-
_viz_config[stim_radar].pos=(0.08183101895225513, 1.073876379822259)
66-
_viz_config[stim_radar].size=(0.075181055478224, 0.04560083920433565)
67-
_viz_config[value].pos=(0.49391319325690364, 0.437965664396405)
68-
_viz_config[value].size=(0.090430043145977, 0.06306305640443133)
54+
_viz_config[stim_radar].pos=(-0.05698205482833595, 0.3237229331568389)
55+
_viz_config[stim_radar].size=(0.1, 0.1)
56+
_viz_config[value].pos=(0.30774080340442267, 0.31389243854735405)
57+
_viz_config[value].size=(0.1, 0.1)

td_grid.py

-4
Original file line numberDiff line numberDiff line change
@@ -533,9 +533,7 @@ def take_action(self, action_idx, epsilon=0.1):
533533
action_idx = np.random.choice(np.arange(4))
534534

535535
x_pos, y_pos = self.compute_position(action_idx)
536-
print('TYPE')
537536
self.agent.set_position(x_pos, y_pos)
538-
print('TEST')
539537

540538
return action_idx
541539

@@ -565,9 +563,7 @@ def step(self, t, x):
565563

566564
qs = self.output[8:]
567565
idx = np.argmax(qs)
568-
print('TEST')
569566
self.current_action_index = self.take_action(idx)
570-
print('PASSED')
571567

572568
# then on next step store new qvalues
573569
elif int(t * 1000) % self.stepsize == 1:

utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,4 @@ def step(self, t, x):
102102
self.output = np.concatenate(
103103
(c_output, f_output, qvalues))
104104

105-
return self.output
105+
return self.output

0 commit comments

Comments
 (0)