diff --git a/torchrl/envs/env_creator.py b/torchrl/envs/env_creator.py index cfe2d36101c..caa2d759a8b 100644 --- a/torchrl/envs/env_creator.py +++ b/torchrl/envs/env_creator.py @@ -8,6 +8,7 @@ from collections import OrderedDict from collections.abc import Callable from multiprocessing.sharedctypes import Synchronized +from multiprocessing.synchronize import Lock, RLock import torch from tensordict import TensorDictBase @@ -141,8 +142,12 @@ def meta_data(self, value: EnvMetaData): @staticmethod def _is_mp_value(val): - - return isinstance(val, (Synchronized,)) and hasattr(val, "_obj") + if isinstance(val, (Synchronized,)) and hasattr(val, "_obj"): + return True + # Also check for lock types which need to be shared across processes + if isinstance(val, (Lock, RLock)): + return True + return False @classmethod def _find_mp_values(cls, env_or_transform, values, prefix=()):