Skip to content

Commit 229fb1f

Browse files
btabacopybara-github
authored andcommitted
Playground updates for manipulation environments using Warp.
PiperOrigin-RevId: 800484715 Change-Id: I64880e02d73abbf225119e23d862dd162bd41803
1 parent 4b7a9e4 commit 229fb1f

File tree

10 files changed

+59
-29
lines changed

10 files changed

+59
-29
lines changed

learning/train_jax_ppo.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,16 +166,16 @@
166166
def get_rl_config(env_name: str) -> config_dict.ConfigDict:
167167
if env_name in mujoco_playground.manipulation._envs:
168168
if _VISION.value:
169-
return manipulation_params.brax_vision_ppo_config(env_name)
170-
return manipulation_params.brax_ppo_config(env_name)
169+
return manipulation_params.brax_vision_ppo_config(env_name, _IMPL.value)
170+
return manipulation_params.brax_ppo_config(env_name, _IMPL.value)
171171
elif env_name in mujoco_playground.locomotion._envs:
172-
if _VISION.value:
173-
return locomotion_params.brax_vision_ppo_config(env_name)
174-
return locomotion_params.brax_ppo_config(env_name)
172+
return locomotion_params.brax_ppo_config(env_name, _IMPL.value)
175173
elif env_name in mujoco_playground.dm_control_suite._envs:
176174
if _VISION.value:
177-
return dm_control_suite_params.brax_vision_ppo_config(env_name)
178-
return dm_control_suite_params.brax_ppo_config(env_name)
175+
return dm_control_suite_params.brax_vision_ppo_config(
176+
env_name, _IMPL.value
177+
)
178+
return dm_control_suite_params.brax_ppo_config(env_name, _IMPL.value)
179179

180180
raise ValueError(f"Env {env_name} not found in {registry.ALL_ENVS}.")
181181

mujoco_playground/_src/locomotion/berkeley_humanoid/xmls/scene_mjx_feetonly_rough_terrain.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
<!-- https://polyhaven.com/a/rock_face -->
1616
<texture type="2d" name="groundplane" file="assets/rocky_texture.png"/>
1717
<material name="groundplane" texture="groundplane" texuniform="true" texrepeat="5 5" reflectance=".8"/>
18-
<hfield name="hfield" file="assets/hfield.png" size="10 10 .05 0.1"/>
18+
<hfield name="hfield" file="assets/hfield.png" size="10 10 .05 1.0"/>
1919
</asset>
2020

2121
<worldbody>

mujoco_playground/_src/locomotion/g1/joystick.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def default_config() -> config_dict.ConfigDict:
103103
ang_vel_yaw=[-1.0, 1.0],
104104
impl="jax",
105105
nconmax=8 * 8192,
106-
njmax=29 + 8 * 4,
106+
njmax=29 * 2 + 8 * 4,
107107
)
108108

109109

@@ -118,7 +118,7 @@ def __init__(
118118
):
119119
if task.startswith("rough"):
120120
config.nconmax = 100 * 8192
121-
config.njmax = 29 + 100 * 4
121+
config.njmax = 29 * 2 + 100 * 4
122122
super().__init__(
123123
xml_path=consts.task_to_xml(task).as_posix(),
124124
config=config,

mujoco_playground/_src/locomotion/g1/xmls/scene_mjx_feetonly_rough_terrain.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
<!-- https://polyhaven.com/a/rock_face -->
1717
<texture type="2d" name="groundplane" file="assets/rocky_texture.png"/>
1818
<material name="groundplane" texture="groundplane" texuniform="true" texrepeat="5 5" reflectance=".8"/>
19-
<hfield name="hfield" file="assets/hfield.png" size="10 10 .05 0.1"/>
19+
<hfield name="hfield" file="assets/hfield.png" size="10 10 .05 1.0"/>
2020
</asset>
2121

2222
<worldbody>

mujoco_playground/_src/locomotion/go1/xmls/scene_mjx_feetonly_rough_terrain.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
<!-- https://polyhaven.com/a/rock_face -->
1616
<texture type="2d" name="groundplane" file="assets/rocky_texture.png"/>
1717
<material name="groundplane" texture="groundplane" texuniform="true" texrepeat="5 5" reflectance=".8"/>
18-
<hfield name="hfield" file="assets/hfield.png" size="10 10 .05 0.1"/>
18+
<hfield name="hfield" file="assets/hfield.png" size="10 10 .05 1.0"/>
1919
</asset>
2020

2121
<worldbody>

mujoco_playground/_src/locomotion/t1/xmls/scene_mjx_feetonly_rough_terrain.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
<!-- https://polyhaven.com/a/rock_face -->
1616
<texture type="2d" name="groundplane" file="assets/rocky_texture.png"/>
1717
<material name="groundplane" texture="groundplane" texuniform="true" texrepeat="5 5" reflectance=".8"/>
18-
<hfield name="hfield" file="assets/hfield.png" size="10 10 .05 0.1"/>
18+
<hfield name="hfield" file="assets/hfield.png" size="10 10 .05 1.0"/>
1919
</asset>
2020

2121
<worldbody>

mujoco_playground/_src/manipulation/aloha/single_peg_insertion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def default_config() -> config_dict.ConfigDict:
4848
peg_insertion_reward=8,
4949
)
5050
),
51-
impl='jax',
51+
impl="jax",
5252
nconmax=24 * 8192,
5353
njmax=256,
5454
)

mujoco_playground/config/dm_control_suite_params.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
# ==============================================================================
1515
"""RL config for DM Control Suite."""
1616

17+
from typing import Optional
1718
from ml_collections import config_dict
18-
1919
from mujoco_playground._src import dm_control_suite
2020

2121

22-
def brax_ppo_config(env_name: str) -> config_dict.ConfigDict:
22+
def brax_ppo_config(
23+
env_name: str, impl: Optional[str] = None
24+
) -> config_dict.ConfigDict:
2325
"""Returns tuned Brax PPO config for the given environment."""
2426
env_config = dm_control_suite.get_default_config(env_name)
2527

@@ -38,6 +40,7 @@ def brax_ppo_config(env_name: str) -> config_dict.ConfigDict:
3840
entropy_cost=1e-2,
3941
num_envs=2048,
4042
batch_size=1024,
43+
num_resets_per_eval=10,
4144
)
4245

4346
if env_name.startswith("AcrobotSwingup"):
@@ -57,7 +60,9 @@ def brax_ppo_config(env_name: str) -> config_dict.ConfigDict:
5760
return rl_config
5861

5962

60-
def brax_vision_ppo_config(env_name: str) -> config_dict.ConfigDict:
63+
def brax_vision_ppo_config(
64+
env_name: str, unused_impl: Optional[str] = None
65+
) -> config_dict.ConfigDict:
6166
"""Returns tuned Brax Vision PPO config for the given environment."""
6267
env_config = dm_control_suite.get_default_config(env_name)
6368

@@ -80,6 +85,7 @@ def brax_vision_ppo_config(env_name: str) -> config_dict.ConfigDict:
8085
num_eval_envs=1024,
8186
batch_size=256,
8287
max_grad_norm=1.0,
88+
num_resets_per_eval=10,
8389
)
8490

8591
if env_name != "CartpoleBalance":
@@ -88,7 +94,9 @@ def brax_vision_ppo_config(env_name: str) -> config_dict.ConfigDict:
8894
return rl_config
8995

9096

91-
def brax_sac_config(env_name: str) -> config_dict.ConfigDict:
97+
def brax_sac_config(
98+
env_name: str, unused_impl: Optional[str] = None
99+
) -> config_dict.ConfigDict:
92100
"""Returns tuned Brax SAC config for the given environment."""
93101
env_config = dm_control_suite.get_default_config(env_name)
94102

@@ -109,6 +117,7 @@ def brax_sac_config(env_name: str) -> config_dict.ConfigDict:
109117
network_factory=config_dict.create(
110118
q_network_layer_norm=True,
111119
),
120+
num_resets_per_eval=10,
112121
)
113122

114123
if env_name == "PendulumSwingUp":

mujoco_playground/config/locomotion_params.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
# ==============================================================================
1515
"""RL config for Locomotion envs."""
1616

17+
from typing import Optional
1718
from ml_collections import config_dict
18-
1919
from mujoco_playground._src import locomotion
2020

2121

22-
def brax_ppo_config(env_name: str) -> config_dict.ConfigDict:
22+
def brax_ppo_config(
23+
env_name: str, impl: Optional[str] = None
24+
) -> config_dict.ConfigDict:
2325
"""Returns tuned Brax PPO config for the given environment."""
2426
env_config = locomotion.get_default_config(env_name)
2527

@@ -45,12 +47,12 @@ def brax_ppo_config(env_name: str) -> config_dict.ConfigDict:
4547
policy_obs_key="state",
4648
value_obs_key="state",
4749
),
50+
num_resets_per_eval=10,
4851
)
4952

5053
if env_name in ("Go1JoystickFlatTerrain", "Go1JoystickRoughTerrain"):
5154
rl_config.num_timesteps = 200_000_000
5255
rl_config.num_evals = 10
53-
rl_config.num_resets_per_eval = 1
5456
rl_config.network_factory = config_dict.create(
5557
policy_hidden_layer_sizes=(512, 256, 128),
5658
value_hidden_layer_sizes=(512, 256, 128),
@@ -109,7 +111,6 @@ def brax_ppo_config(env_name: str) -> config_dict.ConfigDict:
109111
rl_config.num_timesteps = 150_000_000
110112
rl_config.num_evals = 15
111113
rl_config.clipping_epsilon = 0.2
112-
rl_config.num_resets_per_eval = 1
113114
rl_config.entropy_cost = 0.005
114115
rl_config.network_factory = config_dict.create(
115116
policy_hidden_layer_sizes=(512, 256, 128),
@@ -163,7 +164,9 @@ def brax_ppo_config(env_name: str) -> config_dict.ConfigDict:
163164
return rl_config
164165

165166

166-
def rsl_rl_config(env_name: str) -> config_dict.ConfigDict:
167+
def rsl_rl_config(
168+
env_name: str, unused_impl: Optional[str] = None
169+
) -> config_dict.ConfigDict:
167170
"""Returns tuned RSL-RL PPO config for the given environment."""
168171

169172
rl_config = config_dict.create(

mujoco_playground/config/manipulation_params.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
# ==============================================================================
1515
"""RL config for Manipulation envs."""
1616

17+
from typing import Optional
1718
from ml_collections import config_dict
18-
1919
from mujoco_playground._src import manipulation
2020

2121

22-
def brax_ppo_config(env_name: str) -> config_dict.ConfigDict:
22+
def brax_ppo_config(
23+
env_name: str, impl: Optional[str] = None
24+
) -> config_dict.ConfigDict:
2325
"""Returns tuned Brax PPO config for the given environment."""
2426
env_config = manipulation.get_default_config(env_name)
2527

@@ -34,10 +36,11 @@ def brax_ppo_config(env_name: str) -> config_dict.ConfigDict:
3436
policy_obs_key="state",
3537
value_obs_key="state",
3638
),
39+
num_resets_per_eval=10,
3740
)
3841
if env_name == "AlohaHandOver":
3942
rl_config.num_timesteps = 100_000_000
40-
rl_config.num_evals = int(rl_config.num_timesteps / 4_000_000)
43+
rl_config.num_evals = 25
4144
rl_config.unroll_length = 15
4245
rl_config.num_minibatches = 32
4346
rl_config.num_updates_per_batch = 8
@@ -61,6 +64,9 @@ def brax_ppo_config(env_name: str) -> config_dict.ConfigDict:
6164
rl_config.num_envs = 1024
6265
rl_config.batch_size = 512
6366
rl_config.network_factory.policy_hidden_layer_sizes = (256, 256, 256, 256)
67+
if impl == "warp":
68+
rl_config.num_timesteps *= 3
69+
rl_config.num_evals *= 3
6470
elif env_name == "PandaOpenCabinet":
6571
rl_config.num_timesteps = 40_000_000
6672
rl_config.num_evals = 4
@@ -73,7 +79,6 @@ def brax_ppo_config(env_name: str) -> config_dict.ConfigDict:
7379
rl_config.num_envs = 2048
7480
rl_config.batch_size = 512
7581
rl_config.network_factory.policy_hidden_layer_sizes = (32, 32, 32, 32)
76-
rl_config.num_resets_per_eval = 1
7782
elif env_name == "PandaPickCubeCartesian":
7883
rl_config.num_timesteps = 5_000_000
7984
rl_config.num_evals = 5
@@ -89,6 +94,9 @@ def brax_ppo_config(env_name: str) -> config_dict.ConfigDict:
8994
rl_config.network_factory.policy_hidden_layer_sizes = (256, 256)
9095
rl_config.num_resets_per_eval = 1
9196
rl_config.max_grad_norm = 1.0
97+
if impl == "warp":
98+
rl_config.num_timesteps *= 4
99+
rl_config.num_evals *= 4
92100
elif env_name.startswith("PandaPickCube"):
93101
rl_config.num_timesteps = 20_000_000
94102
rl_config.num_evals = 4
@@ -101,6 +109,9 @@ def brax_ppo_config(env_name: str) -> config_dict.ConfigDict:
101109
rl_config.num_envs = 2048
102110
rl_config.batch_size = 512
103111
rl_config.network_factory.policy_hidden_layer_sizes = (32, 32, 32, 32)
112+
if impl == "warp":
113+
rl_config.num_timesteps *= 4
114+
rl_config.num_evals *= 4
104115
elif env_name == "PandaRobotiqPushCube":
105116
rl_config.num_timesteps = 1_800_000_000
106117
rl_config.num_evals = 10
@@ -115,6 +126,10 @@ def brax_ppo_config(env_name: str) -> config_dict.ConfigDict:
115126
rl_config.num_resets_per_eval = 1
116127
rl_config.num_eval_envs = 32
117128
rl_config.network_factory.policy_hidden_layer_sizes = (64, 64, 64, 64)
129+
if impl == "warp":
130+
rl_config.num_resets_per_eval = 10
131+
rl_config.num_timesteps = int(rl_config.num_timesteps * 1.5)
132+
rl_config.num_evals = int(rl_config.num_evals * 1.5)
118133
elif env_name == "LeapCubeRotateZAxis":
119134
rl_config.num_timesteps = 100_000_000
120135
rl_config.num_evals = 10
@@ -157,7 +172,9 @@ def brax_ppo_config(env_name: str) -> config_dict.ConfigDict:
157172
return rl_config
158173

159174

160-
def brax_vision_ppo_config(env_name: str) -> config_dict.ConfigDict:
175+
def brax_vision_ppo_config(
176+
env_name: str, unused_impl: Optional[str] = None
177+
) -> config_dict.ConfigDict:
161178
"""Returns tuned Brax Vision PPO config for the given environment."""
162179
env_config = manipulation.get_default_config(env_name)
163180

@@ -171,6 +188,7 @@ def brax_vision_ppo_config(env_name: str) -> config_dict.ConfigDict:
171188
network_factory=config_dict.create(
172189
policy_hidden_layer_sizes=(32, 32, 32, 32)
173190
),
191+
num_resets_per_eval=10,
174192
)
175193

176194
if env_name == "PandaPickCubeCartesian":
@@ -192,7 +210,7 @@ def brax_vision_ppo_config(env_name: str) -> config_dict.ConfigDict:
192210
return rl_config
193211

194212

195-
def rsl_rl_config(env_name: str) -> config_dict.ConfigDict: # pylint: disable=unused-argument
213+
def rsl_rl_config(env_name: str, unused_impl: Optional[str] = None) -> config_dict.ConfigDict: # pylint: disable=unused-argument
196214
"""Returns tuned RSL-RL PPO config for the given environment."""
197215

198216
rl_config = config_dict.create(

0 commit comments

Comments
 (0)