diff --git a/torchrl/envs/model_based/dreamer.py b/torchrl/envs/model_based/dreamer.py index 9228c39aa66..eaeac5dadfd 100644 --- a/torchrl/envs/model_based/dreamer.py +++ b/torchrl/envs/model_based/dreamer.py @@ -15,7 +15,13 @@ class DreamerEnv(ModelBasedEnvBase): - """Dreamer simulation environment.""" + """Dreamer simulation environment. + + This environment is used for imagination rollouts in Dreamer training. + It never terminates (done is always False) since imagination runs for a + fixed horizon. The done-checking methods are overridden to avoid CUDA + synchronization overhead from Python control flow on CUDA tensors. + """ def __init__( self, @@ -26,11 +32,31 @@ def __init__( device: DEVICE_TYPING = "cpu", batch_size: torch.Size | None = None, ): - super().__init__(world_model, device=device, batch_size=batch_size) + super().__init__( + world_model, + device=device, + batch_size=batch_size, + # Skip done validation in reset() — imagination never terminates. + allow_done_after_reset=True, + ) self.obs_decoder = obs_decoder self.prior_shape = prior_shape self.belief_shape = belief_shape + def any_done(self, tensordict) -> bool: + """Returns False — imagination rollouts never terminate. + + Overridden to avoid CUDA sync from `done.any()` in parent class. + """ + return False + + def maybe_reset(self, tensordict): + """No-op — imagination rollouts don't need partial resets. + + Overridden to avoid CUDA sync from done checks in parent class. + """ + return tensordict + def set_specs_from_env(self, env: EnvBase): """Sets the specs of the environment from the specs of the given environment.""" super().set_specs_from_env(env)