diff --git a/test/test_envs.py b/test/test_envs.py index 7c9a619e7d5..92b3f67d11f 100644 --- a/test/test_envs.py +++ b/test/test_envs.py @@ -179,7 +179,6 @@ def check_no_lingering_multiprocessing_resources(request): _make_envs, CARTPOLE_VERSIONED, check_rollout_consistency_multikey_env, - decorate_thread_sub_func, get_default_devices, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, @@ -191,7 +190,6 @@ def check_no_lingering_multiprocessing_resources(request): _make_envs, CARTPOLE_VERSIONED, check_rollout_consistency_multikey_env, - decorate_thread_sub_func, get_default_devices, HALFCHEETAH_VERSIONED, PENDULUM_VERSIONED, @@ -3405,25 +3403,23 @@ class TestLibThreading: ) def test_num_threads(self): gc.collect() - from torchrl.envs import batched_envs - - _run_worker_pipe_shared_mem_save = batched_envs._run_worker_pipe_shared_mem - batched_envs._run_worker_pipe_shared_mem = decorate_thread_sub_func( - batched_envs._run_worker_pipe_shared_mem, num_threads=3 - ) num_threads = torch.get_num_threads() try: - env = ParallelEnv( - 2, ContinuousActionVecMockEnv, num_sub_threads=3, num_threads=7 - ) + # Wrap the env factory to check thread count inside the subprocess. + # The env is created AFTER torch.set_num_threads() is called in the worker. + def make_env(): + assert ( + torch.get_num_threads() == 3 + ), f"Expected 3 threads, got {torch.get_num_threads()}" + return ContinuousActionVecMockEnv() + + env = ParallelEnv(2, make_env, num_sub_threads=3, num_threads=7) # We could test that the number of threads isn't changed until we start the procs. # Even though it's unlikely that we have 7 threads, we still disable this for safety # assert torch.get_num_threads() != 7 env.rollout(3) assert torch.get_num_threads() == 7 finally: - # reset vals - batched_envs._run_worker_pipe_shared_mem = _run_worker_pipe_shared_mem_save torch.set_num_threads(num_threads) @pytest.mark.skipif(