|
2 | 2 | "cells": [
|
3 | 3 | {
|
4 | 4 | "cell_type": "markdown",
|
5 |
| - "id": "096c868b-ef91-42df-903f-a0a046ccad95", |
| 5 | + "id": "0", |
6 | 6 | "metadata": {},
|
7 | 7 | "source": [
|
8 | 8 | "# Solving a Brax environment using EvoTorch\n",
|
|
27 | 27 | {
|
28 | 28 | "cell_type": "code",
|
29 | 29 | "execution_count": null,
|
30 |
| - "id": "c786f423-220e-4c86-8e06-05edbfd45c42", |
| 30 | + "id": "1", |
31 | 31 | "metadata": {},
|
32 | 32 | "outputs": [],
|
33 | 33 | "source": [
|
|
42 | 42 | },
|
43 | 43 | {
|
44 | 44 | "cell_type": "markdown",
|
45 |
| - "id": "8b547977-1478-4b82-95c4-12a79769336d", |
| 45 | + "id": "2", |
46 | 46 | "metadata": {},
|
47 | 47 | "source": [
|
48 | 48 | "We now check if CUDA is available. If it is, we prepare a configuration which will tell `VecGymNE` to use a single GPU both for the population and for the fitness evaluation operations. If CUDA is not available, we will instead turn to actor-based parallelization on the CPU to boost the performance."
|
|
51 | 51 | {
|
52 | 52 | "cell_type": "code",
|
53 | 53 | "execution_count": null,
|
54 |
| - "id": "d45edbf5-b0e9-43a0-869a-5146321ddd4e", |
| 54 | + "id": "3", |
55 | 55 | "metadata": {},
|
56 | 56 | "outputs": [],
|
57 | 57 | "source": [
|
|
90 | 90 | },
|
91 | 91 | {
|
92 | 92 | "cell_type": "markdown",
|
93 |
| - "id": "f1387654-13b4-4367-9d3e-ffda6b3aaf5d", |
| 93 | + "id": "4", |
94 | 94 | "metadata": {},
|
95 | 95 | "source": [
|
96 | 96 | "We now define our policy. The policy can be expressed as a string, or as an instance or as a subclass of `torch.nn.Module`."
|
|
99 | 99 | {
|
100 | 100 | "cell_type": "code",
|
101 | 101 | "execution_count": null,
|
102 |
| - "id": "4d62b5e2-6471-4c86-b028-7997426270f8", |
| 102 | + "id": "5", |
103 | 103 | "metadata": {},
|
104 | 104 | "outputs": [],
|
105 | 105 | "source": [
|
|
162 | 162 | },
|
163 | 163 | {
|
164 | 164 | "cell_type": "markdown",
|
165 |
| - "id": "166e7f59-b727-44b4-9f60-dc7cbec9943b", |
| 165 | + "id": "6", |
166 | 166 | "metadata": {},
|
167 | 167 | "source": [
|
168 | 168 | "Below, we instantiate our `VecGymNE` problem."
|
|
171 | 171 | {
|
172 | 172 | "cell_type": "code",
|
173 | 173 | "execution_count": null,
|
174 |
| - "id": "56b5f554-cdec-40a5-8e21-fa09ea53e047", |
| 174 | + "id": "7", |
175 | 175 | "metadata": {},
|
176 | 176 | "outputs": [],
|
177 | 177 | "source": [
|
|
207 | 207 | },
|
208 | 208 | {
|
209 | 209 | "cell_type": "markdown",
|
210 |
| - "id": "bce02d7c-400c-4c22-9bb8-aa70fa4b1da2", |
| 210 | + "id": "8", |
211 | 211 | "metadata": {},
|
212 | 212 | "source": [
|
213 | 213 | "---\n",
|
|
220 | 220 | },
|
221 | 221 | {
|
222 | 222 | "cell_type": "markdown",
|
223 |
| - "id": "95417793-3835-47b1-b10a-7f36e78fa3ad", |
| 223 | + "id": "9", |
224 | 224 | "metadata": {},
|
225 | 225 | "source": [
|
226 | 226 | "Initialize a PGPE to work on the problem.\n",
|
|
233 | 233 | {
|
234 | 234 | "cell_type": "code",
|
235 | 235 | "execution_count": null,
|
236 |
| - "id": "bce9f851-68aa-4e67-9dbb-2474a5ebd4cf", |
| 236 | + "id": "10", |
237 | 237 | "metadata": {},
|
238 | 238 | "outputs": [],
|
239 | 239 | "source": [
|
|
257 | 257 | },
|
258 | 258 | {
|
259 | 259 | "cell_type": "markdown",
|
260 |
| - "id": "da60f156-6756-41a4-b261-82ee62d7f7cb", |
| 260 | + "id": "11", |
261 | 261 | "metadata": {},
|
262 | 262 | "source": [
|
263 | 263 | "We register two loggers for our PGPE instance.\n",
|
|
269 | 269 | {
|
270 | 270 | "cell_type": "code",
|
271 | 271 | "execution_count": null,
|
272 |
| - "id": "91270ba2-ce78-43e7-bf01-20c94b0529c3", |
| 272 | + "id": "12", |
273 | 273 | "metadata": {},
|
274 | 274 | "outputs": [],
|
275 | 275 | "source": [
|
|
279 | 279 | },
|
280 | 280 | {
|
281 | 281 | "cell_type": "markdown",
|
282 |
| - "id": "7b4d16c1-078c-4d7d-bd6f-e3ff28173667", |
| 282 | + "id": "13", |
283 | 283 | "metadata": {},
|
284 | 284 | "source": [
|
285 | 285 | "We are now ready to start the evolutionary search."
|
|
288 | 288 | {
|
289 | 289 | "cell_type": "code",
|
290 | 290 | "execution_count": null,
|
291 |
| - "id": "0a1a84a0-47ea-4592-bd37-5e96fc8f6e54", |
| 291 | + "id": "14", |
292 | 292 | "metadata": {},
|
293 | 293 | "outputs": [],
|
294 | 294 | "source": [
|
|
298 | 298 | {
|
299 | 299 | "cell_type": "code",
|
300 | 300 | "execution_count": null,
|
301 |
| - "id": "8491f968-4f43-4df6-aac0-a09c756185da", |
| 301 | + "id": "15", |
302 | 302 | "metadata": {},
|
303 | 303 | "outputs": [],
|
304 | 304 | "source": [
|
|
310 | 310 | {
|
311 | 311 | "cell_type": "code",
|
312 | 312 | "execution_count": null,
|
313 |
| - "id": "5731e007-55b3-49ef-9285-9d9137232c9d", |
| 313 | + "id": "16", |
314 | 314 | "metadata": {},
|
315 | 315 | "outputs": [],
|
316 | 316 | "source": [
|
|
319 | 319 | },
|
320 | 320 | {
|
321 | 321 | "cell_type": "markdown",
|
322 |
| - "id": "0efa9df1-c978-4c2a-a528-c98e761caec7", |
| 322 | + "id": "17", |
323 | 323 | "metadata": {},
|
324 | 324 | "source": [
|
325 | 325 | "Now, we receive our trained policy as a torch module."
|
|
328 | 328 | {
|
329 | 329 | "cell_type": "code",
|
330 | 330 | "execution_count": null,
|
331 |
| - "id": "8ea4d211-08c2-4a59-ab84-23988646895e", |
| 331 | + "id": "18", |
332 | 332 | "metadata": {},
|
333 | 333 | "outputs": [],
|
334 | 334 | "source": [
|
|
339 | 339 | },
|
340 | 340 | {
|
341 | 341 | "cell_type": "markdown",
|
342 |
| - "id": "e7c7581d-24ba-4f06-88e2-41bb3274cd37", |
| 342 | + "id": "19", |
343 | 343 | "metadata": {},
|
344 | 344 | "source": [
|
345 | 345 | "---\n",
|
|
352 | 352 | {
|
353 | 353 | "cell_type": "code",
|
354 | 354 | "execution_count": null,
|
355 |
| - "id": "c99f0ea4-7638-4e8a-a0a2-d9f483c8c096", |
| 355 | + "id": "20", |
356 | 356 | "metadata": {},
|
357 | 357 | "outputs": [],
|
358 | 358 | "source": [
|
|
386 | 386 | },
|
387 | 387 | {
|
388 | 388 | "cell_type": "markdown",
|
389 |
| - "id": "19d38fe1-9d6f-40ec-926d-3c4066a4b66a", |
| 389 | + "id": "21", |
390 | 390 | "metadata": {},
|
391 | 391 | "source": [
|
392 | 392 | "Below, we define a utility function named `use_policy(...)`.\n",
|
|
406 | 406 | {
|
407 | 407 | "cell_type": "code",
|
408 | 408 | "execution_count": null,
|
409 |
| - "id": "e8c6554d-4bdb-49b9-8c2e-956bce6ddb8a", |
| 409 | + "id": "22", |
410 | 410 | "metadata": {},
|
411 | 411 | "outputs": [],
|
412 | 412 | "source": [
|
|
430 | 430 | },
|
431 | 431 | {
|
432 | 432 | "cell_type": "markdown",
|
433 |
| - "id": "28c044fe-4639-411c-be78-ac09cbe5e78f", |
| 433 | + "id": "23", |
434 | 434 | "metadata": {},
|
435 | 435 | "source": [
|
436 | 436 | "We now initialize a new instance of our brax environment, and trigger the jit compilation on its `reset` and `step` methods."
|
|
439 | 439 | {
|
440 | 440 | "cell_type": "code",
|
441 | 441 | "execution_count": null,
|
442 |
| - "id": "c1251e82-1d0f-4e43-a6c5-1ec4c0275dab", |
| 442 | + "id": "24", |
443 | 443 | "metadata": {},
|
444 | 444 | "outputs": [],
|
445 | 445 | "source": [
|
|
454 | 454 | },
|
455 | 455 | {
|
456 | 456 | "cell_type": "markdown",
|
457 |
| - "id": "55229cc2-aad5-4c78-b095-010a538adb40", |
| 457 | + "id": "25", |
458 | 458 | "metadata": {},
|
459 | 459 | "source": [
|
460 | 460 | "Below we run our policy and collect the states of the episodes."
|
|
463 | 463 | {
|
464 | 464 | "cell_type": "code",
|
465 | 465 | "execution_count": null,
|
466 |
| - "id": "220ec837-1d1a-401e-a144-5516bdb3e493", |
| 466 | + "id": "26", |
467 | 467 | "metadata": {},
|
468 | 468 | "outputs": [],
|
469 | 469 | "source": [
|
|
489 | 489 | },
|
490 | 490 | {
|
491 | 491 | "cell_type": "markdown",
|
492 |
| - "id": "c9bbe53a-40f5-4bf5-a063-47f1d88032d6", |
| 492 | + "id": "27", |
493 | 493 | "metadata": {},
|
494 | 494 | "source": [
|
495 | 495 | "Length of the episode and the total reward:"
|
|
498 | 498 | {
|
499 | 499 | "cell_type": "code",
|
500 | 500 | "execution_count": null,
|
501 |
| - "id": "80345e9c-0694-413d-81c2-205dfacdfb5a", |
| 501 | + "id": "28", |
502 | 502 | "metadata": {},
|
503 | 503 | "outputs": [],
|
504 | 504 | "source": [
|
|
507 | 507 | },
|
508 | 508 | {
|
509 | 509 | "cell_type": "markdown",
|
510 |
| - "id": "8b6ec424-e6cc-453a-93fd-eb07e44c1bd6", |
| 510 | + "id": "29", |
511 | 511 | "metadata": {},
|
512 | 512 | "source": [
|
513 | 513 | "Visualization of the policy:"
|
|
516 | 516 | {
|
517 | 517 | "cell_type": "code",
|
518 | 518 | "execution_count": null,
|
519 |
| - "id": "7ec60419-1ad0-4f19-bc1e-f2048577ea29", |
| 519 | + "id": "30", |
520 | 520 | "metadata": {},
|
521 | 521 | "outputs": [],
|
522 | 522 | "source": [
|
|
531 | 531 | {
|
532 | 532 | "cell_type": "code",
|
533 | 533 | "execution_count": null,
|
534 |
| - "id": "a07c70f6-2c93-43a1-b4c3-edd3f395302a", |
| 534 | + "id": "31", |
535 | 535 | "metadata": {},
|
536 | 536 | "outputs": [],
|
537 | 537 | "source": [
|
|
0 commit comments