1616
1717from typing import Any , Dict , Optional , Union
1818
19- from etils import epath
2019import jax
2120import jax .numpy as jp
22- from ml_collections import config_dict
2321import mujoco
24- from mujoco import mjx
2522import numpy as np
23+ from etils import epath
24+ from ml_collections import config_dict
25+ from mujoco import mjx
2626
2727from mujoco_playground ._src import mjx_env
28- from mujoco_playground ._src .collision import geoms_colliding
2928from mujoco_playground ._src .locomotion .apollo import constants as consts
29+ from mujoco_playground ._src .collision import geoms_colliding
3030
3131
3232def get_assets () -> Dict [str , bytes ]:
@@ -46,15 +46,15 @@ class ApolloEnv(mjx_env.MjxEnv):
4646 """Base class for Apollo environments."""
4747
4848 def __init__ (
49- self ,
50- xml_path : str ,
51- config : config_dict .ConfigDict ,
52- config_overrides : Optional [Dict [str , Union [str , int , list [Any ]]]] = None ,
49+ self ,
50+ xml_path : str ,
51+ config : config_dict .ConfigDict ,
52+ config_overrides : Optional [Dict [str , Union [str , int , list [Any ]]]] = None ,
5353 ) -> None :
5454 super ().__init__ (config , config_overrides )
5555
5656 self ._mj_model = mujoco .MjModel .from_xml_string (
57- epath .Path (xml_path ).read_text (), assets = get_assets ()
57+ epath .Path (xml_path ).read_text (), assets = get_assets ()
5858 )
5959 self ._mj_model .opt .timestep = self .sim_dt
6060
@@ -66,9 +66,7 @@ def __init__(
6666
6767 self ._init_q = jp .array (self ._mj_model .keyframe ("knees_bent" ).qpos )
6868 self ._default_ctrl = jp .array (self ._mj_model .keyframe ("knees_bent" ).ctrl )
69- self ._default_pose = jp .array (
70- self ._mj_model .keyframe ("knees_bent" ).qpos [7 :]
71- )
69+ self ._default_pose = jp .array (self ._mj_model .keyframe ("knees_bent" ).qpos [7 :])
7270 self ._actuator_torques = self .mj_model .jnt_actfrcrange [1 :, 1 ]
7371
7472 # Body IDs.
@@ -77,64 +75,52 @@ def __init__(
7775 # Geom IDs.
7876 self ._floor_geom_id = self ._mj_model .geom ("floor" ).id
7977 self ._left_feet_geom_id = np .array (
80- [self ._mj_model .geom (name ).id for name in consts .LEFT_FEET_GEOMS ]
78+ [self ._mj_model .geom (name ).id for name in consts .LEFT_FEET_GEOMS ]
8179 )
8280 self ._right_feet_geom_id = np .array (
83- [self ._mj_model .geom (name ).id for name in consts .RIGHT_FEET_GEOMS ]
81+ [self ._mj_model .geom (name ).id for name in consts .RIGHT_FEET_GEOMS ]
8482 )
8583 self ._left_hand_geom_id = self ._mj_model .geom ("collision_l_hand_plate" ).id
8684 self ._right_hand_geom_id = self ._mj_model .geom ("collision_r_hand_plate" ).id
8785 self ._left_foot_geom_id = self ._mj_model .geom ("collision_l_sole" ).id
8886 self ._right_foot_geom_id = self ._mj_model .geom ("collision_r_sole" ).id
89- self ._left_shin_geom_id = self ._mj_model .geom (
90- "collision_capsule_body_l_shin"
91- ).id
92- self ._right_shin_geom_id = self ._mj_model .geom (
93- "collision_capsule_body_r_shin"
94- ).id
95- self ._left_thigh_geom_id = self ._mj_model .geom (
96- "collision_capsule_body_l_thigh"
97- ).id
98- self ._right_thigh_geom_id = self ._mj_model .geom (
99- "collision_capsule_body_r_thigh"
100- ).id
87+ self ._left_shin_geom_id = self ._mj_model .geom ("collision_capsule_body_l_shin" ).id
88+ self ._right_shin_geom_id = self ._mj_model .geom ("collision_capsule_body_r_shin" ).id
89+ self ._left_thigh_geom_id = self ._mj_model .geom ("collision_capsule_body_l_thigh" ).id
90+ self ._right_thigh_geom_id = self ._mj_model .geom ("collision_capsule_body_r_thigh" ).id
10191
10292 # Site IDs.
10393 self ._imu_site_id = self ._mj_model .site ("imu" ).id
10494 self ._feet_site_id = np .array (
105- [self ._mj_model .site (name ).id for name in consts .FEET_SITES ]
95+ [self ._mj_model .site (name ).id for name in consts .FEET_SITES ]
10696 )
10797
10898 # Sensor readings.
10999
110100 def get_gravity (self , data : mjx .Data ) -> jax .Array :
111101 """Return the gravity vector in the world frame."""
112- return mjx_env .get_sensor_data (
113- self .mj_model , data , f"{ consts .GRAVITY_SENSOR } "
114- )
102+ return mjx_env .get_sensor_data (self .mj_model , data , f"{ consts .GRAVITY_SENSOR } " )
115103
116104 def get_global_linvel (self , data : mjx .Data ) -> jax .Array :
117105 """Return the linear velocity of the robot in the world frame."""
118106 return mjx_env .get_sensor_data (
119- self .mj_model , data , f"{ consts .GLOBAL_LINVEL_SENSOR } "
107+ self .mj_model , data , f"{ consts .GLOBAL_LINVEL_SENSOR } "
120108 )
121109
122110 def get_global_angvel (self , data : mjx .Data ) -> jax .Array :
123111 """Return the angular velocity of the robot in the world frame."""
124112 return mjx_env .get_sensor_data (
125- self .mj_model , data , f"{ consts .GLOBAL_ANGVEL_SENSOR } "
113+ self .mj_model , data , f"{ consts .GLOBAL_ANGVEL_SENSOR } "
126114 )
127115
128116 def get_local_linvel (self , data : mjx .Data ) -> jax .Array :
129117 """Return the linear velocity of the robot in the local frame."""
130- return mjx_env .get_sensor_data (
131- self .mj_model , data , f"{ consts .LOCAL_LINVEL_SENSOR } "
132- )
118+ return mjx_env .get_sensor_data (self .mj_model , data , f"{ consts .LOCAL_LINVEL_SENSOR } " )
133119
134120 def get_accelerometer (self , data : mjx .Data ) -> jax .Array :
135121 """Return the accelerometer readings in the local frame."""
136122 return mjx_env .get_sensor_data (
137- self .mj_model , data , f"{ consts .ACCELEROMETER_SENSOR } "
123+ self .mj_model , data , f"{ consts .ACCELEROMETER_SENSOR } "
138124 )
139125
140126 def get_gyro (self , data : mjx .Data ) -> jax .Array :
@@ -143,14 +129,18 @@ def get_gyro(self, data: mjx.Data) -> jax.Array:
143129
144130 def get_feet_ground_contacts (self , data : mjx .Data ) -> jax .Array :
145131 """Return an array indicating whether each foot is in contact with the ground."""
146- left_feet_contact = jp .array ([
132+ left_feet_contact = jp .array (
133+ [
147134 geoms_colliding (data , geom_id , self ._floor_geom_id )
148135 for geom_id in self ._left_feet_geom_id
149- ])
150- right_feet_contact = jp .array ([
136+ ]
137+ )
138+ right_feet_contact = jp .array (
139+ [
151140 geoms_colliding (data , geom_id , self ._floor_geom_id )
152141 for geom_id in self ._right_feet_geom_id
153- ])
142+ ]
143+ )
154144 return jp .hstack ([jp .any (left_feet_contact ), jp .any (right_feet_contact )])
155145
156146 # Accessors.
0 commit comments