@@ -77,13 +77,6 @@ def get_state_path() -> str:
77
77
return log_path
78
78
79
79
80
- env_name = "QuadrupedRun"
81
- env = registry .load (env_name )
82
- env_cfg = registry .get_default_config (env_name )
83
- eval_env = registry .load (env_name , config = env_cfg )
84
- agent_name = "PPO"
85
-
86
-
87
80
def get_ppo_train_fn ():
88
81
from brax .training .agents .ppo import networks as ppo_networks
89
82
from brax .training .agents .ppo import train as ppo
@@ -207,14 +200,17 @@ def main(cfg):
207
200
f"\n { OmegaConf .to_yaml (cfg )} "
208
201
)
209
202
logger = WeightAndBiasesWriter (cfg )
210
- if agent_name == "SAC" :
203
+ if cfg . agent_name == "SAC" :
211
204
train_fn = get_sac_train_fn ()
212
- elif agent_name == "PPO" :
205
+ elif cfg . agent_name == "PPO" :
213
206
train_fn = get_ppo_train_fn ()
214
207
else :
215
208
raise NotImplementedError
216
209
rng = jax .random .PRNGKey (cfg .training .seed )
217
210
steps = Counter ()
211
+ env = registry .load (cfg .task_name )
212
+ env_cfg = registry .get_default_config (cfg .task_name )
213
+ eval_env = registry .load (cfg .task_name , config = env_cfg )
218
214
with jax .disable_jit (not cfg .jit ):
219
215
make_policy , params , _ = train_fn (
220
216
environment = env ,
0 commit comments