Skip to content

Commit 83dd454

Browse files
committed
Formatting
1 parent 72f467e commit 83dd454

File tree

9 files changed

+316
-374
lines changed

9 files changed

+316
-374
lines changed

mujoco_playground/_src/dm_control_suite/__init__.py

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,21 @@
4040
"AcrobotSwingupSparse": partial(acrobot.Balance, sparse=True),
4141
"BallInCup": ball_in_cup.BallInCup,
4242
"CartpoleBalance": partial(cartpole.Balance, swing_up=False, sparse=False),
43-
"CartpoleBalanceSparse": partial(cartpole.Balance, swing_up=False, sparse=True),
43+
"CartpoleBalanceSparse": partial(
44+
cartpole.Balance, swing_up=False, sparse=True
45+
),
4446
"CartpoleSwingup": partial(cartpole.Balance, swing_up=True, sparse=False),
45-
"CartpoleSwingupSparse": partial(cartpole.Balance, swing_up=True, sparse=True),
47+
"CartpoleSwingupSparse": partial(
48+
cartpole.Balance, swing_up=True, sparse=True
49+
),
4650
"CheetahRun": cheetah.Run,
4751
"FingerSpin": finger.Spin,
48-
"FingerTurnEasy": partial(finger.Turn, target_radius=finger.EASY_TARGET_SIZE),
49-
"FingerTurnHard": partial(finger.Turn, target_radius=finger.HARD_TARGET_SIZE),
52+
"FingerTurnEasy": partial(
53+
finger.Turn, target_radius=finger.EASY_TARGET_SIZE
54+
),
55+
"FingerTurnHard": partial(
56+
finger.Turn, target_radius=finger.HARD_TARGET_SIZE
57+
),
5058
"FishSwim": fish.Swim,
5159
"HopperHop": partial(hopper.Hopper, hopping=True),
5260
"HopperStand": partial(hopper.Hopper, hopping=False),
@@ -99,54 +107,54 @@
99107

100108

101109
def __getattr__(name):
102-
if name == "ALL_ENVS":
103-
return tuple(_envs.keys())
104-
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
110+
if name == "ALL_ENVS":
111+
return tuple(_envs.keys())
112+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
105113

106114

107115
def register_environment(
108116
env_name: str,
109117
env_class: Type[mjx_env.MjxEnv],
110118
cfg_class: Callable[[], config_dict.ConfigDict],
111119
) -> None:
112-
"""Register a new environment.
120+
"""Register a new environment.
113121
114-
Args:
115-
env_name: The name of the environment.
116-
env_class: The environment class.
117-
cfg_class: The default configuration
118-
"""
119-
_envs[env_name] = env_class
120-
_cfgs[env_name] = cfg_class
122+
Args:
123+
env_name: The name of the environment.
124+
env_class: The environment class.
125+
cfg_class: The default configuration
126+
"""
127+
_envs[env_name] = env_class
128+
_cfgs[env_name] = cfg_class
121129

122130

123131
def get_default_config(env_name: str) -> config_dict.ConfigDict:
124-
"""Get the default configuration for an environment."""
125-
if env_name not in _cfgs:
126-
raise ValueError(
127-
f"Env '{env_name}' not found in default configs. Available configs:"
128-
f" {list(_cfgs.keys())}"
129-
)
130-
return _cfgs[env_name]()
132+
"""Get the default configuration for an environment."""
133+
if env_name not in _cfgs:
134+
raise ValueError(
135+
f"Env '{env_name}' not found in default configs. Available configs:"
136+
f" {list(_cfgs.keys())}"
137+
)
138+
return _cfgs[env_name]()
131139

132140

133141
def load(
134142
env_name: str,
135143
config: Optional[config_dict.ConfigDict] = None,
136144
config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None,
137145
) -> mjx_env.MjxEnv:
138-
"""Get an environment instance with the given configuration.
139-
140-
Args:
141-
env_name: The name of the environment.
142-
config: The configuration to use. If not provided, the default
143-
configuration is used.
144-
config_overrides: A dictionary of overrides for the configuration.
145-
146-
Returns:
147-
An instance of the environment.
148-
"""
149-
if env_name not in _envs:
150-
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
151-
config = config or get_default_config(env_name)
152-
return _envs[env_name](config=config, config_overrides=config_overrides)
146+
"""Get an environment instance with the given configuration.
147+
148+
Args:
149+
env_name: The name of the environment.
150+
config: The configuration to use. If not provided, the default
151+
configuration is used.
152+
config_overrides: A dictionary of overrides for the configuration.
153+
154+
Returns:
155+
An instance of the environment.
156+
"""
157+
if env_name not in _envs:
158+
raise ValueError(f"Env '{env_name}' not found. Available envs: {_cfgs.keys()}")
159+
config = config or get_default_config(env_name)
160+
return _envs[env_name](config=config, config_overrides=config_overrides)

mujoco_playground/_src/locomotion/apollo/base.py

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@
1616

1717
from typing import Any, Dict, Optional, Union
1818

19-
from etils import epath
2019
import jax
2120
import jax.numpy as jp
22-
from ml_collections import config_dict
2321
import mujoco
24-
from mujoco import mjx
2522
import numpy as np
23+
from etils import epath
24+
from ml_collections import config_dict
25+
from mujoco import mjx
2626

2727
from mujoco_playground._src import mjx_env
28-
from mujoco_playground._src.collision import geoms_colliding
2928
from mujoco_playground._src.locomotion.apollo import constants as consts
29+
from mujoco_playground._src.collision import geoms_colliding
3030

3131

3232
def get_assets() -> Dict[str, bytes]:
@@ -46,15 +46,15 @@ class ApolloEnv(mjx_env.MjxEnv):
4646
"""Base class for Apollo environments."""
4747

4848
def __init__(
49-
self,
50-
xml_path: str,
51-
config: config_dict.ConfigDict,
52-
config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None,
49+
self,
50+
xml_path: str,
51+
config: config_dict.ConfigDict,
52+
config_overrides: Optional[Dict[str, Union[str, int, list[Any]]]] = None,
5353
) -> None:
5454
super().__init__(config, config_overrides)
5555

5656
self._mj_model = mujoco.MjModel.from_xml_string(
57-
epath.Path(xml_path).read_text(), assets=get_assets()
57+
epath.Path(xml_path).read_text(), assets=get_assets()
5858
)
5959
self._mj_model.opt.timestep = self.sim_dt
6060

@@ -66,9 +66,7 @@ def __init__(
6666

6767
self._init_q = jp.array(self._mj_model.keyframe("knees_bent").qpos)
6868
self._default_ctrl = jp.array(self._mj_model.keyframe("knees_bent").ctrl)
69-
self._default_pose = jp.array(
70-
self._mj_model.keyframe("knees_bent").qpos[7:]
71-
)
69+
self._default_pose = jp.array(self._mj_model.keyframe("knees_bent").qpos[7:])
7270
self._actuator_torques = self.mj_model.jnt_actfrcrange[1:, 1]
7371

7472
# Body IDs.
@@ -77,64 +75,52 @@ def __init__(
7775
# Geom IDs.
7876
self._floor_geom_id = self._mj_model.geom("floor").id
7977
self._left_feet_geom_id = np.array(
80-
[self._mj_model.geom(name).id for name in consts.LEFT_FEET_GEOMS]
78+
[self._mj_model.geom(name).id for name in consts.LEFT_FEET_GEOMS]
8179
)
8280
self._right_feet_geom_id = np.array(
83-
[self._mj_model.geom(name).id for name in consts.RIGHT_FEET_GEOMS]
81+
[self._mj_model.geom(name).id for name in consts.RIGHT_FEET_GEOMS]
8482
)
8583
self._left_hand_geom_id = self._mj_model.geom("collision_l_hand_plate").id
8684
self._right_hand_geom_id = self._mj_model.geom("collision_r_hand_plate").id
8785
self._left_foot_geom_id = self._mj_model.geom("collision_l_sole").id
8886
self._right_foot_geom_id = self._mj_model.geom("collision_r_sole").id
89-
self._left_shin_geom_id = self._mj_model.geom(
90-
"collision_capsule_body_l_shin"
91-
).id
92-
self._right_shin_geom_id = self._mj_model.geom(
93-
"collision_capsule_body_r_shin"
94-
).id
95-
self._left_thigh_geom_id = self._mj_model.geom(
96-
"collision_capsule_body_l_thigh"
97-
).id
98-
self._right_thigh_geom_id = self._mj_model.geom(
99-
"collision_capsule_body_r_thigh"
100-
).id
87+
self._left_shin_geom_id = self._mj_model.geom("collision_capsule_body_l_shin").id
88+
self._right_shin_geom_id = self._mj_model.geom("collision_capsule_body_r_shin").id
89+
self._left_thigh_geom_id = self._mj_model.geom("collision_capsule_body_l_thigh").id
90+
self._right_thigh_geom_id = self._mj_model.geom("collision_capsule_body_r_thigh").id
10191

10292
# Site IDs.
10393
self._imu_site_id = self._mj_model.site("imu").id
10494
self._feet_site_id = np.array(
105-
[self._mj_model.site(name).id for name in consts.FEET_SITES]
95+
[self._mj_model.site(name).id for name in consts.FEET_SITES]
10696
)
10797

10898
# Sensor readings.
10999

110100
def get_gravity(self, data: mjx.Data) -> jax.Array:
111101
"""Return the gravity vector in the world frame."""
112-
return mjx_env.get_sensor_data(
113-
self.mj_model, data, f"{consts.GRAVITY_SENSOR}"
114-
)
102+
return mjx_env.get_sensor_data(self.mj_model, data, f"{consts.GRAVITY_SENSOR}")
115103

116104
def get_global_linvel(self, data: mjx.Data) -> jax.Array:
117105
"""Return the linear velocity of the robot in the world frame."""
118106
return mjx_env.get_sensor_data(
119-
self.mj_model, data, f"{consts.GLOBAL_LINVEL_SENSOR}"
107+
self.mj_model, data, f"{consts.GLOBAL_LINVEL_SENSOR}"
120108
)
121109

122110
def get_global_angvel(self, data: mjx.Data) -> jax.Array:
123111
"""Return the angular velocity of the robot in the world frame."""
124112
return mjx_env.get_sensor_data(
125-
self.mj_model, data, f"{consts.GLOBAL_ANGVEL_SENSOR}"
113+
self.mj_model, data, f"{consts.GLOBAL_ANGVEL_SENSOR}"
126114
)
127115

128116
def get_local_linvel(self, data: mjx.Data) -> jax.Array:
129117
"""Return the linear velocity of the robot in the local frame."""
130-
return mjx_env.get_sensor_data(
131-
self.mj_model, data, f"{consts.LOCAL_LINVEL_SENSOR}"
132-
)
118+
return mjx_env.get_sensor_data(self.mj_model, data, f"{consts.LOCAL_LINVEL_SENSOR}")
133119

134120
def get_accelerometer(self, data: mjx.Data) -> jax.Array:
135121
"""Return the accelerometer readings in the local frame."""
136122
return mjx_env.get_sensor_data(
137-
self.mj_model, data, f"{consts.ACCELEROMETER_SENSOR}"
123+
self.mj_model, data, f"{consts.ACCELEROMETER_SENSOR}"
138124
)
139125

140126
def get_gyro(self, data: mjx.Data) -> jax.Array:
@@ -143,14 +129,18 @@ def get_gyro(self, data: mjx.Data) -> jax.Array:
143129

144130
def get_feet_ground_contacts(self, data: mjx.Data) -> jax.Array:
145131
"""Return an array indicating whether each foot is in contact with the ground."""
146-
left_feet_contact = jp.array([
132+
left_feet_contact = jp.array(
133+
[
147134
geoms_colliding(data, geom_id, self._floor_geom_id)
148135
for geom_id in self._left_feet_geom_id
149-
])
150-
right_feet_contact = jp.array([
136+
]
137+
)
138+
right_feet_contact = jp.array(
139+
[
151140
geoms_colliding(data, geom_id, self._floor_geom_id)
152141
for geom_id in self._right_feet_geom_id
153-
])
142+
]
143+
)
154144
return jp.hstack([jp.any(left_feet_contact), jp.any(right_feet_contact)])
155145

156146
# Accessors.

mujoco_playground/_src/locomotion/apollo/constants.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,18 @@
2525

2626
def task_to_xml(task_name: str) -> epath.Path:
2727
return {
28-
"flat_terrain": FEET_ONLY_FLAT_TERRAIN_XML,
28+
"flat_terrain": FEET_ONLY_FLAT_TERRAIN_XML,
2929
}[task_name]
3030

3131

3232
FEET_SITES = [
33-
"l_foot",
34-
"r_foot",
33+
"l_foot",
34+
"r_foot",
3535
]
3636

3737
HAND_SITES = [
38-
"left_palm",
39-
"right_palm",
38+
"left_palm",
39+
"right_palm",
4040
]
4141

4242
LEFT_FEET_GEOMS = ["collision_l_sole"]

0 commit comments

Comments
 (0)