-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ea275f7
commit 52f8857
Showing
11 changed files
with
778 additions
and
263 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Solving Pong with EvoX\n", | ||
"\n", | ||
"## Notice\n", | ||
"\n", | ||
"Running this notebook requires installing evox, gymnasium, ale-py as well as Pong's image.\n", | ||
"Due to copy right issues, we cannot distribute the image here. Please follow the instructions [here](https://github.com/Farama-Foundation/AutoROM) to install the image." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from evox import workflows, algorithms, problems\n", | ||
"from evox.monitors import StdSOMonitor\n", | ||
"from evox.utils import TreeAndVector\n", | ||
"import jax\n", | ||
"import jax.numpy as jnp\n", | ||
"from flax import linen as nn" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# https://docs.ray.io/en/latest/ray-core/examples/plot_pong_example.html\n", | ||
"def pong_preprocess(img):\n", | ||
" # Crop the image.\n", | ||
" img = img[35:195]\n", | ||
" # Downsample by factor of 2.\n", | ||
" img = img[::2, ::2, 0]\n", | ||
" # Erase background (background type 1 and 2).\n", | ||
" img = jnp.where((img == 144) | (img == 109), 0, img)\n", | ||
" # Set everything else (paddles, ball) to 1.\n", | ||
" img = jnp.where(img != 0, 1, img)\n", | ||
" return img" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class PongPolicy(nn.Module):\n", | ||
" \"\"\"A simple model for cartpole\"\"\"\n", | ||
"\n", | ||
" @nn.compact\n", | ||
" def __call__(self, img):\n", | ||
" x = pong_preprocess(img)\n", | ||
" x = x.astype(jnp.float32)\n", | ||
" x = x.reshape(-1)\n", | ||
" x = nn.Dense(128)(x)\n", | ||
" x = nn.relu(x)\n", | ||
" x = nn.Dense(6)(x)\n", | ||
"\n", | ||
" return jnp.argmax(x)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"key = jax.random.PRNGKey(42)\n", | ||
"model_key, workflow_key = jax.random.split(key)\n", | ||
"\n", | ||
"model = PongPolicy()\n", | ||
"params = model.init(model_key, jnp.zeros((210, 160, 3)))\n", | ||
"adapter = TreeAndVector(params)\n", | ||
"monitor = StdSOMonitor()\n", | ||
"problem = problems.neuroevolution.Gym(\n", | ||
" env_name=\"ALE/Pong-v5\",\n", | ||
" env_options={\"full_action_space\": False},\n", | ||
" policy=jax.jit(model.apply),\n", | ||
" num_workers=16,\n", | ||
" env_per_worker=4,\n", | ||
" controller_options={\n", | ||
" \"num_cpus\": 0,\n", | ||
" \"num_gpus\": 0,\n", | ||
" },\n", | ||
" worker_options={\"num_cpus\": 1, \"num_gpus\": 1 / 16},\n", | ||
" batch_policy=False,\n", | ||
")\n", | ||
"center = adapter.to_vector(params)\n", | ||
"# create a workflow\n", | ||
"workflow = workflows.StdWorkflow(\n", | ||
" algorithm=algorithms.PGPE(\n", | ||
" optimizer=\"adam\",\n", | ||
" center_init=center,\n", | ||
" pop_size=64,\n", | ||
" ),\n", | ||
" problem=problem,\n", | ||
" pop_transform=adapter.batched_to_tree,\n", | ||
" monitor=monitor,\n", | ||
" opt_direction=\"max\",\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# init the workflow\n", | ||
"state = workflow.init(workflow_key)\n", | ||
"# run the workflow for 100 steps\n", | ||
"for i in range(100):\n", | ||
" print(monitor.get_best_fitness())\n", | ||
" state = workflow.step(state)\n", | ||
"\n", | ||
"sample_pop, state = workflow.sample(state)\n", | ||
"# the result should be close to 0\n", | ||
"best_fitness = monitor.get_best_fitness()\n", | ||
"print(best_fitness)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.