|
37 | 37 | _check_for_faulty_process, |
38 | 38 | _make_ordinal_device, |
39 | 39 | logger as torchrl_logger, |
| 40 | + rl_warnings, |
40 | 41 | VERBOSE, |
41 | 42 | ) |
42 | 43 | from torchrl.data.tensor_specs import Composite, NonTensor |
@@ -152,29 +153,50 @@ def __call__(cls, *args, **kwargs): |
152 | 153 | # Wrap lambda functions with EnvCreator so they can be pickled for |
153 | 154 | # multiprocessing with the spawn start method. Lambda functions cannot |
154 | 155 | # be serialized with standard pickle, but EnvCreator uses cloudpickle. |
| 156 | + auto_wrap_envs = kwargs.pop("auto_wrap_envs", True) |
| 157 | + |
| 158 | + def _warn_lambda(): |
| 159 | + if rl_warnings(): |
| 160 | + warnings.warn( |
| 161 | + "A lambda function was passed to ParallelEnv and will be wrapped " |
| 162 | + "in an EnvCreator. This causes the environment to be instantiated " |
| 163 | + "in the main process to extract metadata. Consider using " |
| 164 | + "functools.partial instead, which is natively serializable and " |
| 165 | + "avoids this overhead. To suppress this warning, set the " |
| 166 | + "RL_WARNINGS=0 environment variable.", |
| 167 | + category=UserWarning, |
| 168 | + stacklevel=4, |
| 169 | + ) |
| 170 | + |
155 | 171 | def _wrap_lambdas(create_env_fn): |
156 | 172 | if callable(create_env_fn) and _is_unpicklable_lambda(create_env_fn): |
| 173 | + _warn_lambda() |
157 | 174 | return EnvCreator(create_env_fn) |
158 | 175 | if isinstance(create_env_fn, Sequence): |
159 | 176 | # Reuse EnvCreator for identical function objects to preserve |
160 | 177 | # _single_task detection (e.g., when [lambda_fn] * 3 is passed) |
161 | 178 | wrapped = {} |
162 | 179 | result = [] |
| 180 | + warned = False |
163 | 181 | for fn in create_env_fn: |
164 | 182 | if _is_unpicklable_lambda(fn): |
165 | 183 | fn_id = id(fn) |
166 | 184 | if fn_id not in wrapped: |
| 185 | + if not warned: |
| 186 | + _warn_lambda() |
| 187 | + warned = True |
167 | 188 | wrapped[fn_id] = EnvCreator(fn) |
168 | 189 | result.append(wrapped[fn_id]) |
169 | 190 | else: |
170 | 191 | result.append(fn) |
171 | 192 | return result |
172 | 193 | return create_env_fn |
173 | 194 |
|
174 | | - if "create_env_fn" in kwargs: |
175 | | - kwargs["create_env_fn"] = _wrap_lambdas(kwargs["create_env_fn"]) |
176 | | - elif len(args) >= 2: |
177 | | - args = (args[0], _wrap_lambdas(args[1])) + args[2:] |
| 195 | + if auto_wrap_envs: |
| 196 | + if "create_env_fn" in kwargs: |
| 197 | + kwargs["create_env_fn"] = _wrap_lambdas(kwargs["create_env_fn"]) |
| 198 | + elif len(args) >= 2: |
| 199 | + args = (args[0], _wrap_lambdas(args[1])) + args[2:] |
178 | 200 |
|
179 | 201 | return super().__call__(*args, **kwargs) |
180 | 202 |
|
@@ -241,6 +263,20 @@ class BatchedEnvBase(EnvBase): |
241 | 263 | daemon (bool, optional): whether the processes should be daemonized. |
242 | 264 | This is only applicable to parallel environments such as :class:`~torchrl.envs.ParallelEnv`. |
243 | 265 | Defaults to ``False``. |
| 266 | + auto_wrap_envs (bool, optional): if ``True`` (default), lambda functions passed as |
| 267 | + ``create_env_fn`` will be automatically wrapped in an :class:`~torchrl.envs.EnvCreator` |
| 268 | + to enable pickling for multiprocessing with the ``spawn`` start method. |
| 269 | + This wrapping causes the environment to be instantiated once in the main process |
| 270 | + (to extract metadata) before workers are started. |
| 271 | + If this is undesirable, set ``auto_wrap_envs=False``. Otherwise, ensure your callable is |
| 272 | + serializable (e.g., use :func:`functools.partial` instead of lambdas). |
| 273 | + This parameter only affects :class:`~torchrl.envs.ParallelEnv`. |
| 274 | + Defaults to ``True``. |
| 275 | +
|
| 276 | + .. note:: |
| 277 | + For :class:`~torchrl.envs.ParallelEnv`, it is recommended to use :func:`functools.partial` |
| 278 | + instead of lambda functions when possible, as ``partial`` objects are natively serializable |
| 279 | + and avoid the overhead of :class:`~torchrl.envs.EnvCreator` wrapping. |
244 | 280 |
|
245 | 281 | .. note:: |
246 | 282 | One can pass keyword arguments to each sub-environments using the following |
|
0 commit comments