Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ __pycache__
MUJOCO_LOG.TXT
mujoco_menagerie
checkpoints/
logs
6 changes: 6 additions & 0 deletions learning/train_rsl_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,12 @@ def render_callback(_, state):
# Build RSL-RL config
train_cfg = get_rl_config(_ENV_NAME.value)

obs_size = raw_env.observation_size
if isinstance(obs_size, dict):
train_cfg.obs_groups = {"policy": ["state"], "critic": ["privileged_state"]}
else:
train_cfg.obs_groups = {"policy": ["state"], "critic": ["state"]}

# Overwrite default config with flags
train_cfg.seed = _SEED.value
train_cfg.run_name = exp_name
Expand Down
23 changes: 10 additions & 13 deletions mujoco_playground/_src/wrapper_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections import deque
import functools
import os
from typing import Any

import jax
import numpy as np
Expand All @@ -31,6 +32,7 @@
torch = None

from mujoco_playground._src import wrapper
from tensordict import TensorDict


def _jax_to_torch(tensor):
Expand Down Expand Up @@ -158,8 +160,10 @@ def step(self, action):
if self.asymmetric_obs:
obs = _jax_to_torch(self.env_state.obs["state"])
critic_obs = _jax_to_torch(self.env_state.obs["privileged_state"])
obs = {"state": obs, "privileged_state": critic_obs}
else:
obs = _jax_to_torch(self.env_state.obs)
obs = {"state": obs}
reward = _jax_to_torch(self.env_state.reward)
done = _jax_to_torch(self.env_state.done)
info = self.env_state.info
Expand Down Expand Up @@ -187,6 +191,7 @@ def step(self, action):
if k not in info_ret["log"]:
info_ret["log"][k] = _jax_to_torch(v).float().mean().item()

obs = TensorDict(obs, batch_size=[self.num_envs])
return obs, reward, done, info_ret

def reset(self):
Expand All @@ -195,23 +200,15 @@ def reset(self):

if self.asymmetric_obs:
obs = _jax_to_torch(self.env_state.obs["state"])
# critic_obs = jax_to_torch(self.env_state.obs["privileged_state"])
critic_obs = _jax_to_torch(self.env_state.obs["privileged_state"])
obs = {"state": obs, "privileged_state": critic_obs}
else:
obs = _jax_to_torch(self.env_state.obs)
return obs

def reset_with_critic_obs(self):
self.env_state = self.reset_fn(self.key_reset)
obs = _jax_to_torch(self.env_state.obs["state"])
critic_obs = _jax_to_torch(self.env_state.obs["privileged_state"])
return obs, critic_obs
obs = {"state": obs}
return TensorDict(obs, batch_size=[self.num_envs])

def get_observations(self):
if self.asymmetric_obs:
obs, critic_obs = self.reset_with_critic_obs()
return obs, {"observations": {"critic": critic_obs}}
else:
return self.reset(), {"observations": {}}
return self.reset()

def render(self, mode="human"): # pylint: disable=unused-argument
if self.render_callback is not None:
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ dependencies = [
"orbax-checkpoint>=0.11.22",
"tqdm",
"warp-lang>=1.9.0.dev",
"wandb",
]
keywords = ["mjx", "mujoco", "sim2real", "reinforcement learning"]

Expand Down Expand Up @@ -75,9 +74,14 @@ dev = [
"pylint",
"pytest-xdist",
]
learning = [
"rsl-rl-lib>=3.0.0",
"wandb",
]
all = [
"playground[dev]",
"playground[notebooks]",
"playground[learning]",
]

[tool.hatch.metadata]
Expand Down
Loading