Skip to content

Commit e5f6b50

Browse files
committed
Process impl flag for training
1 parent 229fb1f commit e5f6b50

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

learning/train_jax_ppo.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def main(argv):
205205

206206
# Load environment configuration
207207
env_cfg = registry.get_default_config(_ENV_NAME.value)
208+
env_cfg["impl"] = _IMPL.value
208209

209210
ppo_params = get_rl_config(_ENV_NAME.value)
210211

@@ -396,12 +397,9 @@ def progress(num_steps, metrics):
396397
)
397398

398399
# Load evaluation environment.
399-
config_overrides = {"impl": _IMPL.value}
400400
eval_env = None
401401
if not _VISION.value:
402-
eval_env = registry.load(
403-
_ENV_NAME.value, config=env_cfg, config_overrides=config_overrides
404-
)
402+
eval_env = registry.load(_ENV_NAME.value, config=env_cfg)
405403
num_envs = 1
406404
if _VISION.value:
407405
num_envs = env_cfg.vision_config.render_batch_size
@@ -412,9 +410,7 @@ def progress(num_steps, metrics):
412410
from rscope import brax as rscope_utils
413411

414412
if not _VISION.value:
415-
rscope_env = registry.load(
416-
_ENV_NAME.value, config=env_cfg, config_overrides=config_overrides
417-
)
413+
rscope_env = registry.load(_ENV_NAME.value, config=env_cfg)
418414
rscope_env = wrapper.wrap_for_brax_training(
419415
rscope_env,
420416
episode_length=ppo_params.episode_length,

0 commit comments

Comments
 (0)