Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplementation in RL4CO #58

Open
fedebotu opened this issue Oct 3, 2023 · 5 comments
Open

Reimplementation in RL4CO #58

fedebotu opened this issue Oct 3, 2023 · 5 comments

Comments

@fedebotu
Copy link

fedebotu commented Oct 3, 2023

Hi there 👋🏼

First of all, thanks a lot for your library, it has inspired several works in our research group!
We are actively developing RL4CO, a library for all things Reinforcement Learning for Combinatorial Optimization. We started the library by modularizing the Attention Model, which is the basis for several other autoregressive models. We also used some recent software (such as TorchRL, TensorDict, PyTorch Lightning and Hydra) as well as routines such as FlashAttention, and made everything as easy to use as possible in the hope of helping practitioners and researchers.

We welcome you to check RL4CO out ^^

@wouterkool
Copy link
Owner

Hi! Thanks for bringing it to my attention, I will definitely check it out! Are you able to reproduce the results from the paper with your implementation (training time, evaluation time and performance?).

@fedebotu
Copy link
Author

fedebotu commented Oct 5, 2023

Thanks for your quick answer 🚀

  • In terms of performance: yes! While reimplementing the AM we carefully checked whether everything was working step-by-step; in the exact same settings as the original code (same parameters, learning rate, batch sizes, validation datasets...) we actually obtained slightly better performance:
    image
    image
    with the only difference being the MHA module with the linear bias set to true - the bump in the performance of the last 20 epochs is a step of MultiStepLR scheduler.
  • In terms of training time: this depends on the implementation. To be honest, your code was already really well-optimized although it has been around for a few years! In theory, we should get similar / slightly better performance since we use FlashAttention in the encoder. In practice, however, we found that we are slightly slower with the same hyperparameters (on a single RTX 3090 we can train your original AM for TSP50 in 14 hours and ours in 15 hours). The reason is that we use TensorDicts and TorchRL environments - we found the "culprits" are data loading and most importantly the fact that we need to re-create on the fly TensorDicts with large batch sizes at each step, which has a pretty big impact on performance. One simple way to solve this is to just refactor environments in pure PyTorch. Right now we are in contact with the TorchRL team and there have been several updates (just a few hours ago, the stable TensorDict 0.2.0 was released), so we plan to solve this in the next couple of weeks ;) There is another trick to noticeably improve performance. Setting mask_inner to False enables the FlashAttention routine during decoding (so for each step) and the training above takes 12.5 hours, even with the current slow TensorDict problem! This of course, does degrade performance, but FlashAttention with masking may be added in the near future, so it holds good promise!
  • In terms of evaluation time: we found our evaluation script to be slightly faster - but if we check a single forward pass of the model, we get a similar trend as above, so we think there may have been some bottleneck in the eval script.

We would be more than happy to address your feedback if you check out RL4CO, you may contact us any time 😄

@fedebotu
Copy link
Author

(Late) edit: now we are way more efficient as explained here!

@wouterkool
Copy link
Owner

Great! I have added a link in the readme. However, I wonder if you have also had a look at https://github.com/cpwan/RLOR, they claim a 8x speedup over this repo using PPO.

@fedebotu
Copy link
Author

Yes, we are aware of it!
From our understanding and our testing, their speedup is actually considered as the time to reach a certain cost as seen in Table 4. AM trains TSP 50 reaching 5.80 in 24 hours, while their PPO implementation - with some training and testing tricks - takes 3 hours.
So, it is not a speedup per se - actually, due to the environment being in Numpy, even though vectorized, data collection is naturally a bottleneck - but rather, the time to reach a target performance. Besides, the comparison is made with their AM trained with PPO with larger batch size and learning rate and tested with the "multi-greedy" decoding scheme during inference (what in RL4CO we call multistart, i.e., the POMO decoding scheme that starts decoding from all nodes and then takes the best trajectory), while the baseline AM is just evaluated with one-shot greedy decoding.
For these reasons, we think the 8x speedup claim is a bit overstated 👀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants