|
| 1 | +import functools |
| 2 | + |
| 3 | +import jax |
| 4 | +from brax import envs |
| 5 | + |
| 6 | +from ss2r.benchmark_suites.brax import randomization_fns |
| 7 | +from ss2r.benchmark_suites.rccar import rccar |
| 8 | +from ss2r.benchmark_suites.utils import get_domain_name, get_task_config |
| 9 | + |
| 10 | + |
| 11 | +def make(cfg): |
| 12 | + domain_name = get_domain_name(cfg) |
| 13 | + if domain_name == "brax": |
| 14 | + return make_brax_envs(cfg) |
| 15 | + elif domain_name == "rccar": |
| 16 | + return make_rccar_envs(cfg) |
| 17 | + |
| 18 | + |
| 19 | +def make_rccar_envs(cfg): |
| 20 | + task_cfg = dict(get_task_config(cfg)) |
| 21 | + train_car_params = task_cfg.pop("train_car_params") |
| 22 | + eval_car_params = task_cfg.pop("eval_car_params") |
| 23 | + train_env = rccar.RCCar(train_car_params, **task_cfg) |
| 24 | + train_env = envs.training.wrap( |
| 25 | + train_env, |
| 26 | + episode_length=cfg.training.episode_length, |
| 27 | + action_repeat=cfg.training.action_repeat, |
| 28 | + ) |
| 29 | + eval_env = rccar.RCCar(eval_car_params, **task_cfg) |
| 30 | + eval_env = envs.training.wrap( |
| 31 | + eval_env, |
| 32 | + episode_length=cfg.training.episode_length, |
| 33 | + action_repeat=cfg.training.action_repeat, |
| 34 | + ) |
| 35 | + return train_env, eval_env, None |
| 36 | + |
| 37 | + |
| 38 | +def make_brax_envs(cfg): |
| 39 | + task_cfg = get_task_config(cfg) |
| 40 | + train_env = envs.get_environment( |
| 41 | + task_cfg.task_name, backend=cfg.environment.backend |
| 42 | + ) |
| 43 | + eval_env = envs.get_environment(task_cfg.task_name, backend=cfg.environment.backend) |
| 44 | + train_key, eval_key = jax.random.split(jax.random.PRNGKey(cfg.training.seed)) |
| 45 | + |
| 46 | + def prepare_randomization_fn(key, num_envs): |
| 47 | + randomize_fn = lambda sys, rng: randomization_fns[task_cfg.task_name]( |
| 48 | + sys, rng, task_cfg |
| 49 | + ) |
| 50 | + v_randomization_fn = functools.partial( |
| 51 | + randomize_fn, rng=jax.random.split(key, num_envs) |
| 52 | + ) |
| 53 | + vf_randomization_fn = lambda sys: v_randomization_fn(sys)[:-1] # type: ignore |
| 54 | + params_fn = lambda sys: v_randomization_fn(sys)[-1] |
| 55 | + return vf_randomization_fn, params_fn |
| 56 | + |
| 57 | + train_randomization_fn, params_fn = ( |
| 58 | + prepare_randomization_fn(train_key, cfg.training.num_envs) |
| 59 | + if cfg.training.train_domain_randomization |
| 60 | + else (None, None) |
| 61 | + ) |
| 62 | + train_env = envs.training.wrap( |
| 63 | + train_env, |
| 64 | + episode_length=cfg.training.episode_length, |
| 65 | + action_repeat=cfg.training.action_repeat, |
| 66 | + randomization_fn=train_randomization_fn, |
| 67 | + ) |
| 68 | + eval_randomization_fn, _ = prepare_randomization_fn( |
| 69 | + eval_key, cfg.training.num_eval_envs |
| 70 | + ) |
| 71 | + eval_env = envs.training.wrap( |
| 72 | + eval_env, |
| 73 | + episode_length=cfg.training.episode_length, |
| 74 | + action_repeat=cfg.training.action_repeat, |
| 75 | + randomization_fn=eval_randomization_fn |
| 76 | + if cfg.training.eval_domain_randomization |
| 77 | + else None, |
| 78 | + ) |
| 79 | + if cfg.training.train_domain_randomization and cfg.training.privileged: |
| 80 | + domain_parameters = params_fn(train_env.sys) |
| 81 | + else: |
| 82 | + domain_parameters = None |
| 83 | + return train_env, eval_env, domain_parameters |
0 commit comments