diff --git a/mujoco_playground/_src/dm_control_suite/__init__.py b/mujoco_playground/_src/dm_control_suite/__init__.py index a45f2948f..6ddfe20ad 100644 --- a/mujoco_playground/_src/dm_control_suite/__init__.py +++ b/mujoco_playground/_src/dm_control_suite/__init__.py @@ -30,6 +30,7 @@ from mujoco_playground._src.dm_control_suite import humanoid from mujoco_playground._src.dm_control_suite import pendulum from mujoco_playground._src.dm_control_suite import point_mass +from mujoco_playground._src.dm_control_suite import quadruped from mujoco_playground._src.dm_control_suite import reacher from mujoco_playground._src.dm_control_suite import swimmer from mujoco_playground._src.dm_control_suite import walker @@ -62,6 +63,8 @@ "HumanoidRun": partial(humanoid.Humanoid, move_speed=humanoid.RUN_SPEED), "PendulumSwingup": pendulum.SwingUp, "PointMass": point_mass.PointMass, + "QuadrupedWalk": partial(quadruped.Quadruped, desired_speed=quadruped.WALK_SPEED), + "QuadrupedRun": partial(quadruped.Quadruped, desired_speed=quadruped.RUN_SPEED), "ReacherEasy": partial(reacher.Reacher, target_size=reacher.BIG_TARGET), "ReacherHard": partial(reacher.Reacher, target_size=reacher.SMALL_TARGET), "SwimmerSwimmer6": partial(swimmer.Swim, n_links=6), @@ -91,6 +94,8 @@ "HumanoidWalk": humanoid.default_config, "PendulumSwingup": pendulum.default_config, "PointMass": point_mass.default_config, + "QuadrupedWalk": quadruped.default_config, + "QuadrupedRun": quadruped.default_config, "ReacherEasy": reacher.default_config, "ReacherHard": reacher.default_config, "SwimmerSwimmer6": swimmer.default_config, diff --git a/mujoco_playground/_src/dm_control_suite/quadruped.py b/mujoco_playground/_src/dm_control_suite/quadruped.py new file mode 100644 index 000000000..b1e17ca63 --- /dev/null +++ b/mujoco_playground/_src/dm_control_suite/quadruped.py @@ -0,0 +1,210 @@ +# Copyright 2025 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Quadruped environment.""" + +from itertools import product +from typing import Any, Dict, Optional, Union + +import jax +import jax.numpy as jp +from ml_collections import config_dict +import mujoco +from mujoco import mjx + +from mujoco_playground._src import mjx_env +from mujoco_playground._src import reward +from mujoco_playground._src.dm_control_suite import common + +_XML_PATH = mjx_env.ROOT_PATH / "dm_control_suite" / "xmls" / "quadruped.xml" + +WALK_SPEED = 0.5 +RUN_SPEED = 5.0 + + +def default_config() -> config_dict.ConfigDict: + return config_dict.create( + ctrl_dt=0.02, + sim_dt=0.005, + episode_length=1000, + action_repeat=1, + sample_initial_position=False, + vision=False, + ) + + +def _find_non_contacting_height( + mjx_model, data, orientation, x_pos=0.0, y_pos=0.0 +): + def body_fn(state): + z_pos, num_contacts, num_attempts, _ = state + qpos = data.qpos.at[:3].set(jp.array([x_pos, y_pos, z_pos])) + qpos = qpos.at[3:7].set(jp.array(orientation)) + ndata = data.replace(qpos=qpos) + ndata = mjx.forward(mjx_model, ndata) + num_contacts = ndata.ncon + z_pos += 0.01 + num_attempts += 1 + return (z_pos, num_contacts, num_attempts, ndata) + + initial_state = (0.0, 1, 0, data) # (z_pos, num_contacts, num_attempts) + *_, num_attemps, ndata = jax.lax.while_loop( + lambda state: jp.greater(state[1], 0) & jp.less_equal(state[2], 10000), + body_fn, + initial_state, + ) + ndata = jax.tree_map( + lambda x, y: jp.where(jp.less(num_attemps, 10000), x, y), ndata, data + ) + return ndata + + +class Quadruped(mjx_env.MjxEnv): + """Quadruped environment.""" + + def __init__( + self, + desired_speed: float, + config: config_dict.ConfigDict = default_config(), + config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None, + ): + super().__init__(config, config_overrides) + if self._config.vision: + raise NotImplementedError("Vision not implemented for Quadruped.") + self._desired_speed = desired_speed + self._xml_path = _XML_PATH.as_posix() + self._mj_model = mujoco.MjModel.from_xml_string( + _XML_PATH.read_text(), common.get_assets() + ) + self._mj_model.opt.timestep = self.sim_dt + self._mjx_model = mjx.put_model(self._mj_model) + self._post_init() + + def _post_init(self): + self._force_torque_names = [ + f"{f}_toe_{pos}_{side}" + for (f, pos, side) in product( + ("force", "torque"), ("front", "back"), ("left", "right") + ) + ] + self._torso_id = self._mj_model.body("torso").id + + def reset(self, rng: jax.Array) -> mjx_env.State: + data = mjx_env.init(self.mjx_model) + if self._config.sample_initial_position: + data = _find_non_contacting_height( + self.mjx_model, + data=mjx_env.make_data(self._mj_model), + orientation=(0.0, 0.0, 0.0, 1.0), + x_pos=0.0, + y_pos=0.0, + ) + metrics = {"reward/upright": jp.zeros(()), "reward/move": jp.zeros(())} + info = {"rng": rng} + reward, done = jp.zeros(2) + obs = self._get_obs(data, info) + return mjx_env.State(data, obs, reward, done, metrics, info) + + def step(self, state: mjx_env.State, action: jax.Array) -> mjx_env.State: + lower, upper = ( + self._mj_model.actuator_ctrlrange[:, 0], + self._mj_model.actuator_ctrlrange[:, 1], + ) + action = (action + 1.0) / 2.0 * (upper - lower) + lower + data = mjx_env.step(self.mjx_model, state.data, action, self.n_substeps) + reward = self._get_reward(data, action, state.info, state.metrics) + obs = self._get_obs(data, state.info) + done = jp.isnan(data.qpos).any() | jp.isnan(data.qvel).any() + done = done.astype(float) + return mjx_env.State(data, obs, reward, done, state.metrics, state.info) + + def _get_obs(self, data: mjx.Data, info: dict[str, Any]) -> jax.Array: + del info + ego = self._egocentric_state(data) + torso_vel = self._torso_velocity(data) + upright = self._torso_upright(data) + imu = self._imu(data) + force_torque = self._force_torque(data) + return jp.hstack((ego, torso_vel, upright, imu, force_torque)) + + def _get_reward( + self, + data: mjx.Data, + action: jax.Array, + info: dict[str, Any], + metrics: dict[str, Any], + ) -> jax.Array: + del info, action + move_reward = reward.tolerance( + self._torso_velocity(data)[0], + bounds=(self._desired_speed, float("inf")), + sigmoid="linear", + margin=self._desired_speed, + value_at_margin=0.5, + ) + upright_reward = self._upright_reward(data) + metrics["reward/move"] = move_reward + metrics["reward/upright"] = upright_reward + return move_reward * upright_reward + + def _upright_reward(self, data: mjx.Data) -> jax.Array: + upright = self._torso_upright(data) + return reward.tolerance( + upright, + bounds=(1, float("inf")), + sigmoid="linear", + margin=2, + value_at_margin=0, + ) + + def _egocentric_state(self, data: mjx.Data) -> jax.Array: + """Returns the egocentric state of the quadruped, combining position, velocity, and action.""" + return jp.hstack((data.qpos[7:], data.qvel[7:], data.act)) + + def _torso_upright(self, data: mjx.Data) -> jax.Array: + """Returns the upright orientation of the torso, extracted from the torso's rotation matrix.""" + return data.xmat[self._torso_id, 2, 2] + + def _torso_velocity(self, data: mjx.Data) -> jax.Array: + """Returns the torso velocity using the velocimeter sensor data.""" + return mjx_env.get_sensor_data(self.mj_model, data, "velocimeter") + + def _imu(self, data: mjx.Data) -> jax.Array: + """Returns the combined gyroscope and accelerometer data from the IMU sensor.""" + gyro = mjx_env.get_sensor_data(self.mj_model, data, "imu_gyro") + accelerometer = mjx_env.get_sensor_data(self.mj_model, data, "imu_accel") + return jp.hstack((gyro, accelerometer)) + + def _force_torque(self, data: mjx.Data) -> jax.Array: + """Returns the combined force and torque data from the specified sensor names.""" + return jp.hstack([ + mjx_env.get_sensor_data(self.mj_model, data, name) + for name in self._force_torque_names + ]) + + @property + def xml_path(self) -> str: + return self._xml_path + + @property + def action_size(self) -> int: + return self.mjx_model.nu + + @property + def mj_model(self) -> mujoco.MjModel: + return self._mj_model + + @property + def mjx_model(self) -> mjx.Model: + return self._mjx_model diff --git a/mujoco_playground/_src/dm_control_suite/xmls/quadruped.xml b/mujoco_playground/_src/dm_control_suite/xmls/quadruped.xml new file mode 100644 index 000000000..4f9eb4ba7 --- /dev/null +++ b/mujoco_playground/_src/dm_control_suite/xmls/quadruped.xml @@ -0,0 +1,231 @@ + + + + + \ No newline at end of file