-
Notifications
You must be signed in to change notification settings - Fork 430
Description
Describe the bug
The issue is luckily quite simple: I am unable to register save/load hooks to replay buffers, despite the straightforward implementation. It looks obvious, but I couldn't find any mention of this issue anywhere, despite seeing a discrepancy between the documentation and the implementation of hooks.
To Reproduce
This error manifests itself even in the example code provided with TED2Flat
import tempfile
from tensordict import TensorDict
from torchrl.collectors import SyncDataCollector
from torchrl.data import ReplayBuffer, TED2Flat, LazyMemmapStorage
from torchrl.envs import GymEnv
import torch
def main():
env = GymEnv("CartPole-v1")
env.set_seed(0)
torch.manual_seed(0)
collector = SyncDataCollector(env, policy=env.rand_step, total_frames=200, frames_per_batch=200)
rb = ReplayBuffer(storage=LazyMemmapStorage(200))
rb.register_save_hook(TED2Flat())
with tempfile.TemporaryDirectory() as tmpdir:
for i, data in enumerate(collector):
rb.extend(data)
rb.dumps(tmpdir)
# load the data to represent it
td = TensorDict.load(tmpdir + "/storage/")
print(td)
if __name__ == "__main__":
main()Traceback (most recent call last):
File "...", line 25, in <module>
main()
File "...", line 14, in main3
rb.register_save_hook(TED2Flat())
File "...", line 702, in register_save_hook
self._storage.register_save_hook(hook)
AttributeError: 'LazyMemmapStorage' object has no attribute 'register_save_hook'Expected behavior
One would expect the hook to register, although a simple inspection of the source code gives that the ReplayBuffer delegates the registration of hooks onto the storages:
def register_save_hook(self, hook: Callable[[Any], Any]):
"""Registers a save hook for the storage.
.. note:: Hooks are currently not serialized when saving a replay buffer: they must
be manually re-initialized every time the buffer is created.
"""
self._storage.register_save_hook(hook)However, there is no mention of hooks in any Storage implementation, unless it's implemented some other way I missed.
Screenshots
N/A
System info
Everything is done in a Conda environment, with dependencies installed through pip inside it, running on a Ubuntu 20.04 machine.
Python ver: 3.10
torchrl version as seen via pip: 0.10.1 (but note that torchrl.version returns 0.0.0+unknown)
Reason and Possible fixes
Again, unless there is a non-obvious way in which the hook registration process is done in Storages that I missed, there simply is no registration of hooks in any Storage implementation.
Checklist
- I have checked that there is no similar issue in the repo (required) (somehow)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)