Skip to content

Commit 4e31e49

Browse files
authored
Merge branch 'google-deepmind:main' into add-quadruped
2 parents 9abc68d + d886c80 commit 4e31e49

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+974
-369
lines changed

.github/workflows/pypi.yml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
name: Upload Python Package
2+
3+
on:
4+
release:
5+
types: [created]
6+
7+
jobs:
8+
deploy:
9+
runs-on: ubuntu-latest
10+
11+
steps:
12+
- uses: actions/checkout@v2
13+
- name: Set up Python
14+
uses: actions/setup-python@v4
15+
with:
16+
python-version: "3.10"
17+
- name: Install dependencies
18+
run: |
19+
pip install uv
20+
uv pip install --system -e ".[dev]"
21+
uv pip install --system build twine
22+
- name: Build and publish
23+
env:
24+
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
25+
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
26+
run: |
27+
python -m build
28+
twine upload --username $TWINE_USERNAME --password $TWINE_PASSWORD dist/*

CHANGELOG.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,24 @@
22

33
All notable changes to this project will be documented in this file.
44

5+
## Next release
6+
7+
- Pass through the [MuJoCo Warp](https://github.com/google-deepmind/mujoco_warp)
8+
(MjWarp) implementation to MJX, so that MuJoCo Playground environments can
9+
train with MuJoCo Warp! DM Control Suite and Locomotion environments now
10+
support MjWarp. You can pass through the implementation via the config
11+
override
12+
`registry.load('CartpoleBalance', config_overrides={'impl': 'warp'})`.
13+
14+
## [0.0.5] - 2025-06-23
15+
16+
- Change `light_directional` to `light_type` following MuJoCo API change from version 3.3.2 to 3.3.3. Fixes https://github.com/google-deepmind/mujoco_playground/issues/142.
17+
- Fix bug in `get_qpos_ids`.
18+
- Implement `render` in Wrapper.
19+
- Fix https://github.com/google-deepmind/mujoco_playground/issues/123.
20+
- Fix https://github.com/google-deepmind/mujoco_playground/issues/126.
21+
- Fix https://github.com/google-deepmind/mujoco_playground/issues/41.
22+
523
## [0.0.4] - 2025-02-07
624

725
### Added

README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,24 @@ For vision-based environments, please refer to the installation instructions in
6363
| [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/mujoco_playground/blob/main/learning/notebooks/training_vision_1.ipynb) | Training CartPole from Vision |
6464
| [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/mujoco_playground/blob/main/learning/notebooks/training_vision_2.ipynb) | Robotic Manipulation from Vision |
6565

66+
## Running from CLI
67+
> [!IMPORTANT]
68+
> Assumes installation from source.
69+
70+
For basic usage, navigate to the repo's directory and run:
71+
```bash
72+
python learning/train_jax_ppo.py --env_name CartpoleBalance
73+
```
74+
75+
### Training Visualization
76+
77+
To interactively view trajectories throughout training with [rscope](https://github.com/Andrew-Luo1/rscope/tree/main), install it (`pip install rscope`) and run:
78+
79+
```
80+
python learning/train_jax_ppo.py --env_name PandaPickCube --rscope_envs 16 --run_evals=False --deterministic_rscope=True
81+
# In a separate terminal
82+
python -m rscope
83+
```
6684

6785
## FAQ
6886

learning/train_jax_ppo.py

Lines changed: 91 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,33 @@
132132
"policy_obs_key", "state", "Policy obs key"
133133
)
134134
_VALUE_OBS_KEY = flags.DEFINE_string("value_obs_key", "state", "Value obs key")
135+
_RSCOPE_ENVS = flags.DEFINE_integer(
136+
"rscope_envs",
137+
None,
138+
"Number of parallel environment rollouts to save for the rscope viewer",
139+
)
140+
_DETERMINISTIC_RSCOPE = flags.DEFINE_boolean(
141+
"deterministic_rscope",
142+
True,
143+
"Run deterministic rollouts for the rscope viewer",
144+
)
145+
_RUN_EVALS = flags.DEFINE_boolean(
146+
"run_evals",
147+
True,
148+
"Run evaluation rollouts between policy updates.",
149+
)
150+
_LOG_TRAINING_METRICS = flags.DEFINE_boolean(
151+
"log_training_metrics",
152+
False,
153+
"Whether to log training metrics and callback to progress_fn. Significantly"
154+
" slows down training if too frequent.",
155+
)
156+
_TRAINING_METRICS_STEPS = flags.DEFINE_integer(
157+
"training_metrics_steps",
158+
1_000_000,
159+
"Number of steps between logging training metrics. Increase if training"
160+
" experiences slowdown.",
161+
)
135162

136163

137164
def get_rl_config(env_name: str) -> config_dict.ConfigDict:
@@ -151,6 +178,24 @@ def get_rl_config(env_name: str) -> config_dict.ConfigDict:
151178
raise ValueError(f"Env {env_name} not found in {registry.ALL_ENVS}.")
152179

153180

181+
def rscope_fn(full_states, obs, rew, done):
182+
"""
183+
All arrays are of shape (unroll_length, rscope_envs, ...)
184+
full_states: dict with keys 'qpos', 'qvel', 'time', 'metrics'
185+
obs: nd.array or dict obs based on env configuration
186+
rew: nd.array rewards
187+
done: nd.array done flags
188+
"""
189+
# Calculate cumulative rewards per episode, stopping at first done flag
190+
done_mask = jp.cumsum(done, axis=0)
191+
valid_rewards = rew * (done_mask == 0)
192+
episode_rewards = jp.sum(valid_rewards, axis=0)
193+
print(
194+
"Collected rscope rollouts with reward"
195+
f" {episode_rewards.mean():.3f} +- {episode_rewards.std():.3f}"
196+
)
197+
198+
154199
def main(argv):
155200
"""Run training and evaluation for the specified environment."""
156201

@@ -209,11 +254,16 @@ def main(argv):
209254
ppo_params.network_factory.policy_obs_key = _POLICY_OBS_KEY.value
210255
if _VALUE_OBS_KEY.present:
211256
ppo_params.network_factory.value_obs_key = _VALUE_OBS_KEY.value
212-
213257
if _VISION.value:
214258
env_cfg.vision = True
215259
env_cfg.vision_config.render_batch_size = ppo_params.num_envs
216260
env = registry.load(_ENV_NAME.value, config=env_cfg)
261+
if _RUN_EVALS.present:
262+
ppo_params.run_evals = _RUN_EVALS.value
263+
if _LOG_TRAINING_METRICS.present:
264+
ppo_params.log_training_metrics = _LOG_TRAINING_METRICS.value
265+
if _TRAINING_METRICS_STEPS.present:
266+
ppo_params.training_metrics_steps = _TRAINING_METRICS_STEPS.value
217267

218268
print(f"Environment Config:\n{env_cfg}")
219269
print(f"PPO Training Parameters:\n{ppo_params}")
@@ -268,13 +318,6 @@ def main(argv):
268318
with open(ckpt_path / "config.json", "w", encoding="utf-8") as fp:
269319
json.dump(env_cfg.to_dict(), fp, indent=4)
270320

271-
# Define policy parameters function for saving checkpoints
272-
def policy_params_fn(current_step, make_policy, params): # pylint: disable=unused-argument
273-
orbax_checkpointer = ocp.PyTreeCheckpointer()
274-
save_args = orbax_utils.save_args_from_target(params)
275-
path = ckpt_path / f"{current_step}"
276-
orbax_checkpointer.save(path, params, force=True, save_args=save_args)
277-
278321
training_params = dict(ppo_params)
279322
if "network_factory" in training_params:
280323
del training_params["network_factory"]
@@ -319,9 +362,9 @@ def policy_params_fn(current_step, make_policy, params): # pylint: disable=unus
319362
ppo.train,
320363
**training_params,
321364
network_factory=network_factory,
322-
policy_params_fn=policy_params_fn,
323365
seed=_SEED.value,
324366
restore_checkpoint_path=restore_checkpoint_path,
367+
save_checkpoint_path=ckpt_path,
325368
wrap_env_fn=None if _VISION.value else wrapper.wrap_for_brax_training,
326369
num_eval_envs=num_eval_envs,
327370
)
@@ -341,18 +384,55 @@ def progress(num_steps, metrics):
341384
for key, value in metrics.items():
342385
writer.add_scalar(key, value, num_steps)
343386
writer.flush()
344-
345-
print(f"{num_steps}: reward={metrics['eval/episode_reward']:.3f}")
387+
if _RUN_EVALS.value:
388+
print(f"{num_steps}: reward={metrics['eval/episode_reward']:.3f}")
389+
if _LOG_TRAINING_METRICS.value:
390+
if "episode/sum_reward" in metrics:
391+
print(
392+
f"{num_steps}: mean episode"
393+
f" reward={metrics['episode/sum_reward']:.3f}"
394+
)
346395

347396
# Load evaluation environment
348397
eval_env = (
349398
None if _VISION.value else registry.load(_ENV_NAME.value, config=env_cfg)
350399
)
351400

401+
policy_params_fn = lambda *args: None
402+
if _RSCOPE_ENVS.value:
403+
# Interactive visualisation of policy checkpoints
404+
from rscope import brax as rscope_utils
405+
406+
if not _VISION.value:
407+
rscope_env = registry.load(_ENV_NAME.value, config=env_cfg)
408+
rscope_env = wrapper.wrap_for_brax_training(
409+
rscope_env,
410+
episode_length=ppo_params.episode_length,
411+
action_repeat=ppo_params.action_repeat,
412+
randomization_fn=training_params.get("randomization_fn"),
413+
)
414+
else:
415+
rscope_env = env
416+
417+
rscope_handle = rscope_utils.BraxRolloutSaver(
418+
rscope_env,
419+
ppo_params,
420+
_VISION.value,
421+
_RSCOPE_ENVS.value,
422+
_DETERMINISTIC_RSCOPE.value,
423+
jax.random.PRNGKey(_SEED.value),
424+
rscope_fn,
425+
)
426+
427+
def policy_params_fn(current_step, make_policy, params): # pylint: disable=unused-argument
428+
rscope_handle.set_make_policy(make_policy)
429+
rscope_handle.dump_rollout(params)
430+
352431
# Train or load the model
353432
make_inference_fn, params, _ = train_fn( # pylint: disable=no-value-for-parameter
354433
environment=env,
355434
progress_fn=progress,
435+
policy_params_fn=policy_params_fn,
356436
eval_env=None if _VISION.value else eval_env,
357437
)
358438

mujoco_playground/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from mujoco_playground._src.mjx_env import render_array
2626
from mujoco_playground._src.mjx_env import State
2727
from mujoco_playground._src.mjx_env import step
28+
2829
# pylint: enable=g-importing-member
2930

3031
__all__ = [

mujoco_playground/_src/collision.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,29 @@
1717
from typing import Any, Tuple
1818

1919
import jax
20-
import jax.numpy as jnp
20+
import jax.numpy as jp
2121
from mujoco import mjx
22+
from mujoco.mjx._src import types
2223

2324

2425
def get_collision_info(
2526
contact: Any, geom1: int, geom2: int
2627
) -> Tuple[jax.Array, jax.Array]:
2728
"""Get the distance and normal of the collision between two geoms."""
28-
mask = (jnp.array([geom1, geom2]) == contact.geom).all(axis=1)
29-
mask |= (jnp.array([geom2, geom1]) == contact.geom).all(axis=1)
30-
idx = jnp.where(mask, contact.dist, 1e4).argmin()
29+
mask = (jp.array([geom1, geom2]) == contact.geom).all(axis=1)
30+
mask |= (jp.array([geom2, geom1]) == contact.geom).all(axis=1)
31+
idx = jp.where(mask, contact.dist, 1e4).argmin()
3132
dist = contact.dist[idx] * mask[idx]
3233
normal = (dist < 0) * contact.frame[idx, 0, :3]
3334
return dist, normal
3435

3536

3637
def geoms_colliding(state: mjx.Data, geom1: int, geom2: int) -> jax.Array:
3738
"""Return True if the two geoms are colliding."""
38-
return get_collision_info(state.contact, geom1, geom2)[0] < 0
39+
# if not isinstance(state._impl, types.DataJAX):
40+
# raise NotImplementedError(
41+
# "`geoms_colliding` only implemented for JAX MJX backend."
42+
# )
43+
if not isinstance(state._impl, types.DataJAX):
44+
return jp.array(False)
45+
return get_collision_info(state._impl.contact, geom1, geom2)[0] < 0 # pylint: disable=protected-access

mujoco_playground/_src/dm_control_suite/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ def load(
155155
An instance of the environment.
156156
"""
157157
if env_name not in _envs:
158-
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
158+
raise ValueError(
159+
f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}"
160+
)
159161
config = config or get_default_config(env_name)
160162
return _envs[env_name](config=config, config_overrides=config_overrides)

mujoco_playground/_src/dm_control_suite/acrobot.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ def default_config() -> config_dict.ConfigDict:
3636
episode_length=1000,
3737
action_repeat=1,
3838
vision=False,
39+
impl="jax",
40+
nconmax=0,
41+
njmax=0,
3942
)
4043

4144

@@ -57,11 +60,12 @@ def __init__(
5760
self._margin = 0.0 if sparse else 1.0
5861

5962
self._xml_path = _XML_PATH.as_posix()
63+
self._model_assets = common.get_assets()
6064
self._mj_model = mujoco.MjModel.from_xml_string(
61-
_XML_PATH.read_text(), common.get_assets()
65+
_XML_PATH.read_text(), self._model_assets
6266
)
6367
self._mj_model.opt.timestep = self.sim_dt
64-
self._mjx_model = mjx.put_model(self._mj_model)
68+
self._mjx_model = mjx.put_model(self._mj_model, impl=self._config.impl)
6569
self._post_init()
6670

6771
def _post_init(self) -> None:
@@ -77,7 +81,14 @@ def reset(self, rng: jax.Array) -> mjx_env.State:
7781
qpos = jax.random.uniform(
7882
rng1, (self.mjx_model.nq,), minval=-jp.pi, maxval=jp.pi
7983
)
80-
data = mjx_env.init(self.mjx_model, qpos=qpos)
84+
data = mjx_env.make_data(
85+
self.mj_model,
86+
qpos=qpos,
87+
impl=self.mjx_model.impl.value,
88+
nconmax=self._config.nconmax,
89+
njmax=self._config.njmax,
90+
)
91+
data = mjx.forward(self.mjx_model, data)
8192

8293
metrics = {
8394
"distance": jp.zeros(()),

mujoco_playground/_src/dm_control_suite/ball_in_cup.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ def default_config() -> config_dict.ConfigDict:
3535
episode_length=1000,
3636
action_repeat=1,
3737
vision=False,
38+
impl="jax",
39+
nconmax=10_000,
40+
njmax=25,
3841
)
3942

4043

@@ -53,11 +56,12 @@ def __init__(
5356
)
5457

5558
self._xml_path = _XML_PATH.as_posix()
59+
self._model_assets = common.get_assets()
5660
self._mj_model = mujoco.MjModel.from_xml_string(
57-
_XML_PATH.read_text(), common.get_assets()
61+
_XML_PATH.read_text(), self._model_assets
5862
)
5963
self._mj_model.opt.timestep = self.sim_dt
60-
self._mjx_model = mjx.put_model(self._mj_model)
64+
self._mjx_model = mjx.put_model(self._mj_model, impl=self._config.impl)
6165
self._post_init()
6266

6367
def _post_init(self) -> None:
@@ -68,7 +72,13 @@ def _post_init(self) -> None:
6872
self._ball_size = self._mj_model.geom_size[geom_id, 0]
6973

7074
def reset(self, rng: jax.Array) -> mjx_env.State:
71-
data = mjx_env.init(self.mjx_model)
75+
data = mjx_env.make_data(
76+
self.mj_model,
77+
impl=self.mjx_model.impl.value,
78+
nconmax=self._config.nconmax,
79+
njmax=self._config.njmax,
80+
)
81+
data = mjx.forward(self.mjx_model, data)
7282

7383
metrics = {}
7484
info = {"rng": rng}

0 commit comments

Comments
 (0)