From 1daeb5735077e157d6cc57722d2cca9ac8a4c427 Mon Sep 17 00:00:00 2001 From: Baruch Tabanpour Date: Mon, 15 Sep 2025 10:22:08 -0700 Subject: [PATCH 1/4] fix for rsl-rl training with new tensordict in rsl-rl-lib>=3.0.0 --- .gitignore | 1 + learning/train_rsl_rl.py | 6 ++++++ mujoco_playground/_src/wrapper_torch.py | 23 ++++++++++------------- pyproject.toml | 1 + 4 files changed, 18 insertions(+), 13 deletions(-) diff --git a/.gitignore b/.gitignore index e2f0df855..649f62137 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ __pycache__ MUJOCO_LOG.TXT mujoco_menagerie checkpoints/ +logs diff --git a/learning/train_rsl_rl.py b/learning/train_rsl_rl.py index 96b7dcccf..71228eef8 100644 --- a/learning/train_rsl_rl.py +++ b/learning/train_rsl_rl.py @@ -167,6 +167,12 @@ def render_callback(_, state): # Build RSL-RL config train_cfg = get_rl_config(_ENV_NAME.value) + obs_size = raw_env.observation_size + if isinstance(obs_size, dict): + train_cfg.obs_groups = {"policy": ["state"], "critic": ["privileged_state"]} + else: + train_cfg.obs_groups = {"policy": ["state"], "critic": ["state"]} + # Overwrite default config with flags train_cfg.seed = _SEED.value train_cfg.run_name = exp_name diff --git a/mujoco_playground/_src/wrapper_torch.py b/mujoco_playground/_src/wrapper_torch.py index 3d70bff99..88ac4342e 100644 --- a/mujoco_playground/_src/wrapper_torch.py +++ b/mujoco_playground/_src/wrapper_torch.py @@ -17,6 +17,7 @@ from collections import deque import functools import os +from typing import Any, cast import jax import numpy as np @@ -31,6 +32,7 @@ torch = None from mujoco_playground._src import wrapper +from tensordict import TensorDict def _jax_to_torch(tensor): @@ -158,8 +160,10 @@ def step(self, action): if self.asymmetric_obs: obs = _jax_to_torch(self.env_state.obs["state"]) critic_obs = _jax_to_torch(self.env_state.obs["privileged_state"]) + obs = {"state": obs, "privileged_state": critic_obs} else: obs = _jax_to_torch(self.env_state.obs) + obs = {"state": obs} reward = _jax_to_torch(self.env_state.reward) done = _jax_to_torch(self.env_state.done) info = self.env_state.info @@ -187,6 +191,7 @@ def step(self, action): if k not in info_ret["log"]: info_ret["log"][k] = _jax_to_torch(v).float().mean().item() + obs = TensorDict(cast(dict[str, Any], obs), batch_size=[self.num_envs]) return obs, reward, done, info_ret def reset(self): @@ -195,23 +200,15 @@ def reset(self): if self.asymmetric_obs: obs = _jax_to_torch(self.env_state.obs["state"]) - # critic_obs = jax_to_torch(self.env_state.obs["privileged_state"]) + critic_obs = _jax_to_torch(self.env_state.obs["privileged_state"]) + obs = {"state": obs, "privileged_state": critic_obs} else: obs = _jax_to_torch(self.env_state.obs) - return obs - - def reset_with_critic_obs(self): - self.env_state = self.reset_fn(self.key_reset) - obs = _jax_to_torch(self.env_state.obs["state"]) - critic_obs = _jax_to_torch(self.env_state.obs["privileged_state"]) - return obs, critic_obs + obs = {"state": obs} + return TensorDict(cast(dict[str, Any], obs), batch_size=[self.num_envs]) def get_observations(self): - if self.asymmetric_obs: - obs, critic_obs = self.reset_with_critic_obs() - return obs, {"observations": {"critic": critic_obs}} - else: - return self.reset(), {"observations": {}} + return self.reset() def render(self, mode="human"): # pylint: disable=unused-argument if self.render_callback is not None: diff --git a/pyproject.toml b/pyproject.toml index 1365ddf8a..03412e37d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "tqdm", "warp-lang>=1.9.0.dev", "wandb", + "rsl-rl-lib>=3.0.0", ] keywords = ["mjx", "mujoco", "sim2real", "reinforcement learning"] From 79840577f3479aec5825bc6683d4a5a41e4cb99a Mon Sep 17 00:00:00 2001 From: Baruch Tabanpour Date: Mon, 15 Sep 2025 10:32:41 -0700 Subject: [PATCH 2/4] small fix --- mujoco_playground/_src/wrapper_torch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mujoco_playground/_src/wrapper_torch.py b/mujoco_playground/_src/wrapper_torch.py index 88ac4342e..8aadc43c5 100644 --- a/mujoco_playground/_src/wrapper_torch.py +++ b/mujoco_playground/_src/wrapper_torch.py @@ -17,7 +17,7 @@ from collections import deque import functools import os -from typing import Any, cast +from typing import Any import jax import numpy as np @@ -191,7 +191,7 @@ def step(self, action): if k not in info_ret["log"]: info_ret["log"][k] = _jax_to_torch(v).float().mean().item() - obs = TensorDict(cast(dict[str, Any], obs), batch_size=[self.num_envs]) + obs = TensorDict(obs, batch_size=[self.num_envs]) return obs, reward, done, info_ret def reset(self): @@ -205,7 +205,7 @@ def reset(self): else: obs = _jax_to_torch(self.env_state.obs) obs = {"state": obs} - return TensorDict(cast(dict[str, Any], obs), batch_size=[self.num_envs]) + return TensorDict(obs, batch_size=[self.num_envs]) def get_observations(self): return self.reset() From d87cb048b5b73e43204754335e8e1b9f3aff1f4a Mon Sep 17 00:00:00 2001 From: Baruch Tabanpour Date: Mon, 15 Sep 2025 10:39:33 -0700 Subject: [PATCH 3/4] deps --- pyproject.toml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 03412e37d..a693e088e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,8 +35,6 @@ dependencies = [ "orbax-checkpoint>=0.11.22", "tqdm", "warp-lang>=1.9.0.dev", - "wandb", - "rsl-rl-lib>=3.0.0", ] keywords = ["mjx", "mujoco", "sim2real", "reinforcement learning"] @@ -76,9 +74,14 @@ dev = [ "pylint", "pytest-xdist", ] +learning = [ + "rsl-rl-lib>=3.0.0", + "wandb", +] all = [ "playground[dev]", "playground[notebooks]", + "playground[learning]", ] [tool.hatch.metadata] From 9b6d015b85af168c3dd7b8d922b02c4b1ed15809 Mon Sep 17 00:00:00 2001 From: Baruch Tabanpour Date: Mon, 15 Sep 2025 10:42:50 -0700 Subject: [PATCH 4/4] import --- mujoco_playground/_src/wrapper_torch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mujoco_playground/_src/wrapper_torch.py b/mujoco_playground/_src/wrapper_torch.py index 8aadc43c5..566a7d605 100644 --- a/mujoco_playground/_src/wrapper_torch.py +++ b/mujoco_playground/_src/wrapper_torch.py @@ -32,7 +32,10 @@ torch = None from mujoco_playground._src import wrapper -from tensordict import TensorDict +try: + from tensordict import TensorDict +except ImportError: + TensorDict = None def _jax_to_torch(tensor):