diff --git a/test/test_transforms.py b/test/test_transforms.py index 32030687ea3..d2b4b8e3b61 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -342,7 +342,7 @@ def test_trans_serial_env_check(self): def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - maybe_fork_ParallelEnv(2, lambda: ContinuousActionVecMockEnv()), + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), BinarizeReward(), ) try: @@ -842,7 +842,7 @@ def test_trans_serial_env_check(self): def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - maybe_fork_ParallelEnv(2, lambda: ContinuousActionVecMockEnv()), + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), CatFrames(dim=-1, N=3, in_keys=["observation"]), ) try: @@ -1481,7 +1481,7 @@ def test_trans_parallel_env_check(self, model, device): tensor_pixels_keys=tensor_pixels_key, ) transformed_env = TransformedEnv( - ParallelEnv(2, lambda: DiscreteActionConvMockEnvNumpy().to(device)), r3m + ParallelEnv(2, partial(DiscreteActionConvMockEnvNumpy, device=device)), r3m ) try: check_env_specs(transformed_env) @@ -1639,7 +1639,9 @@ def test_r3m_parallel(self, model, device): out_keys=out_keys, tensor_pixels_keys=tensor_pixels_key, ) - base_env = ParallelEnv(4, lambda: DiscreteActionConvMockEnvNumpy().to(device)) + base_env = ParallelEnv( + 4, partial(DiscreteActionConvMockEnvNumpy, device=device) + ) transformed_env = TransformedEnv(base_env, r3m) td = transformed_env.reset() assert td.device == device @@ -8503,7 +8505,7 @@ def test_trans_serial_env_check(self): def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - maybe_fork_ParallelEnv(2, lambda: ContinuousActionVecMockEnv()), + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), TimeMaxPool( in_keys=["observation"], T=3, @@ -8830,7 +8832,7 @@ def test_trans_parallel_env_check(self, model, device): tensor_pixels_keys=tensor_pixels_key, ) transformed_env = TransformedEnv( - ParallelEnv(2, lambda: DiscreteActionConvMockEnvNumpy().to(device)), vip + ParallelEnv(2, partial(DiscreteActionConvMockEnvNumpy, device=device)), vip ) try: check_env_specs(transformed_env) @@ -9056,7 +9058,9 @@ def test_transform_env(self, model, device): out_keys=out_keys, tensor_pixels_keys=tensor_pixels_key, ) - base_env = ParallelEnv(4, lambda: DiscreteActionConvMockEnvNumpy().to(device)) + base_env = ParallelEnv( + 4, partial(DiscreteActionConvMockEnvNumpy, device=device) + ) transformed_env = TransformedEnv(base_env, vip) td = transformed_env.reset() assert td.device == device @@ -9093,7 +9097,9 @@ def test_vip_parallel_reward(self, model, device, dtype_fixture): # noqa out_keys=out_keys, tensor_pixels_keys=tensor_pixels_key, ) - base_env = ParallelEnv(4, lambda: DiscreteActionConvMockEnvNumpy().to(device)) + base_env = ParallelEnv( + 4, partial(DiscreteActionConvMockEnvNumpy, device=device) + ) transformed_env = TransformedEnv(base_env, vip) tensordict_reset = TensorDict( {"goal_image": torch.randint(0, 255, (4, 7, 7, 3), dtype=torch.uint8)}, @@ -9296,7 +9302,7 @@ def test_trans_parallel_env_check(self, device): model_name="default", ) transformed_env = TransformedEnv( - ParallelEnv(2, lambda: DiscreteActionConvMockEnvNumpy().to(device)), vc1 + ParallelEnv(2, partial(DiscreteActionConvMockEnvNumpy, device=device)), vc1 ) try: check_env_specs(transformed_env) @@ -9452,7 +9458,9 @@ def test_transform_env(self, device, del_keys): del_keys=del_keys, model_name="default", ) - base_env = ParallelEnv(4, lambda: DiscreteActionConvMockEnvNumpy().to(device)) + base_env = ParallelEnv( + 4, partial(DiscreteActionConvMockEnvNumpy, device=device) + ) transformed_env = TransformedEnv(base_env, vc1) td = transformed_env.reset() assert td.device == device @@ -14763,7 +14771,7 @@ def test_trans_serial_env_check(self): def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): env = TransformedEnv( - maybe_fork_ParallelEnv(2, lambda: ContinuousActionVecMockEnv()), + maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), Timer(), ) try: diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index bcc85bac498..e51cd26f551 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1450,7 +1450,6 @@ class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta): def _start_workers(self) -> None: import torchrl - from torchrl.envs.env_creator import EnvCreator self._timeout = 10.0 self.BATCHED_PIPE_TIMEOUT = torchrl._utils.BATCHED_PIPE_TIMEOUT