diff --git a/test/test_envs.py b/test/test_envs.py index 92b3f67d11f..fab8e48fe36 100644 --- a/test/test_envs.py +++ b/test/test_envs.py @@ -1268,6 +1268,32 @@ def test_serial_for_single(self, maybe_fork_ParallelEnv, start_method): finally: env.close(raise_if_closed=False) + def test_lambda_wrapping(self, maybe_fork_ParallelEnv): + """Test that ParallelEnv automatically wraps lambda functions with EnvCreator. + + Lambda functions cannot be pickled with standard pickle (required for spawn + start method), but EnvCreator uses cloudpickle which can handle them. + This test verifies that lambda functions work correctly with ParallelEnv. + """ + # Test single lambda function + env = maybe_fork_ParallelEnv(2, lambda: ContinuousActionVecMockEnv()) + try: + rollout = env.rollout(3) + assert rollout.shape[0] == 2 + assert rollout.shape[1] == 3 + finally: + env.close(raise_if_closed=False) + + # Test list of lambda functions (heterogeneous envs) + env1 = lambda: ContinuousActionVecMockEnv() + env2 = lambda: ContinuousActionVecMockEnv() + env = maybe_fork_ParallelEnv(2, [env1, env2]) + try: + rollout = env.rollout(3) + assert rollout.shape[0] == 2 + finally: + env.close(raise_if_closed=False) + @pytest.mark.parametrize("num_parallel_env", [1, 10]) @pytest.mark.parametrize("env_batch_size", [[], (32,), (32, 1)]) def test_env_with_batch_size(