Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pong results do not match paper #138

Open
George614 opened this issue Jun 5, 2024 · 5 comments
Open

Pong results do not match paper #138

George614 opened this issue Jun 5, 2024 · 5 comments

Comments

@George614
Copy link

Hi Danijar,

Thanks for sharing this amazing repo and creating a robust model-based RL algorithm! I've been playing with the replay buffer and trying to reproduce some of the results. I run the code on Pong with command python dreamerv3/main.py --logdir ./logdir/uniform_pong --configs atari --task atari_pong --run.train_ratio 32 with the default configurations on a Ubuntu 22.04 LTS with a RTX 3090 GPU. Somehow, the agent does not work on the Pong task over 400K env steps (according to the first version of the paper). I'm not sure what went wrong. I've tried with the default uniform replay (cyan curve in figure), a mixed replay (gray curve) with ratio of (0.5, 0.3, 0.2) and uniform replay with compute_dtype: float16 (magenta curve) since I've seen some warnings from CUDA and XLA.
Screenshot 2024-06-04 203650

Here are the package versions that I installed:

python 3.11.9
jax 0.4.28
jax-cuda12-pjrt 0.4.28
jax-cuda12-plugin 0.4.28
jaxlib 0.4.28+cuda12.cudnn89
ale-py 0.8.1
gymnasium 0.29.1
tensorflow-cpu 2.16.1
tensorflow-probability 0.24.0

Please let me know if anything was not set up properly. Thank you!

@IcarusWizard
Copy link

As far as I know, the run.train_ratio should be 1024 for Atari100k.

@NonsansWD
Copy link

Hey,
First of all i think the first comment on this is right, you should increase the train_ratio. That was confusing for me too at first but that should solve the issue. Quick off topic question tho: I see you are running pretty recent versions of tensorflow-cpu as well as jax. Did u run into any issues where the pip installation stated that jax requires mldtype >= 4.0 and tensorflow requires that library to be version 3.2?

@George614
Copy link
Author

As far as I know, the run.train_ratio should be 1024 for Atari100k.

Thanks @IcarusWizard and @NonsansWD I'll try your suggestion!

@George614
Copy link
Author

Hey, First of all i think the first comment on this is right, you should increase the train_ratio. That was confusing for me too at first but that should solve the issue. Quick off topic question tho: I see you are running pretty recent versions of tensorflow-cpu as well as jax. Did u run into any issues where the pip installation stated that jax requires mldtype >= 4.0 and tensorflow requires that library to be version 3.2?

I have not run into that particular issue. I'd suggest that you install tensorflow-cpu first (maybe a less recent version) then install JAX.

@NonsansWD
Copy link

Hey, First of all i think the first comment on this is right, you should increase the train_ratio. That was confusing for me too at first but that should solve the issue. Quick off topic question tho: I see you are running pretty recent versions of tensorflow-cpu as well as jax. Did u run into any issues where the pip installation stated that jax requires mldtype >= 4.0 and tensorflow requires that library to be version 3.2?

I have not run into that particular issue. I'd suggest that you install tensorflow-cpu first (maybe a less recent version) then install JAX.

Alright good to know. In the end i was able to fix my issue and everything works fine. The only problem im left with is i just realized that the resulting folder called "replay" does not contain raw frames but instead a lot of data like rewards and so on. Do you by any chance know a way of obtaining a video of the agents steps or something so i can watch it do its stuff without too much effort? I feel like im missing something cause i also dont know where to get these wonderful score plots or do i have to construct that plot myself with matplotlib? sorry for going off topic

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants