@@ -77,7 +77,7 @@ def get_state_path() -> str:
7777 return log_path
7878
7979
80- def get_ppo_train_fn ():
80+ def get_ppo_train_fn (env_name ):
8181 from brax .training .agents .ppo import networks as ppo_networks
8282 from brax .training .agents .ppo import train as ppo
8383
@@ -99,7 +99,7 @@ def get_ppo_train_fn():
9999 return train_fn
100100
101101
102- def get_sac_train_fn ():
102+ def get_sac_train_fn (env_name ):
103103 from brax .training .agents .sac import networks as sac_networks
104104 from brax .training .agents .sac import train as sac
105105
@@ -200,17 +200,17 @@ def main(cfg):
200200 f"\n { OmegaConf .to_yaml (cfg )} "
201201 )
202202 logger = WeightAndBiasesWriter (cfg )
203- if cfg .agent_name == "SAC" :
204- train_fn = get_sac_train_fn ()
205- elif cfg .agent_name == "PPO" :
206- train_fn = get_ppo_train_fn ()
203+ if cfg .training . agent_name == "SAC" :
204+ train_fn = get_sac_train_fn (cfg . training . task_name )
205+ elif cfg .training . agent_name == "PPO" :
206+ train_fn = get_ppo_train_fn (cfg . training . task_name )
207207 else :
208208 raise NotImplementedError
209209 rng = jax .random .PRNGKey (cfg .training .seed )
210210 steps = Counter ()
211- env = registry .load (cfg .task_name )
212- env_cfg = registry .get_default_config (cfg .task_name )
213- eval_env = registry .load (cfg .task_name , config = env_cfg )
211+ env = registry .load (cfg .training . task_name )
212+ env_cfg = registry .get_default_config (cfg .training . task_name )
213+ eval_env = registry .load (cfg .training . task_name , config = env_cfg )
214214 with jax .disable_jit (not cfg .jit ):
215215 make_policy , params , _ = train_fn (
216216 environment = env ,
0 commit comments