Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new Fetch-v3 and HandReacher-v2 environments (Fix reproducibility issues) #208

Merged
merged 14 commits into from
May 29, 2024
Merged
7 changes: 2 additions & 5 deletions gymnasium_robotics/envs/fetch/fetch_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,11 +373,8 @@ def _render_callback(self):
self._mujoco.mj_forward(self.model, self.data)

def _reset_sim(self):
self.data.time = self.initial_time
self.data.qpos[:] = np.copy(self.initial_qpos)
self.data.qvel[:] = np.copy(self.initial_qvel)
if self.model.na != 0:
self.data.act[:] = None
# Reset buffers for joint states, actuators, warm-start, control buffers etc.
self._mujoco.mj_resetData(self.model, self.data)

# Randomize start position of object.
if self.has_object:
Expand Down
8 changes: 2 additions & 6 deletions gymnasium_robotics/envs/robot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,8 @@ def _initialize_simulation(self):
self.initial_qvel = np.copy(self.data.qvel)

def _reset_sim(self):
self.data.time = self.initial_time
self.data.qpos[:] = np.copy(self.initial_qpos)
self.data.qvel[:] = np.copy(self.initial_qvel)
if self.model.na != 0:
self.data.act[:] = None

# Reset buffers for joint states, warm-start, control buffers etc.
mujoco.mj_resetData(self.model, self.data)
mujoco.mj_forward(self.model, self.data)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mj_forward should also be removed (but kept for fetch_env because it changes the qpos)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests still succeed even if it is removed for fetch_env, although Mujoco should not reflect the changes in qpos in the positions of the links. Do you think it's worth adding tests that catch this or is that overkill?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not 100% sure that removing mj_forward after moving the position of an object (qpos), will not result in bugs, so it is better to just keep it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left it in. Let me know if that addresses all your remarks for the PR

return super()._reset_sim()

Expand Down
99 changes: 99 additions & 0 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,105 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
env_2.close()


@pytest.mark.parametrize(
"env_spec", non_mujoco_py_env_specs, ids=[env.id for env in non_mujoco_py_env_specs]
)
def test_same_env_determinism_rollout(env_spec: EnvSpec):
"""Run two rollouts with a single environment and assert equality.

This test runs two rollouts of NUM_STEPS steps with one environment
reset with the same seed and asserts that:

- observations after the reset are the same
- same actions are sampled by the environment
- observations are contained in the observation space
- obs, rew, terminated, truncated and info are equals between the two rollouts

Note:
We exclude mujoco_py environments because they are deprecated and their implementation is
frozen at this point. They are affected by a subtle bug in their reset method producing
slightly different results for the same seed on subsequent resets of the same environment.
This will not be fixed and tests are expected to fail.
"""
# Don't check rollout equality if it's a nondeterministic environment.
if env_spec.nondeterministic is True:
return

env = env_spec.make(disable_env_checker=True)

rollout_1 = {
"observations": [],
"actions": [],
"rewards": [],
"terminated": [],
"truncated": [],
"infos": [],
}
rollout_2 = {
"observations": [],
"actions": [],
"rewards": [],
"terminated": [],
"truncated": [],
"infos": [],
}

# Run two rollouts of the same environment instance
for rollout in [rollout_1, rollout_2]:
# Reset the environment with the same seed for both rollouts
obs, info = env.reset(seed=SEED)
env.action_space.seed(SEED)
rollout["observations"].append(obs)
rollout["infos"].append(info)

for time_step in range(NUM_STEPS):
action = env.action_space.sample()

obs, rew, terminated, truncated, info = env.step(action)
rollout["observations"].append(obs)
rollout["actions"].append(action)
rollout["rewards"].append(rew)
rollout["terminated"].append(terminated)
rollout["truncated"].append(truncated)
rollout["infos"].append(info)
if terminated or truncated:
env.reset(seed=SEED)

for time_step, (obs_1, obs_2) in enumerate(
zip(rollout_1["observations"], rollout_2["observations"])
):
# -1 because of the initial observation stored on reset
time_step = "initial" if time_step == 0 else time_step - 1
assert_equals(obs_1, obs_2, f"[{time_step}] ")
assert env.observation_space.contains(
obs_1
) # obs_2 verified by previous assertion
for time_step, (rew_1, rew_2) in enumerate(
zip(rollout_1["rewards"], rollout_2["rewards"])
):
assert rew_1 == rew_2, f"[{time_step}] reward 1={rew_1}, reward 2={rew_2}"
for time_step, (terminated_1, terminated_2) in enumerate(
zip(rollout_1["terminated"], rollout_2["terminated"])
):
assert (
terminated_1 == terminated_2
), f"[{time_step}] terminated 1={terminated_1}, terminated 2={terminated_2}"
for time_step, (truncated_1, truncated_2) in enumerate(
zip(rollout_1["truncated"], rollout_2["truncated"])
):
assert (
truncated_1 == truncated_2
), f"[{time_step}] truncated 1={truncated_1}, truncated 2={truncated_2}"
for time_step, (info_1, info_2) in enumerate(
zip(rollout_1["infos"], rollout_2["infos"])
):
# -1 because of the initial info stored on reset
time_step = "initial" if time_step == 0 else time_step - 1
assert_equals(info_1, info_2, f"[{time_step}] ")

env.close()


@pytest.mark.parametrize(
"spec", non_mujoco_py_env_specs, ids=[spec.id for spec in non_mujoco_py_env_specs]
)
Expand Down
Loading