Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions torchrl/envs/model_based/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@


class DreamerEnv(ModelBasedEnvBase):
"""Dreamer simulation environment."""
"""Dreamer simulation environment.

This environment is used for imagination rollouts in Dreamer training.
It never terminates (done is always False) since imagination runs for a
fixed horizon. The done-checking methods are overridden to avoid CUDA
synchronization overhead from Python control flow on CUDA tensors.
"""

def __init__(
self,
Expand All @@ -26,11 +32,31 @@ def __init__(
device: DEVICE_TYPING = "cpu",
batch_size: torch.Size | None = None,
):
super().__init__(world_model, device=device, batch_size=batch_size)
super().__init__(
world_model,
device=device,
batch_size=batch_size,
# Skip done validation in reset() — imagination never terminates.
allow_done_after_reset=True,
)
self.obs_decoder = obs_decoder
self.prior_shape = prior_shape
self.belief_shape = belief_shape

def any_done(self, tensordict) -> bool:
"""Returns False — imagination rollouts never terminate.

Overridden to avoid CUDA sync from `done.any()` in parent class.
"""
return False

def maybe_reset(self, tensordict):
"""No-op — imagination rollouts don't need partial resets.

Overridden to avoid CUDA sync from done checks in parent class.
"""
return tensordict

def set_specs_from_env(self, env: EnvBase):
"""Sets the specs of the environment from the specs of the given environment."""
super().set_specs_from_env(env)
Expand Down
Loading