Skip to content

Commit 20786a5

Browse files
committed
fix bugs
1 parent e2d016f commit 20786a5

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

train_brax.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def get_state_path() -> str:
7777
return log_path
7878

7979

80-
def get_ppo_train_fn():
80+
def get_ppo_train_fn(env_name):
8181
from brax.training.agents.ppo import networks as ppo_networks
8282
from brax.training.agents.ppo import train as ppo
8383

@@ -99,7 +99,7 @@ def get_ppo_train_fn():
9999
return train_fn
100100

101101

102-
def get_sac_train_fn():
102+
def get_sac_train_fn(env_name):
103103
from brax.training.agents.sac import networks as sac_networks
104104
from brax.training.agents.sac import train as sac
105105

@@ -200,17 +200,17 @@ def main(cfg):
200200
f"\n{OmegaConf.to_yaml(cfg)}"
201201
)
202202
logger = WeightAndBiasesWriter(cfg)
203-
if cfg.agent_name == "SAC":
204-
train_fn = get_sac_train_fn()
205-
elif cfg.agent_name == "PPO":
206-
train_fn = get_ppo_train_fn()
203+
if cfg.training.agent_name == "SAC":
204+
train_fn = get_sac_train_fn(cfg.training.task_name)
205+
elif cfg.training.agent_name == "PPO":
206+
train_fn = get_ppo_train_fn(cfg.training.task_name)
207207
else:
208208
raise NotImplementedError
209209
rng = jax.random.PRNGKey(cfg.training.seed)
210210
steps = Counter()
211-
env = registry.load(cfg.task_name)
212-
env_cfg = registry.get_default_config(cfg.task_name)
213-
eval_env = registry.load(cfg.task_name, config=env_cfg)
211+
env = registry.load(cfg.training.task_name)
212+
env_cfg = registry.get_default_config(cfg.training.task_name)
213+
eval_env = registry.load(cfg.training.task_name, config=env_cfg)
214214
with jax.disable_jit(not cfg.jit):
215215
make_policy, params, _ = train_fn(
216216
environment=env,

0 commit comments

Comments
 (0)