Skip to content

Commit

Permalink
Add muzero unplugged codes (#10)
Browse files Browse the repository at this point in the history
* Update readme and dependency

* Add MCTS to benchmark MZU

* Minor addition

* Include search depth into cfg

* update

* lint
  • Loading branch information
lkevinzc committed Jul 17, 2023
1 parent 0f6ea65 commit 42e7800
Show file tree
Hide file tree
Showing 12 changed files with 368 additions and 176 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ jobs:
python-version: 3.8.13
- uses: ./.github/actions/cache
- name: Install
run: pip install -e .
run: |
pip install -e .
pip install dopamine-rl==3.1.2
pip install chex==0.1.5
- name: Lint
run: make checks
23 changes: 19 additions & 4 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,24 @@
2. Clone this repository and install it in develop mode:
```console
pip install -e .

# Install the following packages separately due to version conflicts.
pip install dopamine-rl==3.1.2
pip install chex==0.1.5
```
3. [Install the ROM for Atari](https://github.com/openai/atari-py#roms).
4. (Optional) Download **BSuite** [datasets](https://drive.google.com/file/d/1FWexoOphUgBaWTWtY9VR43N90z9A6FvP/view?usp=sharing) if you are running BSuite experiments; **Atari** datasets will be automatically downloaded from [TFDS](https://www.tensorflow.org/datasets/catalog/rlu_atari). The dataset path is defined in `experiment/*/config.py`.
4. Download dataset:
1. **BSuite** datasets ([drive](https://drive.google.com/file/d/1FWexoOphUgBaWTWtY9VR43N90z9A6FvP/view?usp=sharing)) if you are running BSuite experiments;
2. **Atari** datasets will be automatically downloaded from [TFDS](https://www.tensorflow.org/datasets/catalog/rlu_atari) when starting the experiment. The dataset path is defined in `experiment/*/config.py`. Or you could also download it using the following script:
```
from rosmo.data.rlu_atari import create_atari_ds_loader
create_atari_ds_loader(
env_name="Pong", # Change this.
run_number=1, # Fix this.
dataset_dir="/path/to/download",
)
```

### TPU

Expand All @@ -31,12 +46,12 @@ pip install "jax[tpu]==0.3.6" -f https://storage.googleapis.com/jax-releases/lib

### GPU

We also conducted verification experiments on 4 Tesla-V100 GPUs to ensure our algorithm's reproducibility on different platforms. To install the same version of Jax as ours:
We also conducted verification experiments on 4 Tesla-A100 GPUs to ensure our algorithm's reproducibility on different platforms. To install the same version of Jax as ours:
```console
pip uninstall jax jaxlib libtpu-nightly libtpu -y

# jax-0.3.6 jaxlib-0.3.5+cuda11.cudnn82
pip install --upgrade "jax[cuda]==0.3.6" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# jax-0.3.25 jaxlib-0.3.25+cuda11.cudnn82
pip install --upgrade "jax[cuda]==0.3.25" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

### Test
Expand Down
18 changes: 13 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,16 @@

## Introduction

This repository contains the implementation of ROSMO, a **R**egularized **O**ne-**S**tep **M**odel-based algorithm for **O**ffline-RL, introduced in our paper "Efficient Offline Policy Optimization with a Learned Model". We provide the training codes for both Atari and BSuite experiments, and have made the reproduced results publicly available at [W&B](https://wandb.ai/lkevinzc/rosmo).
This repository contains the implementation of ROSMO, a **R**egularized **O**ne-**S**tep **M**odel-based algorithm for **O**ffline-RL, introduced in our paper "Efficient Offline Policy Optimization with a Learned Model". We provide the training codes for both Atari and BSuite experiments, and have made the reproduced results on `Atari MsPacman` publicly available at [W&B](https://wandb.ai/lkevinzc/rosmo-public).

## Installation
Please follow the [installation guide](INSTALL.md).

## Usage
### BSuite

To run the BSuite experiments, please ensure you have downloaded the [datasets](https://drive.google.com/file/d/1FWexoOphUgBaWTWtY9VR43N90z9A6FvP/view?usp=sharing) and placed them at the directory defined by `CONFIG.data_dir` in `experiment/bsuite/config.py`.

1. Debug run.
```console
python experiment/bsuite/main.py -exp_id test -env cartpole
Expand All @@ -56,13 +58,19 @@ python experiment/bsuite/main.py -exp_id test -env cartpole -nodebug -use_wb -us

### Atari

1. Train with exact policy target.
The following commands are examples to train 1) a ROSMO agent, 2) its sampling variant, and 3) a MZU agent on the game `MsPacman`.

1. Train ROSMO with exact policy target.
```console
python experiment/atari/main.py -exp_id rosmo -env MsPacman -nodebug -use_wb -user ${WB_USER}
```
2. Train ROSMO with sampled policy target (N=4).
```console
python experiment/atari/main.py -exp_id test -env MsPacman -nodebug -use_wb -user ${WB_USER}
python experiment/atari/main.py -exp_id rosmo-sample-4 -sampling -env MsPacman -nodebug -use_wb -user ${WB_USER}
```
2. Train with sampled policy target (N=4).
1. Train MuZero unplugged for benchmark (N=20).
```console
python experiment/atari/main.py -exp_id test-sample-4 -sampling -env MsPacman -nodebug -use_wb -user ${WB_USER}
python experiment/atari/main.py -exp_id mzu-sample-20 -algo mzu -num_simulations 20 -env MsPacman -nodebug -use_wb -user ${WB_USER}
```

## Citation
Expand Down
2 changes: 2 additions & 0 deletions experiment/atari/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ def get_config(game_name: str) -> Dict:
config["seed"] = FLAGS.seed
config["benchmark"] = "atari"
config["sampling"] = FLAGS.sampling
config["mcts"] = FLAGS.algo == "mzu"
config["game_name"] = game_name
config["num_simulations"] = FLAGS.num_simulations
config["search_depth"] = FLAGS.search_depth or FLAGS.num_simulations
config["batch_size"] = 16 if FLAGS.debug else config["batch_size"]
exp_full_name = f"{FLAGS.exp_id}_{game_name}_" + generate_id()
config["exp_full_name"] = exp_full_name
Expand Down
17 changes: 12 additions & 5 deletions experiment/atari/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Atari experiment entry."""

import os
import pickle
import random
Expand Down Expand Up @@ -57,7 +58,13 @@

flags.DEFINE_boolean("sampling", False, "Whether to sample policy target.")
flags.DEFINE_integer("num_simulations", 4, "Simulation budget.")

flags.DEFINE_enum("algo", "rosmo", ["rosmo", "mzu"], "Algorithm to use.")
flags.DEFINE_integer(
"search_depth",
0,
"Depth of Monte-Carlo Tree Search (only for mzu), \
defaults to num_simulations.",
)

# ===== Learner. ===== #
def get_learner(config, networks, data_iterator, logger) -> RosmoLearner:
Expand All @@ -71,7 +78,7 @@ def get_learner(config, networks, data_iterator, logger) -> RosmoLearner:
return learner


# ===== Eval Actor-Env Loop. ===== #
# ===== Eval Actor-Env Loop & Observer. ===== #
def get_actor_env_eval_loop(
config, networks, environment, observers, logger
) -> Tuple[RosmoEvalActor, EnvironmentLoop]:
Expand All @@ -98,7 +105,7 @@ def get_env_loop_observers() -> List[ExtendedEnvLoopObserver]:
return observers


# ===== Environment & Dataloader ===== #
# ===== Environment & Dataloader. ===== #
def get_env_data_loader(config) -> Tuple[dm_env.Environment, Iterator]:
"""Get environment and trajectory data loader."""
trajectory_length = config["unroll_steps"] + config["td_steps"] + 1
Expand Down Expand Up @@ -134,7 +141,7 @@ def transform_timesteps(steps: Dict[str, np.ndarray]) -> ActorOutput:
return environment, iterator


# ===== Network ===== #
# ===== Network. ===== #
def get_networks(config, environment) -> Networks:
"""Get environment-specific networks."""
environment_spec = make_environment_spec(environment)
Expand Down Expand Up @@ -177,7 +184,7 @@ def get_logger_fn(
def main(_):
"""Main program."""
platform = jax.lib.xla_bridge.get_backend().platform
num_devices = jax.lib.xla_bridge.device_count()
num_devices = jax.device_count()
logging.warn(f"Compute platform: {platform} with {num_devices} devices.")
logging.info(f"Debug mode: {FLAGS.debug}")
random.seed(FLAGS.seed)
Expand Down
1 change: 1 addition & 0 deletions experiment/bsuite/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def get_config(game_name: str) -> Dict:
config = deepcopy(CONFIG)
config["seed"] = FLAGS.seed
config["benchmark"] = "bsuite"
config["mcts"] = FLAGS.algo == "mzu"
config["game_name"] = game_name
config["batch_size"] = 16 if FLAGS.debug else config.batch_size
exp_full_name = f"{FLAGS.exp_id}_{game_name}_" + generate_id()
Expand Down
2 changes: 1 addition & 1 deletion experiment/bsuite/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def main(_):
np.random.seed(FLAGS.seed)

platform = jax.lib.xla_bridge.get_backend().platform
num_devices = jax.lib.xla_bridge.device_count()
num_devices = jax.device_count()
logging.warn(f"Compute platform: {platform} with {num_devices} devices.")

# ===== Setup. ===== #
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ classifiers = [
dependencies = [
"dm-acme==0.4.0",
"dm-launchpad-nightly==0.3.0.dev20220321",
"dm-haiku==0.0.7",
"dopamine-rl==3.1.2",
"dm-haiku==0.0.9",
"gym==0.17.2",
"gin-config==0.3.0",
"rlax==0.1.4",
Expand All @@ -42,6 +41,7 @@ dependencies = [
"mujoco-py<2.2,>=2.1",
"bsuite==0.3.5",
"viztracer==0.15.6",
"mctx==0.0.2",
]
dynamic = ["version"]

Expand Down
41 changes: 29 additions & 12 deletions rosmo/agent/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from acme import types
from acme.jax import networks as networks_lib

from rosmo.agent.learning import one_step_improve, root_unroll
from rosmo.agent.improvement_op import mcts_improve, one_step_improve
from rosmo.agent.learning import root_unroll
from rosmo.agent.network import Networks
from rosmo.agent.type import AgentOutput, Params
from rosmo.type import ActorOutput, Array
Expand All @@ -50,8 +51,10 @@ def __init__(

num_bins = config["num_bins"]
discount_factor = config["discount_factor"]
use_mcts = config["mcts"]
sampling = config.get("sampling", False)
num_simulations = config.get("num_simulations", -1)
search_depth = config.get("search_depth", num_simulations)

def root_step(
rng_key: chex.PRNGKey,
Expand All @@ -70,25 +73,39 @@ def root_step(
)
improve_key, sample_key = jax.random.split(rng_key)

agent_out: AgentOutput = jax.tree_map(
lambda t: t.squeeze(axis=0), agent_out
) # Squeeze the dummy time dimension.
if not sampling:
logging.info("[Actor] Using onestep improvement.")
improved_policy, _ = one_step_improve(
self._networks,
if use_mcts:
logging.info("[Actor] Using MCTS planning.")
mcts_out = mcts_improve(
networks,
improve_key,
params,
agent_out,
num_bins,
discount_factor,
num_simulations,
sampling,
search_depth,
)
action = mcts_out.action
else:
logging.info("[Actor] Using policy.")
improved_policy = jax.nn.softmax(agent_out.policy_logits)
action = rlax.categorical_sample(sample_key, improved_policy)
agent_out = jax.tree_map(
lambda t: t.squeeze(axis=0), agent_out
) # Squeeze the dummy time dimension.
if not sampling:
logging.info("[Actor] Using onestep improvement.")
improved_policy, _ = one_step_improve(
self._networks,
improve_key,
params,
agent_out,
num_bins,
discount_factor,
num_simulations,
sampling,
)
else:
logging.info("[Actor] Using policy.")
improved_policy = jax.nn.softmax(agent_out.policy_logits)
action = rlax.categorical_sample(sample_key, improved_policy)
return action, agent_out

def batch_step(
Expand Down
Loading

0 comments on commit 42e7800

Please sign in to comment.