Skip to content

sai-prasanna/dreamerv2_torch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Mastering Atari with Discrete World Models (Pytorch)

Implementation of the DreamerV2 agent in pytorch. It mirrors the original tensorflow implementaition in it's structure for the most part. We use major parts of dreamer-torch implementation. It reaches similar performance curves for few environments I tested for 100k (cartpole, cheetah). Haven't tested Plan2Explore yet, but it's included.

If you find this code useful, please reference in your paper:

@article{hafner2020dreamerv2,
  title={Mastering Atari with Discrete World Models},
  author={Hafner, Danijar and Lillicrap, Timothy and Norouzi, Mohammad and Ba, Jimmy},
  journal={arXiv preprint arXiv:2010.02193},
  year={2020}
}
@misc{dreamerv2_torch,
  author = {Sai Prasanna},
  title = {Dreamerv2 Pytorch Implementation},
  year = {2022},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/sai-prasanna/dreamerv2_torch}},
}

Method

DreamerV2 is the first world model agent that achieves human-level performance on the Atari benchmark. DreamerV2 also outperforms the final performance of the top model-free agents Rainbow and IQN using the same amount of experience and computation. The implementation in this repository alternates between training the world model, training the policy, and collecting experience and runs on a single GPU.

World Model Learning

DreamerV2 learns a model of the environment directly from high-dimensional input images. For this, it predicts ahead using compact learned states. The states consist of a deterministic part and several categorical variables that are sampled. The prior for these categoricals is learned through a KL loss. The world model is learned end-to-end via straight-through gradients, meaning that the gradient of the density is set to the gradient of the sample.

Actor Critic Learning

DreamerV2 learns actor and critic networks from imagined trajectories of latent states. The trajectories start at encoded states of previously encountered sequences. The world model then predicts ahead using the selected actions and its learned state prior. The critic is trained using temporal difference learning and the actor is trained to maximize the value function via reinforce and straight-through gradients.

For more information:

Using the Package

The easiest way to run DreamerV2 on new environments is to install the package via pip3 install git+https://github.com/sai-prasanna/dreamerv2_torch.git. The code automatically detects whether the environment uses discrete or continuous actions. Here is a usage example that trains DreamerV2 on the MiniGrid environment:

import gym
import gym_minigrid
import dreamerv2_torch.api as dv2

config = dv2.defaults.update({
    'logdir': '~/logdir/minigrid',
    'log_every': 1e3,
    'train_every': 10,
    'prefill': 1e5,
    'actor_ent': 3e-3,
    'loss_scales.kl': 1.0,
    'discount': 0.99,
}).parse_flags()

env = gym.make('MiniGrid-DoorKey-6x6-v0')
env = gym_minigrid.wrappers.RGBImgPartialObsWrapper(env)
dv2.train(env, config)

Manual Instructions

To modify the DreamerV2 agent, clone the repository and follow the instructions below. There is also a Dockerfile available, in case you do not want to install the dependencies on your system.

Get dependencies:

pip3 install torch ruamel.yaml 'gym[atari]' dm_control

Train on Atari:

python3 dreamerv2_torch/train.py --logdir ~/logdir/atari_pong/dreamerv2_torch/1 \
  --configs atari --task atari_pong

Train on DM Control:

python3 dreamerv2_torch/train.py --logdir ~/logdir/dmc_walker_walk/dreamerv2_torch/1 \
  --configs dmc_vision --task dmc_walker_walk

Monitor results:

tensorboard --logdir ~/logdir

Generate plots:

python3 common/plot.py --indir ~/logdir --outdir ~/plots \
  --xaxis step --yaxis eval_return --bins 1e6

Docker Instructions

The Dockerfile lets you run DreamerV2 without installing its dependencies in your system. This requires you to have Docker with GPU access set up.

Check your setup:

docker run -it --rm --gpus all tensorflow/tensorflow:2.4.2-gpu nvidia-smi

Train on Atari:

docker build -t dreamerv2 .
docker run -it --rm --gpus all -v ~/logdir:/logdir dreamerv2 \
  python3 dreamerv2_torch/train.py --logdir /logdir/atari_pong/dreamerv2_torch/1 \
    --configs atari --task atari_pong

Train on DM Control:

docker build -t dreamerv2 . --build-arg MUJOCO_KEY="$(cat ~/.mujoco/mjkey.txt)"
docker run -it --rm --gpus all -v ~/logdir:/logdir dreamerv2 \
  python3 dreamerv2_torch/train.py --logdir /logdir/dmc_walker_walk/dreamerv2_torch/1 \
    --configs dmc_vision --task dmc_walker_walk

Tips

  • Efficient debugging. You can use the debug config as in --configs atari debug. This reduces the batch size, increases the evaluation frequency, and disables tf.function graph compilation for easy line-by-line debugging.

  • Infinite gradient norms. This is normal and described under loss scaling in the mixed precision guide. You can disable mixed precision by passing --precision 32 to the training script. Mixed precision is faster but can in principle cause numerical instabilities.

  • Accessing logged metrics. The metrics are stored in both TensorBoard and JSON lines format. You can directly load them using pandas.read_json(). The plotting script also stores the binned and aggregated metrics of multiple runs into a single JSON file for easy manual plotting.

Differences with Dreamerv2 official

For dmc_vision tasks, we set rssm: {hidden: 200, deter: 200, stoch: 50, discrete: 0} as the rssm settings. This makes the stochastic state space continous normal. The defaults in dreamerv2 official chooses rssm: {hidden: 200, deter: 200} as the override, and leaves the stoch: 32, discrete: 32 of defaults untouched. This sets the stochastic state to be a set of 32 one hot distributions.

About

Implementation of the DreamerV2 agent in torch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published