|
2 | 2 | "cells": [
|
3 | 3 | {
|
4 | 4 | "cell_type": "markdown",
|
5 |
| - "id": "fc7b28eb-db2b-4a02-87d4-dfc44a9fecc2", |
| 5 | + "id": "0", |
6 | 6 | "metadata": {},
|
7 | 7 | "source": [
|
8 | 8 | "# Visualization of the `brax` Experiment Results\n",
|
|
13 | 13 | {
|
14 | 14 | "cell_type": "code",
|
15 | 15 | "execution_count": null,
|
16 |
| - "id": "d5350a77-c84b-4d05-9e1c-4a6e0e2f6d0d", |
| 16 | + "id": "1", |
17 | 17 | "metadata": {},
|
18 | 18 | "outputs": [],
|
19 | 19 | "source": [
|
|
24 | 24 | {
|
25 | 25 | "cell_type": "code",
|
26 | 26 | "execution_count": null,
|
27 |
| - "id": "e4c4d0df-0480-4c5f-a4ab-d34727286cc1", |
| 27 | + "id": "2", |
28 | 28 | "metadata": {},
|
29 | 29 | "outputs": [],
|
30 | 30 | "source": [
|
|
34 | 34 | },
|
35 | 35 | {
|
36 | 36 | "cell_type": "markdown",
|
37 |
| - "id": "d64f3d04-680a-43c5-a807-dbaf9f917198", |
| 37 | + "id": "3", |
38 | 38 | "metadata": {},
|
39 | 39 | "source": [
|
40 | 40 | "---"
|
|
43 | 43 | {
|
44 | 44 | "cell_type": "code",
|
45 | 45 | "execution_count": null,
|
46 |
| - "id": "7025f8b3-9982-4196-9790-9e8c1491dc06", |
| 46 | + "id": "4", |
47 | 47 | "metadata": {},
|
48 | 48 | "outputs": [],
|
49 | 49 | "source": [
|
|
55 | 55 | {
|
56 | 56 | "cell_type": "code",
|
57 | 57 | "execution_count": null,
|
58 |
| - "id": "d5bccb74-a64a-4e55-9ae9-9686f772fa60", |
| 58 | + "id": "5", |
59 | 59 | "metadata": {},
|
60 | 60 | "outputs": [],
|
61 | 61 | "source": [
|
|
66 | 66 | {
|
67 | 67 | "cell_type": "code",
|
68 | 68 | "execution_count": null,
|
69 |
| - "id": "0d71c7c3-bb36-47fb-88dc-bd285c54c44a", |
| 69 | + "id": "6", |
70 | 70 | "metadata": {},
|
71 | 71 | "outputs": [],
|
72 | 72 | "source": [
|
|
77 | 77 | {
|
78 | 78 | "cell_type": "code",
|
79 | 79 | "execution_count": null,
|
80 |
| - "id": "54c7bee4-bf5c-445c-b009-40484391461c", |
| 80 | + "id": "7", |
81 | 81 | "metadata": {},
|
82 | 82 | "outputs": [],
|
83 | 83 | "source": [
|
|
89 | 89 | {
|
90 | 90 | "cell_type": "code",
|
91 | 91 | "execution_count": null,
|
92 |
| - "id": "30eca014-be71-4162-b675-0eaf8b037bce", |
| 92 | + "id": "8", |
93 | 93 | "metadata": {},
|
94 | 94 | "outputs": [],
|
95 | 95 | "source": [
|
|
100 | 100 | },
|
101 | 101 | {
|
102 | 102 | "cell_type": "markdown",
|
103 |
| - "id": "78ce1f2f-cb89-432e-ad98-07f7e1d5d457", |
| 103 | + "id": "9", |
104 | 104 | "metadata": {},
|
105 | 105 | "source": [
|
106 | 106 | "Below, we put the values of the center solution into the policy network as parameters:"
|
|
109 | 109 | {
|
110 | 110 | "cell_type": "code",
|
111 | 111 | "execution_count": null,
|
112 |
| - "id": "545d693d-ccd9-4ea7-b068-06b832f32be5", |
| 112 | + "id": "10", |
113 | 113 | "metadata": {},
|
114 | 114 | "outputs": [],
|
115 | 115 | "source": [
|
|
118 | 118 | },
|
119 | 119 | {
|
120 | 120 | "cell_type": "markdown",
|
121 |
| - "id": "e7c7581d-24ba-4f06-88e2-41bb3274cd37", |
| 121 | + "id": "11", |
122 | 122 | "metadata": {},
|
123 | 123 | "source": [
|
124 | 124 | "---\n",
|
|
131 | 131 | {
|
132 | 132 | "cell_type": "code",
|
133 | 133 | "execution_count": null,
|
134 |
| - "id": "8d798a71-003f-4bed-9f12-391dc7b797c0", |
| 134 | + "id": "12", |
135 | 135 | "metadata": {},
|
136 | 136 | "outputs": [],
|
137 | 137 | "source": [
|
|
143 | 143 | {
|
144 | 144 | "cell_type": "code",
|
145 | 145 | "execution_count": null,
|
146 |
| - "id": "30080d4b-227c-4bde-8009-7a97c6e177f1", |
| 146 | + "id": "13", |
147 | 147 | "metadata": {},
|
148 | 148 | "outputs": [],
|
149 | 149 | "source": [
|
|
153 | 153 | {
|
154 | 154 | "cell_type": "code",
|
155 | 155 | "execution_count": null,
|
156 |
| - "id": "c64a6482-0533-40c8-b8f9-1b1ad3ee0019", |
| 156 | + "id": "14", |
157 | 157 | "metadata": {},
|
158 | 158 | "outputs": [],
|
159 | 159 | "source": [
|
|
163 | 163 | {
|
164 | 164 | "cell_type": "code",
|
165 | 165 | "execution_count": null,
|
166 |
| - "id": "ee69ba99-bcb7-4d21-9f5c-6c330cb3ed96", |
| 166 | + "id": "15", |
167 | 167 | "metadata": {},
|
168 | 168 | "outputs": [],
|
169 | 169 | "source": [
|
|
183 | 183 | {
|
184 | 184 | "cell_type": "code",
|
185 | 185 | "execution_count": null,
|
186 |
| - "id": "c99f0ea4-7638-4e8a-a0a2-d9f483c8c096", |
| 186 | + "id": "16", |
187 | 187 | "metadata": {},
|
188 | 188 | "outputs": [],
|
189 | 189 | "source": [
|
|
194 | 194 | },
|
195 | 195 | {
|
196 | 196 | "cell_type": "markdown",
|
197 |
| - "id": "19d38fe1-9d6f-40ec-926d-3c4066a4b66a", |
| 197 | + "id": "17", |
198 | 198 | "metadata": {},
|
199 | 199 | "source": [
|
200 | 200 | "Below, we define a utility function named `use_policy(...)`.\n",
|
|
214 | 214 | {
|
215 | 215 | "cell_type": "code",
|
216 | 216 | "execution_count": null,
|
217 |
| - "id": "e8c6554d-4bdb-49b9-8c2e-956bce6ddb8a", |
| 217 | + "id": "18", |
218 | 218 | "metadata": {},
|
219 | 219 | "outputs": [],
|
220 | 220 | "source": [
|
|
238 | 238 | },
|
239 | 239 | {
|
240 | 240 | "cell_type": "markdown",
|
241 |
| - "id": "28c044fe-4639-411c-be78-ac09cbe5e78f", |
| 241 | + "id": "19", |
242 | 242 | "metadata": {},
|
243 | 243 | "source": [
|
244 | 244 | "We now initialize a new instance of our brax environment, and trigger the jit compilation on its `reset` and `step` methods."
|
|
247 | 247 | {
|
248 | 248 | "cell_type": "code",
|
249 | 249 | "execution_count": null,
|
250 |
| - "id": "c1251e82-1d0f-4e43-a6c5-1ec4c0275dab", |
| 250 | + "id": "20", |
251 | 251 | "metadata": {},
|
252 | 252 | "outputs": [],
|
253 | 253 | "source": [
|
|
258 | 258 | },
|
259 | 259 | {
|
260 | 260 | "cell_type": "markdown",
|
261 |
| - "id": "55229cc2-aad5-4c78-b095-010a538adb40", |
| 261 | + "id": "21", |
262 | 262 | "metadata": {},
|
263 | 263 | "source": [
|
264 | 264 | "Below we run our policy and collect the states of the episodes."
|
|
267 | 267 | {
|
268 | 268 | "cell_type": "code",
|
269 | 269 | "execution_count": null,
|
270 |
| - "id": "220ec837-1d1a-401e-a144-5516bdb3e493", |
| 270 | + "id": "22", |
271 | 271 | "metadata": {},
|
272 | 272 | "outputs": [],
|
273 | 273 | "source": [
|
|
291 | 291 | },
|
292 | 292 | {
|
293 | 293 | "cell_type": "markdown",
|
294 |
| - "id": "c9bbe53a-40f5-4bf5-a063-47f1d88032d6", |
| 294 | + "id": "23", |
295 | 295 | "metadata": {},
|
296 | 296 | "source": [
|
297 | 297 | "Length of the episode and the total reward:"
|
|
300 | 300 | {
|
301 | 301 | "cell_type": "code",
|
302 | 302 | "execution_count": null,
|
303 |
| - "id": "80345e9c-0694-413d-81c2-205dfacdfb5a", |
| 303 | + "id": "24", |
304 | 304 | "metadata": {},
|
305 | 305 | "outputs": [],
|
306 | 306 | "source": [
|
|
310 | 310 | {
|
311 | 311 | "cell_type": "code",
|
312 | 312 | "execution_count": null,
|
313 |
| - "id": "8d2e2506-5bb2-4b1f-bb00-004465008fa3", |
| 313 | + "id": "25", |
314 | 314 | "metadata": {},
|
315 | 315 | "outputs": [],
|
316 | 316 | "source": [
|
|
325 | 325 | },
|
326 | 326 | {
|
327 | 327 | "cell_type": "markdown",
|
328 |
| - "id": "8b6ec424-e6cc-453a-93fd-eb07e44c1bd6", |
| 328 | + "id": "26", |
329 | 329 | "metadata": {},
|
330 | 330 | "source": [
|
331 | 331 | "Visualization of the policy:"
|
|
334 | 334 | {
|
335 | 335 | "cell_type": "code",
|
336 | 336 | "execution_count": null,
|
337 |
| - "id": "29b9c3aa-8068-40bc-9ec2-e5ff759ed22c", |
| 337 | + "id": "27", |
338 | 338 | "metadata": {},
|
339 | 339 | "outputs": [],
|
340 | 340 | "source": [
|
|
349 | 349 | {
|
350 | 350 | "cell_type": "code",
|
351 | 351 | "execution_count": null,
|
352 |
| - "id": "b50e36a7-dd33-4fc3-9f9d-2d7dca93110b", |
| 352 | + "id": "28", |
353 | 353 | "metadata": {},
|
354 | 354 | "outputs": [],
|
355 | 355 | "source": [
|
|
0 commit comments