diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index 171ea068428..72b6695eb2f 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -117,11 +117,13 @@ def __init__( device: DEVICE_TYPING = "cpu", batch_size: torch.Size | None = None, run_type_checks: bool = False, + allow_done_after_reset: bool = False, ): super().__init__( device=device, batch_size=batch_size, run_type_checks=run_type_checks, + allow_done_after_reset=allow_done_after_reset, ) self.world_model = world_model.to(self.device) self.world_model_params = params