diff --git a/test/test_envs.py b/test/test_envs.py index e7a4337229c..704b389e133 100644 --- a/test/test_envs.py +++ b/test/test_envs.py @@ -3418,13 +3418,18 @@ class TestLibThreading: def test_num_threads(self): gc.collect() num_threads = torch.get_num_threads() + main_pid = os.getpid() try: # Wrap the env factory to check thread count inside the subprocess. # The env is created AFTER torch.set_num_threads() is called in the worker. + # Note: the factory is also called in the main process to get metadata, + # so we only check thread count when running in a subprocess. def make_env(): - assert ( - torch.get_num_threads() == 3 - ), f"Expected 3 threads, got {torch.get_num_threads()}" + if os.getpid() != main_pid: + # Only check thread count in subprocess, not during metadata extraction + assert ( + torch.get_num_threads() == 3 + ), f"Expected 3 threads, got {torch.get_num_threads()}" return ContinuousActionVecMockEnv() env = ParallelEnv(2, make_env, num_sub_threads=3, num_threads=7)