@@ -77,7 +77,7 @@ def get_state_path() -> str:
77
77
return log_path
78
78
79
79
80
- def get_ppo_train_fn ():
80
+ def get_ppo_train_fn (env_name ):
81
81
from brax .training .agents .ppo import networks as ppo_networks
82
82
from brax .training .agents .ppo import train as ppo
83
83
@@ -99,7 +99,7 @@ def get_ppo_train_fn():
99
99
return train_fn
100
100
101
101
102
- def get_sac_train_fn ():
102
+ def get_sac_train_fn (env_name ):
103
103
from brax .training .agents .sac import networks as sac_networks
104
104
from brax .training .agents .sac import train as sac
105
105
@@ -200,17 +200,17 @@ def main(cfg):
200
200
f"\n { OmegaConf .to_yaml (cfg )} "
201
201
)
202
202
logger = WeightAndBiasesWriter (cfg )
203
- if cfg .agent_name == "SAC" :
204
- train_fn = get_sac_train_fn ()
205
- elif cfg .agent_name == "PPO" :
206
- train_fn = get_ppo_train_fn ()
203
+ if cfg .training . agent_name == "SAC" :
204
+ train_fn = get_sac_train_fn (cfg . training . task_name )
205
+ elif cfg .training . agent_name == "PPO" :
206
+ train_fn = get_ppo_train_fn (cfg . training . task_name )
207
207
else :
208
208
raise NotImplementedError
209
209
rng = jax .random .PRNGKey (cfg .training .seed )
210
210
steps = Counter ()
211
- env = registry .load (cfg .task_name )
212
- env_cfg = registry .get_default_config (cfg .task_name )
213
- eval_env = registry .load (cfg .task_name , config = env_cfg )
211
+ env = registry .load (cfg .training . task_name )
212
+ env_cfg = registry .get_default_config (cfg .training . task_name )
213
+ eval_env = registry .load (cfg .training . task_name , config = env_cfg )
214
214
with jax .disable_jit (not cfg .jit ):
215
215
make_policy , params , _ = train_fn (
216
216
environment = env ,
0 commit comments