Skip to content

Commit 7a9f211

Browse files
committed
Merge branch 'add-quadruped' of github.com:yardenas/mujoco_playground into add-quadruped
2 parents c9d44b2 + 4e31e49 commit 7a9f211

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+654
-263
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@
22

33
All notable changes to this project will be documented in this file.
44

5+
## Next release
6+
7+
- Pass through the [MuJoCo Warp](https://github.com/google-deepmind/mujoco_warp)
8+
(MjWarp) implementation to MJX, so that MuJoCo Playground environments can
9+
train with MuJoCo Warp! DM Control Suite and Locomotion environments now
10+
support MjWarp. You can pass through the implementation via the config
11+
override
12+
`registry.load('CartpoleBalance', config_overrides={'impl': 'warp'})`.
13+
514
## [0.0.5] - 2025-06-23
615

716
- Change `light_directional` to `light_type` following MuJoCo API change from version 3.3.2 to 3.3.3. Fixes https://github.com/google-deepmind/mujoco_playground/issues/142.

mujoco_playground/_src/collision.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,29 @@
1717
from typing import Any, Tuple
1818

1919
import jax
20-
import jax.numpy as jnp
20+
import jax.numpy as jp
2121
from mujoco import mjx
22+
from mujoco.mjx._src import types
2223

2324

2425
def get_collision_info(
2526
contact: Any, geom1: int, geom2: int
2627
) -> Tuple[jax.Array, jax.Array]:
2728
"""Get the distance and normal of the collision between two geoms."""
28-
mask = (jnp.array([geom1, geom2]) == contact.geom).all(axis=1)
29-
mask |= (jnp.array([geom2, geom1]) == contact.geom).all(axis=1)
30-
idx = jnp.where(mask, contact.dist, 1e4).argmin()
29+
mask = (jp.array([geom1, geom2]) == contact.geom).all(axis=1)
30+
mask |= (jp.array([geom2, geom1]) == contact.geom).all(axis=1)
31+
idx = jp.where(mask, contact.dist, 1e4).argmin()
3132
dist = contact.dist[idx] * mask[idx]
3233
normal = (dist < 0) * contact.frame[idx, 0, :3]
3334
return dist, normal
3435

3536

3637
def geoms_colliding(state: mjx.Data, geom1: int, geom2: int) -> jax.Array:
3738
"""Return True if the two geoms are colliding."""
39+
# if not isinstance(state._impl, types.DataJAX):
40+
# raise NotImplementedError(
41+
# "`geoms_colliding` only implemented for JAX MJX backend."
42+
# )
43+
if not isinstance(state._impl, types.DataJAX):
44+
return jp.array(False)
3845
return get_collision_info(state._impl.contact, geom1, geom2)[0] < 0 # pylint: disable=protected-access

mujoco_playground/_src/dm_control_suite/acrobot.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ def default_config() -> config_dict.ConfigDict:
3636
episode_length=1000,
3737
action_repeat=1,
3838
vision=False,
39+
impl="jax",
40+
nconmax=0,
41+
njmax=0,
3942
)
4043

4144

@@ -62,7 +65,7 @@ def __init__(
6265
_XML_PATH.read_text(), self._model_assets
6366
)
6467
self._mj_model.opt.timestep = self.sim_dt
65-
self._mjx_model = mjx.put_model(self._mj_model)
68+
self._mjx_model = mjx.put_model(self._mj_model, impl=self._config.impl)
6669
self._post_init()
6770

6871
def _post_init(self) -> None:
@@ -78,7 +81,14 @@ def reset(self, rng: jax.Array) -> mjx_env.State:
7881
qpos = jax.random.uniform(
7982
rng1, (self.mjx_model.nq,), minval=-jp.pi, maxval=jp.pi
8083
)
81-
data = mjx_env.init(self.mjx_model, qpos=qpos)
84+
data = mjx_env.make_data(
85+
self.mj_model,
86+
qpos=qpos,
87+
impl=self.mjx_model.impl.value,
88+
nconmax=self._config.nconmax,
89+
njmax=self._config.njmax,
90+
)
91+
data = mjx.forward(self.mjx_model, data)
8292

8393
metrics = {
8494
"distance": jp.zeros(()),

mujoco_playground/_src/dm_control_suite/ball_in_cup.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ def default_config() -> config_dict.ConfigDict:
3535
episode_length=1000,
3636
action_repeat=1,
3737
vision=False,
38+
impl="jax",
39+
nconmax=10_000,
40+
njmax=25,
3841
)
3942

4043

@@ -58,7 +61,7 @@ def __init__(
5861
_XML_PATH.read_text(), self._model_assets
5962
)
6063
self._mj_model.opt.timestep = self.sim_dt
61-
self._mjx_model = mjx.put_model(self._mj_model)
64+
self._mjx_model = mjx.put_model(self._mj_model, impl=self._config.impl)
6265
self._post_init()
6366

6467
def _post_init(self) -> None:
@@ -69,7 +72,13 @@ def _post_init(self) -> None:
6972
self._ball_size = self._mj_model.geom_size[geom_id, 0]
7073

7174
def reset(self, rng: jax.Array) -> mjx_env.State:
72-
data = mjx_env.init(self.mjx_model)
75+
data = mjx_env.make_data(
76+
self.mj_model,
77+
impl=self.mjx_model.impl.value,
78+
nconmax=self._config.nconmax,
79+
njmax=self._config.njmax,
80+
)
81+
data = mjx.forward(self.mjx_model, data)
7382

7483
metrics = {}
7584
info = {"rng": rng}

mujoco_playground/_src/dm_control_suite/cartpole.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ def default_config() -> config_dict.ConfigDict:
5151
action_repeat=1,
5252
vision=False,
5353
vision_config=default_vision_config(),
54+
impl="jax",
55+
nconmax=0,
56+
njmax=2,
5457
)
5558

5659

@@ -95,7 +98,7 @@ def __init__(
9598
_XML_PATH.read_text(), self._model_assets
9699
)
97100
self._mj_model.opt.timestep = self.sim_dt
98-
self._mjx_model = mjx.put_model(self._mj_model)
101+
self._mjx_model = mjx.put_model(self._mj_model, impl=self._config.impl)
99102
self._post_init()
100103

101104
if self._vision:
@@ -129,7 +132,7 @@ def _post_init(self) -> None:
129132
self._hinge_1_qposadr = self._mj_model.jnt_qposadr[hinge_1_jid]
130133

131134
def _reset_swing_up(self, rng: jax.Array) -> jax.Array:
132-
rng, rng1, rng2, rng3 = jax.random.split(rng, 4)
135+
_, rng1, rng2, rng3 = jax.random.split(rng, 4)
133136

134137
qpos = jp.zeros(self.mjx_model.nq)
135138
qpos = qpos.at[self._slider_qposadr].set(0.01 * jax.random.normal(rng1))
@@ -163,7 +166,15 @@ def reset(self, rng: jax.Array) -> mjx_env.State:
163166
rng, rng1 = jax.random.split(rng, 2)
164167
qvel = 0.01 * jax.random.normal(rng1, (self.mjx_model.nv,))
165168

166-
data = mjx_env.init(self.mjx_model, qpos=qpos, qvel=qvel)
169+
data = mjx_env.make_data(
170+
self.mj_model,
171+
qpos=qpos,
172+
qvel=qvel,
173+
impl=self.mjx_model.impl.value,
174+
nconmax=self._config.nconmax,
175+
njmax=self._config.njmax,
176+
)
177+
data = mjx.forward(self.mjx_model, data)
167178

168179
metrics = {
169180
"reward/upright": jp.zeros(()),

mujoco_playground/_src/dm_control_suite/cheetah.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def default_config() -> config_dict.ConfigDict:
3838
episode_length=1000,
3939
action_repeat=1,
4040
vision=False,
41+
impl="jax",
42+
nconmax=100_000,
43+
njmax=100,
4144
)
4245

4346

@@ -61,7 +64,7 @@ def __init__(
6164
_XML_PATH.read_text(), self._model_assets
6265
)
6366
self._mj_model.opt.timestep = self.sim_dt
64-
self._mjx_model = mjx.put_model(self._mj_model)
67+
self._mjx_model = mjx.put_model(self._mj_model, impl=self._config.impl)
6568
self._post_init()
6669

6770
def _post_init(self) -> None:
@@ -81,7 +84,14 @@ def reset(self, rng: jax.Array) -> mjx_env.State:
8184
)
8285
)
8386

84-
data = mjx_env.init(self.mjx_model, qpos=qpos)
87+
data = mjx_env.make_data(
88+
self.mj_model,
89+
qpos=qpos,
90+
impl=self.mjx_model.impl.value,
91+
nconmax=self._config.nconmax,
92+
njmax=self._config.njmax,
93+
)
94+
data = mjx.forward(self.mjx_model, data)
8595

8696
# Stabilize.
8797
data = mjx_env.step(self.mjx_model, data, jp.zeros(self.mjx_model.nu), 200)

mujoco_playground/_src/dm_control_suite/finger.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def default_config() -> config_dict.ConfigDict:
4848
episode_length=1000,
4949
action_repeat=1,
5050
vision=False,
51+
impl="jax",
52+
nconmax=25_000,
53+
njmax=25,
5154
)
5255

5356

@@ -93,7 +96,7 @@ def __init__(
9396
self._model_assets = common.get_assets()
9497
self._mj_model = _make_spin_model(_XML_PATH, self._model_assets)
9598
self._mj_model.opt.timestep = self.sim_dt
96-
self._mjx_model = mjx.put_model(self._mj_model)
99+
self._mjx_model = mjx.put_model(self._mj_model, impl=self._config.impl)
97100
self._post_init()
98101

99102
def _post_init(self) -> None:
@@ -108,7 +111,13 @@ def reset(self, rng: jax.Array) -> mjx_env.State:
108111
)
109112
qpos = qpos.at[2].set(jax.random.uniform(rng1, minval=-jp.pi, maxval=jp.pi))
110113

111-
data = mjx_env.init(self.mjx_model, qpos)
114+
data = mjx_env.make_data(
115+
self.mj_model,
116+
qpos=qpos,
117+
impl=self.mjx_model.impl.value,
118+
nconmax=self._config.nconmax,
119+
njmax=self._config.njmax,
120+
)
112121

113122
metrics = {}
114123
info = {"rng": rng}

mujoco_playground/_src/dm_control_suite/fish.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ def default_config() -> config_dict.ConfigDict:
4646
episode_length=1000,
4747
action_repeat=1,
4848
vision=False,
49+
impl="jax",
50+
nconmax=0,
51+
njmax=25,
4952
)
5053

5154

@@ -69,7 +72,7 @@ def __init__(
6972
_XML_PATH.read_text(), self._model_assets
7073
)
7174
self._mj_model.opt.timestep = self.sim_dt
72-
self._mjx_model = mjx.put_model(self._mj_model)
75+
self._mjx_model = mjx.put_model(self._mj_model, impl=self._config.impl)
7376
self._post_init()
7477

7578
def _post_init(self) -> None:
@@ -104,7 +107,14 @@ def reset(self, rng: jax.Array) -> mjx_env.State:
104107
)
105108
)
106109

107-
data = mjx_env.init(self.mjx_model, qpos=qpos)
110+
data = mjx_env.make_data(
111+
self.mj_model,
112+
qpos=qpos,
113+
impl=self.mjx_model.impl.value,
114+
nconmax=self._config.nconmax,
115+
njmax=self._config.njmax,
116+
)
117+
data = mjx.forward(self.mjx_model, data)
108118

109119
# Randomize target position.
110120
xyz = jax.random.uniform(

mujoco_playground/_src/dm_control_suite/hopper.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ def default_config() -> config_dict.ConfigDict:
4141
episode_length=1000,
4242
action_repeat=1,
4343
vision=False,
44+
impl="jax",
45+
nconmax=50_000,
46+
njmax=50,
4447
)
4548

4649

@@ -78,7 +81,7 @@ def __init__(
7881
_XML_PATH.read_text(), self._model_assets
7982
)
8083
self._mj_model.opt.timestep = self.sim_dt
81-
self._mjx_model = mjx.put_model(self._mj_model)
84+
self._mjx_model = mjx.put_model(self._mj_model, impl=self._config.impl)
8285
self._post_init()
8386

8487
def _post_init(self) -> None:
@@ -103,7 +106,14 @@ def reset(self, rng: jax.Array) -> mjx_env.State:
103106
)
104107
)
105108

106-
data = mjx_env.init(self.mjx_model, qpos=qpos)
109+
data = mjx_env.make_data(
110+
self.mj_model,
111+
qpos=qpos,
112+
impl=self.mjx_model.impl.value,
113+
nconmax=self._config.nconmax,
114+
njmax=self._config.njmax,
115+
)
116+
data = mjx.forward(self.mjx_model, data)
107117

108118
metrics = {k: jp.zeros(()) for k in self._metric_keys}
109119
info = {"rng": rng}

mujoco_playground/_src/dm_control_suite/humanoid.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ def default_config() -> config_dict.ConfigDict:
4242
episode_length=1000,
4343
action_repeat=1,
4444
vision=False,
45+
impl="jax",
46+
nconmax=200_000,
47+
njmax=250,
4548
)
4649

4750

@@ -72,7 +75,7 @@ def __init__(
7275
_XML_PATH.read_text(), self._model_assets
7376
)
7477
self._mj_model.opt.timestep = self.sim_dt
75-
self._mjx_model = mjx.put_model(self._mj_model)
78+
self._mjx_model = mjx.put_model(self._mj_model, impl=self._config.impl)
7679
self._post_init()
7780

7881
def _post_init(self) -> None:
@@ -88,7 +91,13 @@ def _post_init(self) -> None:
8891
def reset(self, rng: jax.Array) -> mjx_env.State:
8992
# TODO(kevin): Add non-penetrating joint randomization.
9093

91-
data = mjx_env.init(self.mjx_model)
94+
data = mjx_env.make_data(
95+
self.mj_model,
96+
impl=self.mjx_model.impl.value,
97+
nconmax=self._config.nconmax,
98+
njmax=self._config.njmax,
99+
)
100+
data = mjx.forward(self.mjx_model, data)
92101

93102
metrics = {
94103
"reward/standing": jp.zeros(()),

0 commit comments

Comments
 (0)