Skip to content

Commit 9dde1b4

Browse files
authored
First implementation of RCCar (#5)
* First implementation of rccar * Add config files * Add environment making functions * Rccar loading * RCCar runs
1 parent 1a2a669 commit 9dde1b4

27 files changed

+958
-70
lines changed

ss2r/benchmark_suites/__init__.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import functools
2+
3+
import jax
4+
from brax import envs
5+
6+
from ss2r.benchmark_suites.brax import randomization_fns
7+
from ss2r.benchmark_suites.rccar import rccar
8+
from ss2r.benchmark_suites.utils import get_domain_name, get_task_config
9+
10+
11+
def make(cfg):
12+
domain_name = get_domain_name(cfg)
13+
if domain_name == "brax":
14+
return make_brax_envs(cfg)
15+
elif domain_name == "rccar":
16+
return make_rccar_envs(cfg)
17+
18+
19+
def make_rccar_envs(cfg):
20+
task_cfg = dict(get_task_config(cfg))
21+
train_car_params = task_cfg.pop("train_car_params")
22+
eval_car_params = task_cfg.pop("eval_car_params")
23+
train_env = rccar.RCCar(train_car_params, **task_cfg)
24+
train_env = envs.training.wrap(
25+
train_env,
26+
episode_length=cfg.training.episode_length,
27+
action_repeat=cfg.training.action_repeat,
28+
)
29+
eval_env = rccar.RCCar(eval_car_params, **task_cfg)
30+
eval_env = envs.training.wrap(
31+
eval_env,
32+
episode_length=cfg.training.episode_length,
33+
action_repeat=cfg.training.action_repeat,
34+
)
35+
return train_env, eval_env, None
36+
37+
38+
def make_brax_envs(cfg):
39+
task_cfg = get_task_config(cfg)
40+
train_env = envs.get_environment(
41+
task_cfg.task_name, backend=cfg.environment.backend
42+
)
43+
eval_env = envs.get_environment(task_cfg.task_name, backend=cfg.environment.backend)
44+
train_key, eval_key = jax.random.split(jax.random.PRNGKey(cfg.training.seed))
45+
46+
def prepare_randomization_fn(key, num_envs):
47+
randomize_fn = lambda sys, rng: randomization_fns[task_cfg.task_name](
48+
sys, rng, task_cfg
49+
)
50+
v_randomization_fn = functools.partial(
51+
randomize_fn, rng=jax.random.split(key, num_envs)
52+
)
53+
vf_randomization_fn = lambda sys: v_randomization_fn(sys)[:-1] # type: ignore
54+
params_fn = lambda sys: v_randomization_fn(sys)[-1]
55+
return vf_randomization_fn, params_fn
56+
57+
train_randomization_fn, params_fn = (
58+
prepare_randomization_fn(train_key, cfg.training.num_envs)
59+
if cfg.training.train_domain_randomization
60+
else (None, None)
61+
)
62+
train_env = envs.training.wrap(
63+
train_env,
64+
episode_length=cfg.training.episode_length,
65+
action_repeat=cfg.training.action_repeat,
66+
randomization_fn=train_randomization_fn,
67+
)
68+
eval_randomization_fn, _ = prepare_randomization_fn(
69+
eval_key, cfg.training.num_eval_envs
70+
)
71+
eval_env = envs.training.wrap(
72+
eval_env,
73+
episode_length=cfg.training.episode_length,
74+
action_repeat=cfg.training.action_repeat,
75+
randomization_fn=eval_randomization_fn
76+
if cfg.training.eval_domain_randomization
77+
else None,
78+
)
79+
if cfg.training.train_domain_randomization and cfg.training.privileged:
80+
domain_parameters = params_fn(train_env.sys)
81+
else:
82+
domain_parameters = None
83+
return train_env, eval_env, domain_parameters

ss2r/benchmark_suites/brax/cartpole/cartpole.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from brax.io import mjcf
1010

1111
from ss2r.algorithms.state_sampler import StateSampler
12-
from ss2r.benchmark_suites.brax import rewards
12+
from ss2r.benchmark_suites import rewards
1313

1414

1515
def domain_randomization(sys, rng, cfg):

ss2r/benchmark_suites/rccar/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)