From 52f885706eb666b6c711cb0be3c54927a2d32337 Mon Sep 17 00:00:00 2001 From: Bill Huang Date: Tue, 24 Oct 2023 16:32:11 +0800 Subject: [PATCH] doc: integrate examples into docs --- docs/source/example/atari_pong.ipynb | 151 +++++++++ .../custom_algorithm_and_problem.ipynb | 290 ++++++++++++++++++ docs/source/example/gym_classic_control.ipynb | 191 ++++++++++++ docs/source/example/index.md | 10 + docs/source/example/pso_ackley.ipynb | 135 ++++++++ docs/source/index.md | 1 + examples/cmaes_classic_control.py | 76 ----- examples/custom_genetic_algorithm.py | 56 ---- examples/custom_problem_onemax.py | 25 -- examples/pgpe_pong.py | 80 ----- examples/pso_ackley.py | 26 -- 11 files changed, 778 insertions(+), 263 deletions(-) create mode 100644 docs/source/example/atari_pong.ipynb create mode 100644 docs/source/example/custom_algorithm_and_problem.ipynb create mode 100644 docs/source/example/gym_classic_control.ipynb create mode 100644 docs/source/example/index.md create mode 100644 docs/source/example/pso_ackley.ipynb delete mode 100644 examples/cmaes_classic_control.py delete mode 100644 examples/custom_genetic_algorithm.py delete mode 100644 examples/custom_problem_onemax.py delete mode 100644 examples/pgpe_pong.py delete mode 100644 examples/pso_ackley.py diff --git a/docs/source/example/atari_pong.ipynb b/docs/source/example/atari_pong.ipynb new file mode 100644 index 000000000..9d89c06e7 --- /dev/null +++ b/docs/source/example/atari_pong.ipynb @@ -0,0 +1,151 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Solving Pong with EvoX\n", + "\n", + "## Notice\n", + "\n", + "Running this notebook requires installing evox, gymnasium, ale-py as well as Pong's image.\n", + "Due to copy right issues, we cannot distribute the image here. Please follow the instructions [here](https://github.com/Farama-Foundation/AutoROM) to install the image." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from evox import workflows, algorithms, problems\n", + "from evox.monitors import StdSOMonitor\n", + "from evox.utils import TreeAndVector\n", + "import jax\n", + "import jax.numpy as jnp\n", + "from flax import linen as nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# https://docs.ray.io/en/latest/ray-core/examples/plot_pong_example.html\n", + "def pong_preprocess(img):\n", + " # Crop the image.\n", + " img = img[35:195]\n", + " # Downsample by factor of 2.\n", + " img = img[::2, ::2, 0]\n", + " # Erase background (background type 1 and 2).\n", + " img = jnp.where((img == 144) | (img == 109), 0, img)\n", + " # Set everything else (paddles, ball) to 1.\n", + " img = jnp.where(img != 0, 1, img)\n", + " return img" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class PongPolicy(nn.Module):\n", + " \"\"\"A simple model for cartpole\"\"\"\n", + "\n", + " @nn.compact\n", + " def __call__(self, img):\n", + " x = pong_preprocess(img)\n", + " x = x.astype(jnp.float32)\n", + " x = x.reshape(-1)\n", + " x = nn.Dense(128)(x)\n", + " x = nn.relu(x)\n", + " x = nn.Dense(6)(x)\n", + "\n", + " return jnp.argmax(x)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "key = jax.random.PRNGKey(42)\n", + "model_key, workflow_key = jax.random.split(key)\n", + "\n", + "model = PongPolicy()\n", + "params = model.init(model_key, jnp.zeros((210, 160, 3)))\n", + "adapter = TreeAndVector(params)\n", + "monitor = StdSOMonitor()\n", + "problem = problems.neuroevolution.Gym(\n", + " env_name=\"ALE/Pong-v5\",\n", + " env_options={\"full_action_space\": False},\n", + " policy=jax.jit(model.apply),\n", + " num_workers=16,\n", + " env_per_worker=4,\n", + " controller_options={\n", + " \"num_cpus\": 0,\n", + " \"num_gpus\": 0,\n", + " },\n", + " worker_options={\"num_cpus\": 1, \"num_gpus\": 1 / 16},\n", + " batch_policy=False,\n", + ")\n", + "center = adapter.to_vector(params)\n", + "# create a workflow\n", + "workflow = workflows.StdWorkflow(\n", + " algorithm=algorithms.PGPE(\n", + " optimizer=\"adam\",\n", + " center_init=center,\n", + " pop_size=64,\n", + " ),\n", + " problem=problem,\n", + " pop_transform=adapter.batched_to_tree,\n", + " monitor=monitor,\n", + " opt_direction=\"max\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# init the workflow\n", + "state = workflow.init(workflow_key)\n", + "# run the workflow for 100 steps\n", + "for i in range(100):\n", + " print(monitor.get_best_fitness())\n", + " state = workflow.step(state)\n", + "\n", + "sample_pop, state = workflow.sample(state)\n", + "# the result should be close to 0\n", + "best_fitness = monitor.get_best_fitness()\n", + "print(best_fitness)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/source/example/custom_algorithm_and_problem.ipynb b/docs/source/example/custom_algorithm_and_problem.ipynb new file mode 100644 index 000000000..280ce54c9 --- /dev/null +++ b/docs/source/example/custom_algorithm_and_problem.ipynb @@ -0,0 +1,290 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# install evox, skip it if you have already installed evox\n", + "try:\n", + " import evox\n", + "except ImportError:\n", + " !pip install --disable-pip-version-check --upgrade -q evox\n", + " import evox" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from evox import Algorithm, Problem, State, jit_class, monitors, workflows\n", + "from evox.operators import mutation, crossover, selection\n", + "from jax import random\n", + "import jax.numpy as jnp" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "@jit_class\n", + "class OneMax(Problem):\n", + " def __init__(self) -> None:\n", + " super().__init__()\n", + "\n", + " def evaluate(self, state, bitstrings):\n", + " # bitstrings has shape (pop_size, num_bits)\n", + " # so sum along the axis 1.\n", + " fitness = jnp.sum(bitstrings, axis=1)\n", + " return fitness, state" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "@jit_class\n", + "class CustomGA(Algorithm):\n", + " def __init__(self, pop_size, ndim, flip_prob):\n", + " super().__init__()\n", + " # those are hyperparameters that stay fixed.\n", + " self.pop_size = pop_size\n", + " self.ndim = ndim\n", + " # the probability of fliping each bit\n", + " self.flip_prob = flip_prob\n", + "\n", + " def setup(self, key):\n", + " # initialize the state\n", + " # state are mutable data like the population, offsprings\n", + " # the population is randomly initialized.\n", + " # we don't have any offspring now, but initialize it as a placeholder\n", + " # because jax want static shaped arrays.\n", + " key, subkey = random.split(key)\n", + " pop = random.uniform(subkey, (self.pop_size, self.ndim)) < 0.5\n", + " return State(\n", + " pop=pop,\n", + " offsprings=jnp.empty((self.pop_size * 2, self.ndim)),\n", + " fit=jnp.full((self.pop_size,), jnp.inf),\n", + " key=key,\n", + " )\n", + "\n", + " def ask(self, state):\n", + " key, mut_key, x_key = random.split(state.key, 3)\n", + " # here we do mutation and crossover (reproduction)\n", + " # for simplicity, we didn't use any mating selections\n", + " # so the offspring is twice as large as the population\n", + " offsprings = jnp.concatenate(\n", + " (\n", + " mutation.bitflip(mut_key, state.pop, self.flip_prob),\n", + " crossover.one_point(x_key, state.pop),\n", + " ),\n", + " axis=0,\n", + " )\n", + " # return the candidate solution and update the state\n", + " return offsprings, state.update(offsprings=offsprings, key=key)\n", + "\n", + " def tell(self, state, fitness):\n", + " # here we do selection\n", + " merged_pop = jnp.concatenate([state.pop, state.offsprings])\n", + " merged_fit = jnp.concatenate([state.fit, fitness])\n", + " new_pop, new_fit = selection.topk_fit(merged_pop, merged_fit, self.pop_size)\n", + " # replace the old population\n", + " return state.update(pop=new_pop, fit=new_fit)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "algorithm = CustomGA(\n", + " pop_size=128,\n", + " ndim=100,\n", + " flip_prob=0.1,\n", + ")\n", + "problem = OneMax()\n", + "monitor = monitors.StdSOMonitor()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# create a workflow\n", + "workflow = workflows.StdWorkflow(\n", + " algorithm,\n", + " problem,\n", + " monitor,\n", + " record_pop=True,\n", + " opt_direction=\"max\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# init the workflow\n", + "key = random.PRNGKey(42)\n", + "state = workflow.init(key)\n", + "\n", + "# run the workflow for 20 iterations\n", + "for i in range(20):\n", + " state = workflow.step(state)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array(-93, dtype=int32)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "monitor.get_best_fitness()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([ True, True, True, True, True, True, False, True, True,\n", + " True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, False, True, True, True, True,\n", + " True, True, True, True, False, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True,\n", + " True, True, False, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, False, True,\n", + " True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, False, True, True, True, True,\n", + " True, True, False, True, True, True, True, True, True,\n", + " True], dtype=bool)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "monitor.get_best_solution()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# run the workflow for another 20 iterations\n", + "for i in range(20):\n", + " state = workflow.step(state)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array(-100, dtype=int32)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "monitor.get_best_fitness()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([ True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True,\n", + " True, True, True, True, True, True, True, True, True,\n", + " True], dtype=bool)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "monitor.get_best_solution()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/source/example/gym_classic_control.ipynb b/docs/source/example/gym_classic_control.ipynb new file mode 100644 index 000000000..6a439977d --- /dev/null +++ b/docs/source/example/gym_classic_control.ipynb @@ -0,0 +1,191 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# install evox, skip it if you have already installed evox\n", + "try:\n", + " import evox\n", + "except ImportError:\n", + " !pip install --disable-pip-version-check --upgrade -q evox gymnasium flax\n", + " import evox" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from evox import workflows, algorithms, problems\n", + "from evox.monitors import StdSOMonitor\n", + "from evox.utils import TreeAndVector\n", + "import jax\n", + "import jax.numpy as jnp\n", + "from flax import linen as nn" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "gym_name = \"Pendulum-v1\" # choose a setup\n", + "\n", + "def tanh2(x):\n", + " return 2 * nn.tanh(x)\n", + "\n", + "policy_params = {\n", + " \"Acrobot-v1\": (3, (6,), jnp.argmax),\n", + " \"CartPole-v1\": (2, (4,), jnp.argmax),\n", + " \"MountainCarContinuous-v0\": (1, (2,), nn.tanh),\n", + " \"MountainCar-v0\": (3, (2,), jnp.argmax),\n", + " \"Pendulum-v1\": (1, (3,), tanh2),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# define a policy model\n", + "class ClassicPolicy(nn.Module):\n", + " \"\"\"A simple model for Classic Control problem\"\"\"\n", + "\n", + " @nn.compact\n", + " def __call__(self, x):\n", + " x = x.at[1].multiply(10) # normalization\n", + " x = nn.Dense(16)(x)\n", + " x = nn.relu(x)\n", + " x = nn.Dense(policy_params[gym_name][0])(x)\n", + "\n", + " return policy_params[gym_name][2](x)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-10-24 15:54:46,501\tINFO worker.py:1553 -- Started a local Ray instance.\n" + ] + } + ], + "source": [ + "key = jax.random.PRNGKey(42)\n", + "model_key, workflow_key = jax.random.split(key)\n", + "\n", + "model = ClassicPolicy()\n", + "params = model.init(model_key, jnp.zeros(policy_params[gym_name][1]))\n", + "adapter = TreeAndVector(params)\n", + "monitor = StdSOMonitor()\n", + "problem = problems.neuroevolution.Gym(\n", + " env_name=gym_name,\n", + " policy=jax.jit(model.apply),\n", + " num_workers=16, # adjust according to your need\n", + " env_per_worker=4,\n", + " controller_options={\n", + " \"num_cpus\": 0,\n", + " \"num_gpus\": 0,\n", + " },\n", + " worker_options={\"num_cpus\": 1, \"num_gpus\": 1 / 16},\n", + " batch_policy=False,\n", + ")\n", + "center = adapter.to_vector(params)\n", + "# create a workflow\n", + "workflow = workflows.StdWorkflow(\n", + " algorithm=algorithms.CMAES(center_init=center, init_stdev=1, pop_size=64),\n", + " problem=problem,\n", + " pop_transform=adapter.batched_to_tree,\n", + " monitor=monitor,\n", + " opt_direction=\"max\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now run the workflow.\n", + "You may see warnings like\n", + "```\n", + "CUDA backend failed to initialize: Unable to load CUDA.\n", + "```\n", + "This is expected behaivor, because we have a controller thread that manages a group of Gym workers,\n", + "and the controller thread does not use GPU.\n", + "\n", + "If the program stucks, you may want to check whether is `num_workers` is larger than the number of available cores on your computer." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[2m\u001b[36m(Controller pid=641434)\u001b[0m CUDA backend failed to initialize: Unable to load CUDA. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-0.114485\n" + ] + } + ], + "source": [ + "# init the workflow\n", + "state = workflow.init(workflow_key)\n", + "# run the workflow for 100 steps\n", + "for i in range(100):\n", + " state = workflow.step(state)\n", + "\n", + "best_fitness = monitor.get_best_fitness()\n", + "print(best_fitness)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/source/example/index.md b/docs/source/example/index.md new file mode 100644 index 000000000..e7974c266 --- /dev/null +++ b/docs/source/example/index.md @@ -0,0 +1,10 @@ +# EvoX's examples + +```{toctree} +:maxdepth: 1 + +pso_ackley +custom_algorithm_and_problem +gym_classic_control +atari_pong +``` diff --git a/docs/source/example/pso_ackley.ipynb b/docs/source/example/pso_ackley.ipynb new file mode 100644 index 000000000..1f9a866dd --- /dev/null +++ b/docs/source/example/pso_ackley.ipynb @@ -0,0 +1,135 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# install evox, skip it if you have already installed evox\n", + "try:\n", + " import evox\n", + "except ImportError:\n", + " !pip install --disable-pip-version-check --upgrade -q evox\n", + " import evox" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from evox import algorithms, problems, workflows, monitors\n", + "import jax\n", + "import jax.numpy as jnp" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "algorithm = algorithms.PSO(\n", + " lb=jnp.full(shape=(2,), fill_value=-32),\n", + " ub=jnp.full(shape=(2,), fill_value=32),\n", + " pop_size=100,\n", + ")\n", + "problem = problems.numerical.Ackley()\n", + "monitor = monitors.StdSOMonitor()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# create a workflow\n", + "workflow = workflows.StdWorkflow(\n", + " algorithm,\n", + " problem,\n", + " monitor,\n", + " record_pop=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "# init the workflow\n", + "key = jax.random.PRNGKey(42)\n", + "state = workflow.init(key)\n", + "\n", + "# run the workflow for 100 steps\n", + "for i in range(100):\n", + " state = workflow.step(state)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array(0., dtype=float32)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "monitor.get_best_fitness()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([-4.0062014e-07, 5.2837186e-07], dtype=float32)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "monitor.get_best_solution()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/source/index.md b/docs/source/index.md index 9d83cd394..88f062510 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -6,6 +6,7 @@ User Guide API reference +Examples ``` ```{eval-rst} diff --git a/examples/cmaes_classic_control.py b/examples/cmaes_classic_control.py deleted file mode 100644 index 7cc3a01c7..000000000 --- a/examples/cmaes_classic_control.py +++ /dev/null @@ -1,76 +0,0 @@ -from evox import workflows, algorithms, problems -from evox.monitors import StdSOMonitor -from evox.utils import TreeAndVector -import jax -import jax.numpy as jnp -from flax import linen as nn - -# change here only -gym_name = "Pendulum-v1" - - -def tanh2(x): - return 2 * nn.tanh(x) - - -policy_params = { - "Acrobot-v1": (3, (6,), jnp.argmax), - "CartPole-v1": (2, (4,), jnp.argmax), - "MountainCarContinuous-v0": (1, (2,), nn.tanh), - "MountainCar-v0": (3, (2,), jnp.argmax), - "Pendulum-v1": (1, (3,), tanh2), -} - - -class ClassicPolicy(nn.Module): - """A simple model for Classic Control problem""" - - @nn.compact - def __call__(self, x): - x = x.at[1].multiply(10) # normalization - x = nn.Dense(16)(x) - x = nn.relu(x) - x = nn.Dense(policy_params[gym_name][0])(x) - - return policy_params[gym_name][2](x) - - -key = jax.random.PRNGKey(42) -model_key, workflow_key = jax.random.split(key) - -model = ClassicPolicy() -params = model.init(model_key, jnp.zeros(policy_params[gym_name][1])) -adapter = TreeAndVector(params) -monitor = StdSOMonitor() -problem = problems.neuroevolution.Gym( - env_name=gym_name, - policy=jax.jit(model.apply), - num_workers=16, - env_per_worker=4, - controller_options={ - "num_cpus": 0, - "num_gpus": 0, - }, - worker_options={"num_cpus": 1, "num_gpus": 1 / 16}, - batch_policy=False, -) -center = adapter.to_vector(params) -# create a workflow -workflow = workflows.StdWorkflow( - algorithm=algorithms.CMAES(init_mean=center, init_stdev=1, pop_size=64), - problem=problem, - pop_transform=adapter.batched_to_tree, - monitor=monitor, -) -# init the workflow -state = workflow.init(workflow_key) -# run the workflow for 100 steps -for i in range(100): - print(monitor.get_best_fitness()) - state = workflow.step(state) - -sample_pop, state = workflow.sample(state) -# problem._render(state.get_child_state("problem"), adapter.to_tree(sample_pop[0])) - -min_fitness = monitor.get_best_fitness() -print(min_fitness) diff --git a/examples/custom_genetic_algorithm.py b/examples/custom_genetic_algorithm.py deleted file mode 100644 index 8010f3ee6..000000000 --- a/examples/custom_genetic_algorithm.py +++ /dev/null @@ -1,56 +0,0 @@ -# An example of implementing genetic algorithm that solves OneMax problem in EvoX. -# This algorithm uses binary crossover and bitflip mutation. - -from evox import Algorithm, State, jit_class -from evox.operators import mutation, crossover, selection -from jax import random -import jax.numpy as jnp - - -@jit_class -class ExampleGA(Algorithm): - def __init__(self, pop_size, ndim, flip_prob): - super().__init__() - # those are hyperparameters that stay fixed. - self.pop_size = pop_size - self.ndim = ndim - # the probability of fliping each bit - self.flip_prob = flip_prob - - def setup(self, key): - # initialize the state - # state are mutable data like the population, offsprings - # the population is randomly initialized. - # we don't have any offspring now, but initialize it as a placeholder - # because jax want static shaped arrays. - key, subkey = random.split(key) - pop = random.uniform(subkey, (self.pop_size, self.ndim)) < 0.5 - return State( - pop=pop, - offsprings=jnp.empty((self.pop_size * 2, self.ndim)), - fit=jnp.full((self.pop_size,), jnp.inf), - key=key, - ) - - def ask(self, state): - key, mut_key, x_key = random.split(state.key, 3) - # here we do mutation and crossover (reproduction) - # for simplicity, we didn't use any mating selections - # so the offspring is twice as large as the population - offsprings = jnp.concatenate( - ( - mutation.bitflip(mut_key, state.pop, self.flip_prob), - crossover.one_point(x_key, state.pop), - ), - axis=0, - ) - # return the candidate solution and update the state - return offsprings, state.update(offsprings=offsprings, key=key) - - def tell(self, state, fitness): - # here we do selection - merged_pop = jnp.concatenate([state.pop, state.offsprings]) - merged_fit = jnp.concatenate([state.fit, fitness]) - new_pop, new_fit = selection.topk_fit(merged_pop, merged_fit, self.pop_size) - # replace the old population - return state.update(pop=new_pop, fit=new_fit) diff --git a/examples/custom_problem_onemax.py b/examples/custom_problem_onemax.py deleted file mode 100644 index ed3cccb46..000000000 --- a/examples/custom_problem_onemax.py +++ /dev/null @@ -1,25 +0,0 @@ -# An exmaple of implement custom problem in EvoX. -# Here, we implement the OneMax problem, the fitness is defined as the sum of all the digits in a bitstring. -# For example, "100111" -> 4, "000101" -> 2. -# The goal is to find the bitstring that maximize the fitness. -# Since in EvoX, algorithms try to minimize the fitness, so we return the negitive sum as our fitness. - -import jax.numpy as jnp -from evox import Problem, jit_class - - -@jit_class -class OneMax(Problem): - def __init__(self, neg_fitness=True) -> None: - super().__init__() - self.neg_fitess = neg_fitness - - def evaluate(self, state, bitstrings): - # bitstrings has shape (pop_size, num_bits) - # so sum along the axis 1. - fitness = jnp.sum(bitstrings, axis=1) - # Since in EvoX, algorithms try to minimize the fitness - # so return the negitive value. - if self.neg_fitess: - fitness = -fitness - return fitness, state diff --git a/examples/pgpe_pong.py b/examples/pgpe_pong.py deleted file mode 100644 index 875b8719c..000000000 --- a/examples/pgpe_pong.py +++ /dev/null @@ -1,80 +0,0 @@ -from evox import workflows, algorithms, problems -from evox.monitors import StdSOMonitor -from evox.utils import TreeAndVector -import jax -import jax.numpy as jnp -from flax import linen as nn - - -# https://docs.ray.io/en/latest/ray-core/examples/plot_pong_example.html -def pong_preprocess(img): - # Crop the image. - img = img[35:195] - # Downsample by factor of 2. - img = img[::2, ::2, 0] - # Erase background (background type 1 and 2). - img = jnp.where((img == 144) | (img == 109), 0, img) - # Set everything else (paddles, ball) to 1. - img = jnp.where(img != 0, 1, img) - return img - - -class PongPolicy(nn.Module): - """A simple model for cartpole""" - - @nn.compact - def __call__(self, img): - x = pong_preprocess(img) - x = x.astype(jnp.float32) - x = x.reshape(-1) - x = nn.Dense(128)(x) - x = nn.relu(x) - x = nn.Dense(6)(x) - - return jnp.argmax(x) - - -key = jax.random.PRNGKey(42) -model_key, workflow_key = jax.random.split(key) - -model = PongPolicy() -params = model.init(model_key, jnp.zeros((210, 160, 3))) -adapter = TreeAndVector(params) -monitor = StdSOMonitor() -problem = problems.neuroevolution.Gym( - env_name="ALE/Pong-v5", - env_options={"full_action_space": False}, - policy=jax.jit(model.apply), - num_workers=16, - env_per_worker=4, - controller_options={ - "num_cpus": 0, - "num_gpus": 0, - }, - worker_options={"num_cpus": 1, "num_gpus": 1 / 16}, - batch_policy=False, -) -center = adapter.to_vector(params) -# create a workflow -workflow = workflows.StdWorkflow( - algorithm=algorithms.PGPE( - optimizer="adam", - center_init=center, - pop_size=64, - ), - problem=problem, - pop_transform=adapter.batched_to_tree, - monitor=monitor, -) -# init the workflow -state = workflow.init(workflow_key) -# run the workflow for 100 steps -for i in range(10): - print(monitor.get_best_fitness()) - state = workflow.step(state) - -sample_pop, state = workflow.sample(state) -# problem._render(adapter.to_tree(sample_pop[0]), ale_render_mode="human") -# the result should be close to 0 -min_fitness = monitor.get_best_fitness() -print(min_fitness) diff --git a/examples/pso_ackley.py b/examples/pso_ackley.py deleted file mode 100644 index 20975a777..000000000 --- a/examples/pso_ackley.py +++ /dev/null @@ -1,26 +0,0 @@ -from evox import algorithms, problems, workflows -import jax -import jax.numpy as jnp - -algorithm = algorithms.PSO( - lb=jnp.full(shape=(2,), fill_value=-32), - ub=jnp.full(shape=(2,), fill_value=32), - pop_size=100, -) - -problem = problems.numerical.Ackley() - -# create a workflow - -workflow = workflows.StdWorkflow( - algorithm, - problem, -) - -# init the workflow -key = jax.random.PRNGKey(42) -state = workflow.init(key) - -# run the workflow for 100 steps -for i in range(100): - state = workflow.step(state)