Skip to content

Commit 30be900

Browse files
committed
It trains
1 parent 09a1a94 commit 30be900

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

train_brax.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ def get_ppo_train_fn():
8888
from brax.training.agents.ppo import networks as ppo_networks
8989
from brax.training.agents.ppo import train as ppo
9090

91-
from mujoco_playground.config import locomotion_params
91+
from mujoco_playground.config import dm_control_suite_params
9292

93-
ppo_params = locomotion_params.brax_ppo_config(env_name)
93+
ppo_params = dm_control_suite_params.brax_ppo_config(env_name)
9494
ppo_training_params = dict(ppo_params)
9595
network_factory = ppo_networks.make_ppo_networks
9696
if "network_factory" in ppo_params:
@@ -110,9 +110,9 @@ def get_sac_train_fn():
110110
from brax.training.agents.sac import networks as sac_networks
111111
from brax.training.agents.sac import train as sac
112112

113-
from mujoco_playground.config import locomotion_params
113+
from mujoco_playground.config import dm_control_suite_params
114114

115-
sac_params = locomotion_params.brax_sac_config(env_name)
115+
sac_params = dm_control_suite_params.brax_sac_config(env_name)
116116
sac_training_params = dict(sac_params)
117117
network_factory = sac_networks.make_sac_networks
118118
if "network_factory" in sac_params:

0 commit comments

Comments
 (0)