Skip to content

Commit e2d016f

Browse files
committed
Add parameters
1 parent 30be900 commit e2d016f

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

config/train_brax.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,6 @@ jit: true
2727

2828
training:
2929
seed: 0
30-
render: true
30+
render: true
31+
task_name: QuadrupedRun
32+
agent_name: PPO

train_brax.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,6 @@ def get_state_path() -> str:
7777
return log_path
7878

7979

80-
env_name = "QuadrupedRun"
81-
env = registry.load(env_name)
82-
env_cfg = registry.get_default_config(env_name)
83-
eval_env = registry.load(env_name, config=env_cfg)
84-
agent_name = "PPO"
85-
86-
8780
def get_ppo_train_fn():
8881
from brax.training.agents.ppo import networks as ppo_networks
8982
from brax.training.agents.ppo import train as ppo
@@ -207,14 +200,17 @@ def main(cfg):
207200
f"\n{OmegaConf.to_yaml(cfg)}"
208201
)
209202
logger = WeightAndBiasesWriter(cfg)
210-
if agent_name == "SAC":
203+
if cfg.agent_name == "SAC":
211204
train_fn = get_sac_train_fn()
212-
elif agent_name == "PPO":
205+
elif cfg.agent_name == "PPO":
213206
train_fn = get_ppo_train_fn()
214207
else:
215208
raise NotImplementedError
216209
rng = jax.random.PRNGKey(cfg.training.seed)
217210
steps = Counter()
211+
env = registry.load(cfg.task_name)
212+
env_cfg = registry.get_default_config(cfg.task_name)
213+
eval_env = registry.load(cfg.task_name, config=env_cfg)
218214
with jax.disable_jit(not cfg.jit):
219215
make_policy, params, _ = train_fn(
220216
environment=env,

0 commit comments

Comments
 (0)