Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions learning/train_jax_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,16 @@
from absl import app
from absl import flags
from absl import logging
from brax.io import model
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import networks_vision as ppo_networks_vision
from brax.training.agents.ppo import train as ppo
from etils import epath
from flax.training import orbax_utils
import jax
import jax.numpy as jp
import mediapy as media
from ml_collections import config_dict
import mujoco
from orbax import checkpoint as ocp
from tensorboardX import SummaryWriter
import wandb

Expand Down Expand Up @@ -73,10 +72,16 @@
_LOAD_CHECKPOINT_PATH = flags.DEFINE_string(
"load_checkpoint_path", None, "Path to load checkpoint from"
)
_SAVE_PARAMS_PATH = flags.DEFINE_string(
"save_params_path", None, "Path to save parameters to"
)
_SUFFIX = flags.DEFINE_string("suffix", None, "Suffix for the experiment name")
_PLAY_ONLY = flags.DEFINE_boolean(
"play_only", False, "If true, only play with the model and do not train"
)
_RENDER_FINAL_POLICY = flags.DEFINE_boolean(
"render_final_policy", True, "If true, render the final policy"
)
_USE_WANDB = flags.DEFINE_boolean(
"use_wandb",
False,
Expand Down Expand Up @@ -264,17 +269,6 @@ def main(argv):
ckpt_path.mkdir(parents=True, exist_ok=True)
print(f"Checkpoint path: {ckpt_path}")

# Save environment configuration
with open(ckpt_path / "config.json", "w", encoding="utf-8") as fp:
json.dump(env_cfg.to_dict(), fp, indent=4)

# Define policy parameters function for saving checkpoints
def policy_params_fn(current_step, make_policy, params): # pylint: disable=unused-argument
orbax_checkpointer = ocp.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(params)
path = ckpt_path / f"{current_step}"
orbax_checkpointer.save(path, params, force=True, save_args=save_args)

training_params = dict(ppo_params)
if "network_factory" in training_params:
del training_params["network_factory"]
Expand Down Expand Up @@ -319,8 +313,8 @@ def policy_params_fn(current_step, make_policy, params): # pylint: disable=unus
ppo.train,
**training_params,
network_factory=network_factory,
policy_params_fn=policy_params_fn,
seed=_SEED.value,
save_checkpoint_path=ckpt_path,
restore_checkpoint_path=restore_checkpoint_path,
wrap_env_fn=None if _VISION.value else wrapper.wrap_for_brax_training,
num_eval_envs=num_eval_envs,
Expand Down Expand Up @@ -361,6 +355,12 @@ def progress(num_steps, metrics):
print(f"Time to JIT compile: {times[1] - times[0]}")
print(f"Time to train: {times[-1] - times[1]}")

if _SAVE_PARAMS_PATH.value is not None:
model.save_params(epath.Path(_SAVE_PARAMS_PATH.value).resolve(), params)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Andrew-Luo1 would really like to not use the pkl stuff, is this absolutely necessary?


if not _RENDER_FINAL_POLICY.value:
return

print("Starting inference...")

# Create inference function
Expand Down
1 change: 1 addition & 0 deletions mujoco_playground/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from mujoco_playground._src.mjx_env import render_array
from mujoco_playground._src.mjx_env import State
from mujoco_playground._src.mjx_env import step

# pylint: enable=g-importing-member

__all__ = [
Expand Down
4 changes: 3 additions & 1 deletion mujoco_playground/_src/dm_control_suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def load(
An instance of the environment.
"""
if env_name not in _envs:
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
raise ValueError(
f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}"
)
config = config or get_default_config(env_name)
return _envs[env_name](config=config, config_overrides=config_overrides)
4 changes: 3 additions & 1 deletion mujoco_playground/_src/locomotion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ def load(
An instance of the environment.
"""
if env_name not in _envs:
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
raise ValueError(
f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}"
)
config = config or get_default_config(env_name)
return _envs[env_name](config=config, config_overrides=config_overrides)

Expand Down
3 changes: 1 addition & 2 deletions mujoco_playground/_src/locomotion/t1/randomize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from mujoco import mjx
import numpy as np


FLOOR_GEOM_ID = 0
TORSO_BODY_ID = 1
ANKLE_JOINT_IDS = np.array([[21, 22, 27, 28]])
Expand All @@ -30,7 +29,7 @@ def rand_dynamics(rng):
# Floor friction: =U(0.4, 1.0).
rng, key = jax.random.split(rng)
geom_friction = model.geom_friction.at[FLOOR_GEOM_ID, 0].set(
jax.random.uniform(key, minval=0.2, maxval=.6)
jax.random.uniform(key, minval=0.2, maxval=0.6)
)

rng, key = jax.random.split(rng)
Expand Down
18 changes: 14 additions & 4 deletions mujoco_playground/_src/manipulation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from mujoco import mjx

from mujoco_playground._src import mjx_env
from mujoco_playground._src.manipulation.aloha import distillation as aloha_distillation
from mujoco_playground._src.manipulation.aloha import handover as aloha_handover
from mujoco_playground._src.manipulation.aloha import single_peg_insertion as aloha_peg
from mujoco_playground._src.manipulation.aloha import peg_insertion as aloha_peg_insertion
from mujoco_playground._src.manipulation.aloha import pick as aloha_pick
from mujoco_playground._src.manipulation.franka_emika_panda import open_cabinet as panda_open_cabinet
from mujoco_playground._src.manipulation.franka_emika_panda import pick as panda_pick
from mujoco_playground._src.manipulation.franka_emika_panda import pick_cartesian as panda_pick_cartesian
Expand All @@ -31,7 +33,9 @@

_envs = {
"AlohaHandOver": aloha_handover.HandOver,
"AlohaSinglePegInsertion": aloha_peg.SinglePegInsertion,
"AlohaPick": aloha_pick.Pick,
"AlohaPegInsertion": aloha_peg_insertion.SinglePegInsertion,
"AlohaPegInsertionDistill": aloha_distillation.DistillPegInsertion,
"PandaPickCube": panda_pick.PandaPickCube,
"PandaPickCubeOrientation": panda_pick.PandaPickCubeOrientation,
"PandaPickCubeCartesian": panda_pick_cartesian.PandaPickCubeCartesian,
Expand All @@ -43,7 +47,9 @@

_cfgs = {
"AlohaHandOver": aloha_handover.default_config,
"AlohaSinglePegInsertion": aloha_peg.default_config,
"AlohaPick": aloha_pick.default_config,
"AlohaPegInsertion": aloha_peg_insertion.default_config,
"AlohaPegInsertionDistill": aloha_distillation.default_config,
"PandaPickCube": panda_pick.default_config,
"PandaPickCubeOrientation": panda_pick.default_config,
"PandaPickCubeCartesian": panda_pick_cartesian.default_config,
Expand All @@ -56,6 +62,8 @@
_randomizer = {
"LeapCubeRotateZAxis": leap_rotate_z.domain_randomize,
"LeapCubeReorient": leap_cube_reorient.domain_randomize,
"AlohaPick": aloha_pick.domain_randomize,
"AlohaPegInsertionDistill": aloha_distillation.domain_randomize,
}


Expand Down Expand Up @@ -108,7 +116,9 @@ def load(
An instance of the environment.
"""
if env_name not in _envs:
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
raise ValueError(
f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}"
)
config = config or get_default_config(env_name)
return _envs[env_name](config=config, config_overrides=config_overrides)

Expand Down
84 changes: 84 additions & 0 deletions mujoco_playground/_src/manipulation/aloha/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
### Quickstart


**Pre-requisites**

- *Handover, Pick, Peg Insertion:* The standard Playground setup
- *Behaviour Cloning for Peg Insertion:* Madrona MJX
- *Jax-to-ONNX Conversion:* Onnx, Tensorflow, tf2onnx

```bash
# Train Aloha Handover. Documentation at https://github.com/google-deepmind/mujoco_playground/pull/29
python learning/train_jax_ppo.py --env_name AlohaHandOver
```

```bash
# Plots for pick and peg-insertion at https://github.com/google-deepmind/mujoco_playground/pull/76
cd <PATH_TO_YOUR_CLONE>
export PARAMS_PATH=mujoco_playground/_src/manipulation/aloha/params

# Train a single arm to pick up a cube.
python learning/train_jax_ppo.py --env_name AlohaPick --domain_randomization --norender_final_policy --save_params_path $PARAMS_PATH/AlohaPick.prms
sleep 0.5

# Train a biarm to insert a peg into a socket. Requires above policy.
python learning/train_jax_ppo.py --env_name AlohaPegInsertion --save_params_path $PARAMS_PATH/AlohaPegInsertion.prms
sleep 0.5

# Train a student policy to insert a peg into a socket using *pixel inputs*. Requires above policy.
python mujoco_playground/experimental/bc_peg_insertion.py --domain-randomization --num-evals 0 --print-loss

# Convert checkpoints from the above run to ONNX for easy robot deployment.
# ONNX policies are written to `experimental/jax2onnx/onnx_policies`.
python mujoco_playground/experimental/jax2onnx/aloha_nets_to_onnx.py --checkpoint_path <YOUR_DISTILL_CHECKPOINT_DIR>
```

### Sim-to-Real Transfer of a Bi-Arm RL Policy via Pixel-Based Behaviour Cloning

https://github.com/user-attachments/assets/205fe8b9-1773-4715-8025-5de13490d0da

---

**Distillation**

In this module, we demonstrate policy distillation: a straightforward method for deploying a simulation-trained reinforcement learning policy that initially uses privileged state observations (such as object positions). The process involves two steps:

1. **Teacher Policy Training:** A state-based teacher policy is trained using RL.
2. **Student Policy Distillation:** The teacher is then distilled into a student policy via behaviour cloning (BC), where the student learns to map its observations $o_s(x)$ (e.g., exteroceptive RGBD images) to the teacher’s deterministic actions $\pi_t(o_t(x))$. For example, while both policies observe joint angles, the student uses RGBD images, whereas the teacher directly accesses (noisy) object positions.

The distillation process—where the student uses left and right wrist-mounted RGBD cameras for exteroception—takes about **3 minutes** on an RTX4090. This rapid turnaround is due to three factors:

1. [Very fast rendering](https://github.com/google-deepmind/mujoco_playground/blob/main/mujoco_playground/experimental/madrona_benchmarking/figures/cartpole_benchmark_full.png) provided by Madrona MJX.
2. The sample efficiency of behaviour cloning.
3. The use of low-resolution (32×32) rendering, which is sufficient for precise alignment given the wrist camera placement.

For further details on the teacher policy and RGBD sim-to-real techniques, please refer to the [technical report](https://docs.google.com/presentation/d/1v50Vg-SJdy5HV5JmPHALSwph9mcVI2RSPRdrxYR3Bkg/edit?usp=sharing).

---

**A Note on Sample Efficiency**

Behaviour cloning (BC) can be orders of magnitude more sample-efficient than reinforcement learning. In our approach, we use an L2 loss defined as:

$|| \pi_s(o_s(x)) - \pi_t(o_t(x)) ||$

In contrast, the policy gradient in RL generally takes the form:

![Equation](https://latex.codecogs.com/svg.latex?\nabla_\theta%20J(\theta)%20=%20\mathbb{E}_{\tau%20\sim%20\theta}%20\left[\sum_t%20\nabla_\theta%20\log%20\pi_\theta(a_t%20|%20s_t)%20R(\tau)\right])

Two key observations highlight why BC’s direct supervision is more efficient:

- **Explicit Loss Signal:** The BC loss compares against the teacher action, giving explicit feedback on how the action should be adjusted. In contrast, the policy gradient only provides directional guidance, instructing the optimizer to increase or decrease an action’s likelihood based solely on its downstream rewards.
- **Per-Dimension Supervision:** While the policy gradient applies a uniform weighting across all action dimensions, BC supplies per-dimension information, making it easier to scale to high-dimensional action spaces.

---

**Frozen Encoders**

*VisionMLP2ChanCIFAR10_OCP* is an Orbax checkpoint of [NatureCNN](https://github.com/google/brax/blob/241f9bc5bbd003f9cfc9ded7613388e2fe125af6/brax/training/networks.py#L153) (AtariCNN) pre-trained on CIFAR10 to achieve over 70% classification accuracy. We omit the supervised training code, see [this tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial5/Inception_ResNet_DenseNet.html) for reference.

---

**Aloha Deployment Setup**

For deployment, the ONNX policy is executed on the Aloha robot using a custom fork of [OpenPI](https://github.com/Physical-Intelligence/openpi) along with the Interbotix Aloha ROS packages. Acknowledgements to Kevin Zakka, Laura Smith and the Levine Lab for robot deployment setup!
22 changes: 22 additions & 0 deletions mujoco_playground/_src/manipulation/aloha/aloha_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,25 @@
"right/left_finger",
"right/right_finger",
]

LEFT_JOINTS = [
"left/waist",
"left/shoulder",
"left/elbow",
"left/forearm_roll",
"left/wrist_angle",
"left/wrist_rotate",
"left/left_finger",
"left/right_finger",
]

RIGHT_JOINTS = [
"right/waist",
"right/shoulder",
"right/elbow",
"right/forearm_roll",
"right/wrist_angle",
"right/wrist_rotate",
"right/left_finger",
"right/right_finger",
]
3 changes: 3 additions & 0 deletions mujoco_playground/_src/manipulation/aloha/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def get_assets() -> Dict[str, bytes]:
path = mjx_env.ROOT_PATH / "manipulation" / "aloha" / "xmls"
mjx_env.update_assets(assets, path, "*.xml")
mjx_env.update_assets(assets, path / "assets")
path = mjx_env.ROOT_PATH / "manipulation" / "aloha" / "xmls" / "s2r"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no longer needed FWIU

mjx_env.update_assets(assets, path, "*.xml")
mjx_env.update_assets(assets, path / "assets")
return assets


Expand Down
Loading