diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index bd2a698ac4e..0e495be07d7 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1477,6 +1477,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda): env_fun = self.create_env_fn[idx] if not isinstance(env_fun, (EnvCreator, CloudpickleWrapper)): env_fun = CloudpickleWrapper(env_fun) + kwargs[idx].update( { "parent_pipe": parent_pipe, @@ -1486,7 +1487,7 @@ def look_for_cuda(tensor, has_cuda=has_cuda): "has_lazy_inputs": self.has_lazy_inputs, "num_threads": num_sub_threads, "non_blocking": self.non_blocking, - "filter_warnings": torchrl.filter_warnings_subprocess, + "filter_warnings": self._filter_warnings_subprocess(), } ) if self._use_buffers: @@ -1522,6 +1523,11 @@ def look_for_cuda(tensor, has_cuda=has_cuda): self.is_closed = False self.set_spec_lock_() + def _filter_warnings_subprocess(self) -> bool: + from torchrl import filter_warnings_subprocess + + return filter_warnings_subprocess + @_check_start def state_dict(self) -> OrderedDict: state_dict = OrderedDict()