16
16
17
17
from typing import Any , Dict , Optional , Union
18
18
19
- from etils import epath
20
19
import jax
21
20
import jax .numpy as jp
22
- from ml_collections import config_dict
23
21
import mujoco
24
- from mujoco import mjx
25
22
import numpy as np
23
+ from etils import epath
24
+ from ml_collections import config_dict
25
+ from mujoco import mjx
26
26
27
27
from mujoco_playground ._src import mjx_env
28
- from mujoco_playground ._src .collision import geoms_colliding
29
28
from mujoco_playground ._src .locomotion .apollo import constants as consts
29
+ from mujoco_playground ._src .collision import geoms_colliding
30
30
31
31
32
32
def get_assets () -> Dict [str , bytes ]:
@@ -46,15 +46,15 @@ class ApolloEnv(mjx_env.MjxEnv):
46
46
"""Base class for Apollo environments."""
47
47
48
48
def __init__ (
49
- self ,
50
- xml_path : str ,
51
- config : config_dict .ConfigDict ,
52
- config_overrides : Optional [Dict [str , Union [str , int , list [Any ]]]] = None ,
49
+ self ,
50
+ xml_path : str ,
51
+ config : config_dict .ConfigDict ,
52
+ config_overrides : Optional [Dict [str , Union [str , int , list [Any ]]]] = None ,
53
53
) -> None :
54
54
super ().__init__ (config , config_overrides )
55
55
56
56
self ._mj_model = mujoco .MjModel .from_xml_string (
57
- epath .Path (xml_path ).read_text (), assets = get_assets ()
57
+ epath .Path (xml_path ).read_text (), assets = get_assets ()
58
58
)
59
59
self ._mj_model .opt .timestep = self .sim_dt
60
60
@@ -66,9 +66,7 @@ def __init__(
66
66
67
67
self ._init_q = jp .array (self ._mj_model .keyframe ("knees_bent" ).qpos )
68
68
self ._default_ctrl = jp .array (self ._mj_model .keyframe ("knees_bent" ).ctrl )
69
- self ._default_pose = jp .array (
70
- self ._mj_model .keyframe ("knees_bent" ).qpos [7 :]
71
- )
69
+ self ._default_pose = jp .array (self ._mj_model .keyframe ("knees_bent" ).qpos [7 :])
72
70
self ._actuator_torques = self .mj_model .jnt_actfrcrange [1 :, 1 ]
73
71
74
72
# Body IDs.
@@ -77,64 +75,52 @@ def __init__(
77
75
# Geom IDs.
78
76
self ._floor_geom_id = self ._mj_model .geom ("floor" ).id
79
77
self ._left_feet_geom_id = np .array (
80
- [self ._mj_model .geom (name ).id for name in consts .LEFT_FEET_GEOMS ]
78
+ [self ._mj_model .geom (name ).id for name in consts .LEFT_FEET_GEOMS ]
81
79
)
82
80
self ._right_feet_geom_id = np .array (
83
- [self ._mj_model .geom (name ).id for name in consts .RIGHT_FEET_GEOMS ]
81
+ [self ._mj_model .geom (name ).id for name in consts .RIGHT_FEET_GEOMS ]
84
82
)
85
83
self ._left_hand_geom_id = self ._mj_model .geom ("collision_l_hand_plate" ).id
86
84
self ._right_hand_geom_id = self ._mj_model .geom ("collision_r_hand_plate" ).id
87
85
self ._left_foot_geom_id = self ._mj_model .geom ("collision_l_sole" ).id
88
86
self ._right_foot_geom_id = self ._mj_model .geom ("collision_r_sole" ).id
89
- self ._left_shin_geom_id = self ._mj_model .geom (
90
- "collision_capsule_body_l_shin"
91
- ).id
92
- self ._right_shin_geom_id = self ._mj_model .geom (
93
- "collision_capsule_body_r_shin"
94
- ).id
95
- self ._left_thigh_geom_id = self ._mj_model .geom (
96
- "collision_capsule_body_l_thigh"
97
- ).id
98
- self ._right_thigh_geom_id = self ._mj_model .geom (
99
- "collision_capsule_body_r_thigh"
100
- ).id
87
+ self ._left_shin_geom_id = self ._mj_model .geom ("collision_capsule_body_l_shin" ).id
88
+ self ._right_shin_geom_id = self ._mj_model .geom ("collision_capsule_body_r_shin" ).id
89
+ self ._left_thigh_geom_id = self ._mj_model .geom ("collision_capsule_body_l_thigh" ).id
90
+ self ._right_thigh_geom_id = self ._mj_model .geom ("collision_capsule_body_r_thigh" ).id
101
91
102
92
# Site IDs.
103
93
self ._imu_site_id = self ._mj_model .site ("imu" ).id
104
94
self ._feet_site_id = np .array (
105
- [self ._mj_model .site (name ).id for name in consts .FEET_SITES ]
95
+ [self ._mj_model .site (name ).id for name in consts .FEET_SITES ]
106
96
)
107
97
108
98
# Sensor readings.
109
99
110
100
def get_gravity (self , data : mjx .Data ) -> jax .Array :
111
101
"""Return the gravity vector in the world frame."""
112
- return mjx_env .get_sensor_data (
113
- self .mj_model , data , f"{ consts .GRAVITY_SENSOR } "
114
- )
102
+ return mjx_env .get_sensor_data (self .mj_model , data , f"{ consts .GRAVITY_SENSOR } " )
115
103
116
104
def get_global_linvel (self , data : mjx .Data ) -> jax .Array :
117
105
"""Return the linear velocity of the robot in the world frame."""
118
106
return mjx_env .get_sensor_data (
119
- self .mj_model , data , f"{ consts .GLOBAL_LINVEL_SENSOR } "
107
+ self .mj_model , data , f"{ consts .GLOBAL_LINVEL_SENSOR } "
120
108
)
121
109
122
110
def get_global_angvel (self , data : mjx .Data ) -> jax .Array :
123
111
"""Return the angular velocity of the robot in the world frame."""
124
112
return mjx_env .get_sensor_data (
125
- self .mj_model , data , f"{ consts .GLOBAL_ANGVEL_SENSOR } "
113
+ self .mj_model , data , f"{ consts .GLOBAL_ANGVEL_SENSOR } "
126
114
)
127
115
128
116
def get_local_linvel (self , data : mjx .Data ) -> jax .Array :
129
117
"""Return the linear velocity of the robot in the local frame."""
130
- return mjx_env .get_sensor_data (
131
- self .mj_model , data , f"{ consts .LOCAL_LINVEL_SENSOR } "
132
- )
118
+ return mjx_env .get_sensor_data (self .mj_model , data , f"{ consts .LOCAL_LINVEL_SENSOR } " )
133
119
134
120
def get_accelerometer (self , data : mjx .Data ) -> jax .Array :
135
121
"""Return the accelerometer readings in the local frame."""
136
122
return mjx_env .get_sensor_data (
137
- self .mj_model , data , f"{ consts .ACCELEROMETER_SENSOR } "
123
+ self .mj_model , data , f"{ consts .ACCELEROMETER_SENSOR } "
138
124
)
139
125
140
126
def get_gyro (self , data : mjx .Data ) -> jax .Array :
@@ -143,14 +129,18 @@ def get_gyro(self, data: mjx.Data) -> jax.Array:
143
129
144
130
def get_feet_ground_contacts (self , data : mjx .Data ) -> jax .Array :
145
131
"""Return an array indicating whether each foot is in contact with the ground."""
146
- left_feet_contact = jp .array ([
132
+ left_feet_contact = jp .array (
133
+ [
147
134
geoms_colliding (data , geom_id , self ._floor_geom_id )
148
135
for geom_id in self ._left_feet_geom_id
149
- ])
150
- right_feet_contact = jp .array ([
136
+ ]
137
+ )
138
+ right_feet_contact = jp .array (
139
+ [
151
140
geoms_colliding (data , geom_id , self ._floor_geom_id )
152
141
for geom_id in self ._right_feet_geom_id
153
- ])
142
+ ]
143
+ )
154
144
return jp .hstack ([jp .any (left_feet_contact ), jp .any (right_feet_contact )])
155
145
156
146
# Accessors.
0 commit comments