diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 5efa0592068..19deb33546f 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -22,7 +22,6 @@ from torch import nn, vmap from torchrl._utils import implement_for, logger, RL_WARNINGS, seed_generator -from torchrl.data.utils import CloudpickleWrapper from torchrl.envs import MultiThreadedEnv, ObservationNorm from torchrl.envs.batched_envs import ParallelEnv, SerialEnv from torchrl.envs.libs.envpool import _has_envpool @@ -594,14 +593,6 @@ def check_rollout_consistency_multikey_env(td: TensorDict, max_steps: int): assert (td["next", "nested_2", "reward"][~action_is_count] == 0).all() -def decorate_thread_sub_func(func, num_threads): - def new_func(*args, **kwargs): - assert torch.get_num_threads() == num_threads - return func(*args, **kwargs) - - return CloudpickleWrapper(new_func) - - class LSTMNet(nn.Module): """An embedder for an LSTM preceded by an MLP. diff --git a/test/test_collectors.py b/test/test_collectors.py index 1c52fb37615..86a5aa25848 100644 --- a/test/test_collectors.py +++ b/test/test_collectors.py @@ -89,6 +89,7 @@ SafeModule, ) from torchrl.testing.modules import BiasModule, NonSerializableBiasModule +from torchrl.testing.mp_helpers import decorate_thread_sub_func from torchrl.weight_update import ( MultiProcessWeightSyncScheme, SharedMemWeightSyncScheme, @@ -99,7 +100,6 @@ from pytorch.rl.test._utils_internal import ( CARTPOLE_VERSIONED, check_rollout_consistency_multikey_env, - decorate_thread_sub_func, generate_seeds, get_available_devices, get_default_devices, @@ -112,7 +112,6 @@ from _utils_internal import ( CARTPOLE_VERSIONED, check_rollout_consistency_multikey_env, - decorate_thread_sub_func, generate_seeds, get_available_devices, get_default_devices, diff --git a/torchrl/testing/mp_helpers.py b/torchrl/testing/mp_helpers.py new file mode 100644 index 00000000000..ffc1b3c5a20 --- /dev/null +++ b/torchrl/testing/mp_helpers.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +import torch + +from torchrl.data.utils import CloudpickleWrapper + + +def decorate_thread_sub_func(func, num_threads): + """Decorate a function to assert that the number of threads is correct.""" + + def new_func(*args, **kwargs): + assert torch.get_num_threads() == num_threads + return func(*args, **kwargs) + + return CloudpickleWrapper(new_func)