Skip to content

Commit de59375

Browse files
authored
Domain randomization rccar (#8)
* Add nominal model * Add bounds to car models * Add randomization functions * Sampling seems to work * Add domain randomization test * Remove unused function
1 parent c0a39f9 commit de59375

File tree

20 files changed

+619
-263
lines changed

20 files changed

+619
-263
lines changed

ss2r/benchmark_suites/__init__.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,54 @@ def make(cfg):
1616
return make_rccar_envs(cfg)
1717

1818

19+
def prepare_randomization_fn(key, num_envs, cfg, task_name):
20+
randomize_fn = lambda sys, rng: randomization_fns[task_name](sys, rng, cfg)
21+
v_randomization_fn = functools.partial(
22+
randomize_fn, rng=jax.random.split(key, num_envs)
23+
)
24+
vf_randomization_fn = lambda sys: v_randomization_fn(sys)[:-1] # type: ignore
25+
params_fn = lambda sys: v_randomization_fn(sys)[-1]
26+
return vf_randomization_fn, params_fn
27+
28+
1929
def make_rccar_envs(cfg):
2030
task_cfg = dict(get_task_config(cfg))
2131
task_cfg.pop("domain_name")
32+
task_cfg.pop("task_name")
2233
train_car_params = task_cfg.pop("train_car_params")
2334
eval_car_params = task_cfg.pop("eval_car_params")
24-
train_env = rccar.RCCar(train_car_params, **task_cfg)
35+
train_key, eval_key = jax.random.split(jax.random.PRNGKey(cfg.training.seed))
36+
train_env = rccar.RCCar(train_car_params["nominal"], **task_cfg)
37+
train_randomization_fn, params_fn = (
38+
prepare_randomization_fn(
39+
train_key,
40+
cfg.training.num_envs,
41+
train_car_params["bounds"],
42+
cfg.environment.task_name,
43+
)
44+
if cfg.training.train_domain_randomization
45+
else (None, None)
46+
)
2547
train_env = envs.training.wrap(
2648
train_env,
2749
episode_length=cfg.training.episode_length,
2850
action_repeat=cfg.training.action_repeat,
51+
randomization_fn=train_randomization_fn,
52+
)
53+
eval_env = rccar.RCCar(eval_car_params["nominal"], **task_cfg)
54+
eval_randomization_fn, _ = prepare_randomization_fn(
55+
eval_key,
56+
cfg.training.num_eval_envs,
57+
eval_car_params["bounds"],
58+
cfg.environment.task_name,
2959
)
30-
eval_env = rccar.RCCar(eval_car_params, **task_cfg)
3160
eval_env = envs.training.wrap(
3261
eval_env,
3362
episode_length=cfg.training.episode_length,
3463
action_repeat=cfg.training.action_repeat,
64+
randomization_fn=eval_randomization_fn,
3565
)
36-
return train_env, eval_env, None
66+
return train_env, eval_env, params_fn
3767

3868

3969
def make_brax_envs(cfg):
@@ -43,20 +73,10 @@ def make_brax_envs(cfg):
4373
)
4474
eval_env = envs.get_environment(task_cfg.task_name, backend=cfg.environment.backend)
4575
train_key, eval_key = jax.random.split(jax.random.PRNGKey(cfg.training.seed))
46-
47-
def prepare_randomization_fn(key, num_envs):
48-
randomize_fn = lambda sys, rng: randomization_fns[task_cfg.task_name](
49-
sys, rng, task_cfg
50-
)
51-
v_randomization_fn = functools.partial(
52-
randomize_fn, rng=jax.random.split(key, num_envs)
53-
)
54-
vf_randomization_fn = lambda sys: v_randomization_fn(sys)[:-1] # type: ignore
55-
params_fn = lambda sys: v_randomization_fn(sys)[-1]
56-
return vf_randomization_fn, params_fn
57-
5876
train_randomization_fn, params_fn = (
59-
prepare_randomization_fn(train_key, cfg.training.num_envs)
77+
prepare_randomization_fn(
78+
train_key, cfg.training.num_envs, task_cfg, task_cfg.task_name
79+
)
6080
if cfg.training.train_domain_randomization
6181
else (None, None)
6282
)
@@ -67,7 +87,7 @@ def prepare_randomization_fn(key, num_envs):
6787
randomization_fn=train_randomization_fn,
6888
)
6989
eval_randomization_fn, _ = prepare_randomization_fn(
70-
eval_key, cfg.training.num_eval_envs
90+
eval_key, cfg.training.num_eval_envs, task_cfg, task_cfg.task_name
7191
)
7292
eval_env = envs.training.wrap(
7393
eval_env,
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from ss2r.benchmark_suites.brax.cartpole import cartpole
2+
from ss2r.benchmark_suites.rccar import rccar
23

34
randomization_fns = {
45
"cartpole_swingup": cartpole.domain_randomization,
56
"cartpole_swingup_sparse": cartpole.domain_randomization,
67
"cartpole_balance": cartpole.domain_randomization,
78
"inverted_pendulum": cartpole.domain_randomization,
9+
"rccar": rccar.domain_randomization,
810
}

ss2r/benchmark_suites/rccar/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ class CarParams:
1919
c_m_1: max current of motor: [0.2 - 0.5] c_m_2: motor resistance due to shaft: [0.01 - 0.15]
2020
"""
2121

22-
car_id: int = 2
2322
m: jax.Array = jnp.array(1.65) # [0.04, 0.08]
2423
i_com: jax.Array = jnp.array(2.78e-05) # [1e-6, 5e-6]
2524
l_f: jax.Array = jnp.array(0.13) # [0.025, 0.05]

ss2r/benchmark_suites/rccar/rccar.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
from typing import Tuple
33

44
import jax
5+
import jax.flatten_util
56
import jax.numpy as jnp
7+
import jax.tree_util as jtu
68
from brax.envs.base import Env, State
9+
from omegaconf import OmegaConf
710

811
from ss2r.benchmark_suites import rewards
912
from ss2r.benchmark_suites.rccar.model import CarParams, RaceCarDynamics
@@ -13,6 +16,38 @@
1316
)
1417

1518

19+
def domain_randomization(sys, rng, cfg):
20+
def sample_from_bounds(value, key):
21+
"""
22+
Sample from a JAX uniform distribution if the value is a list of two elements.
23+
"""
24+
if isinstance(value, list) and len(value) == 2:
25+
lower, upper = value
26+
# Sample from jax.random.uniform with the given key
27+
return jax.random.uniform(key, shape=(), minval=lower, maxval=upper)
28+
return value
29+
30+
@jax.vmap
31+
def randomize(rng):
32+
bounds = CarParams(**cfg)
33+
# Define a custom tree structure that treats lists as leaves
34+
treedef = jtu.tree_structure(bounds, is_leaf=lambda x: isinstance(x, list))
35+
# Generate random keys only for the relevant leaves (i.e., lists with 2 elements)
36+
keys = jax.random.split(rng, num=treedef.num_leaves)
37+
# Rebuild the tree with the keys, only where there are valid leaves
38+
keys = jtu.tree_unflatten(treedef, keys)
39+
# Map over the tree, generating random values where needed
40+
sys = jtu.tree_map(
41+
sample_from_bounds, bounds, keys, is_leaf=lambda x: isinstance(x, list)
42+
)
43+
return sys, jax.flatten_util.ravel_pytree(sys)[0]
44+
45+
cfg = OmegaConf.to_container(cfg)
46+
in_axes = jax.tree_map(lambda _: 0, sys)
47+
sys, params = randomize(rng)
48+
return sys, in_axes, params
49+
50+
1651
def rotate_coordinates(state: jnp.array, encode_angle: bool = False) -> jnp.array:
1752
x_pos, x_vel = (
1853
state[..., 0:1],
Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,39 @@
1-
use_blend: 0.0
2-
m: 1.65
3-
l_f: 0.13
4-
l_r: 0.17
5-
angle_offset: 0.0156
6-
b_f: 2.58
7-
b_r: 3.39
8-
blend_ratio_lb: 0.01
9-
blend_ratio_ub: 0.01
10-
c_d: 0.41464928
11-
c_f: 1.2
12-
c_m_1: 10.701814
13-
c_m_2: 1.4208151
14-
c_r: 1.27
15-
d_f: 0.02
16-
d_r: 0.017
17-
i_com: 0.01
18-
steering_limit: 0.3543
1+
nominal:
2+
use_blend: 0.0
3+
m: 1.65
4+
l_f: 0.13
5+
l_r: 0.17
6+
angle_offset: 0.0156
7+
b_f: 2.58
8+
b_r: 3.39
9+
blend_ratio_lb: 0.01
10+
blend_ratio_ub: 0.01
11+
c_d: 0.41464928
12+
c_f: 1.2
13+
c_m_1: 10.701814
14+
c_m_2: 1.4208151
15+
c_r: 1.27
16+
d_f: 0.02
17+
d_r: 0.017
18+
i_com: 0.01
19+
steering_limit: 0.3543
20+
21+
bounds:
22+
use_blend: [0.0, 0.0]
23+
m: [1.6, 1.7]
24+
l_f: [0.11, 0.15]
25+
l_r: [0.15, 0.19]
26+
angle_offset: [0.001, 0.03]
27+
b_f: [2.2, 2.8]
28+
b_r: [2.0, 6.0]
29+
blend_ratio_lb: [0.4, 0.4]
30+
blend_ratio_ub: [0.5, 0.5]
31+
c_d: [0.3, 0.5]
32+
c_f: [1.2, 1.2]
33+
c_m_1: [8.0, 13.0]
34+
c_m_2: [1.1, 1.7]
35+
c_r: [1.27, 1.27]
36+
d_f: [0.02, 0.02]
37+
d_r: [0.017, 0.017]
38+
i_com: [0.01, 0.1]
39+
steering_limit: [0.20, 0.5]
Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,39 @@
1-
use_blend: 1.0
2-
m: 1.65
3-
l_f: 0.13
4-
l_r: 0.17
5-
angle_offset: -0.0213
6-
b_f: 1.8966477
7-
b_r: 6.2884626
8-
blend_ratio_lb: 0.06637411
9-
blend_ratio_ub: 0.00554
10-
c_d: 0.0
11-
c_f: 1.5381637
12-
c_m_1: 11.102413
13-
c_m_2: 1.3169205
14-
c_r: 1.186591
15-
d_f: 0.5968191
16-
d_r: 0.42716035
17-
i_com: 0.0685434
18-
steering_limit: 0.6337473
1+
nominal:
2+
use_blend: 1.0
3+
m: 1.65
4+
l_f: 0.13
5+
l_r: 0.17
6+
angle_offset: -0.0213
7+
b_f: 1.8966477
8+
b_r: 6.2884626
9+
blend_ratio_lb: 0.06637411
10+
blend_ratio_ub: 0.00554
11+
c_d: 0.0
12+
c_f: 1.5381637
13+
c_m_1: 11.102413
14+
c_m_2: 1.3169205
15+
c_r: 1.186591
16+
d_f: 0.5968191
17+
d_r: 0.42716035
18+
i_com: 0.0685434
19+
steering_limit: 0.6337473
20+
21+
bounds:
22+
use_blend: [1.0, 1.0]
23+
m: [1.6, 1.7]
24+
l_f: [0.125, 0.135]
25+
l_r: [0.165, 0.175]
26+
angle_offset: [-0.025, 0.025]
27+
b_f: [1.3, 3.0]
28+
b_r: [4.0, 10.0]
29+
blend_ratio_lb: [0.01, 0.1]
30+
blend_ratio_ub: [0.0, 0.2]
31+
c_d: [0.0, 0.0]
32+
c_f: [1.2, 1.8]
33+
c_m_1: [10.0, 12.0]
34+
c_m_2: [1.1, 1.5]
35+
c_r: [0.9, 1.5]
36+
d_f: [0.35, 0.65]
37+
d_r: [0.3, 0.6]
38+
i_com: [0.05, 0.09]
39+
steering_limit: [0.5, 0.9]
Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,39 @@
1-
use_blend: 0.0
2-
m: 1.65
3-
l_f: 0.13
4-
l_r: 0.17
5-
angle_offset: 0.0
6-
b_f: 2.58
7-
b_r: 5.0
8-
blend_ratio_lb: 0.01
9-
blend_ratio_ub: 0.01
10-
c_d: 0.0
11-
c_f: 1.2
12-
c_m_1: 8.0
13-
c_m_2: 1.5
14-
c_r: 1.27
15-
d_f: 0.02
16-
d_r: 0.017
17-
i_com: 0.01
18-
steering_limit: 0.3
1+
nominal:
2+
use_blend: 0.0
3+
m: 1.65
4+
l_f: 0.13
5+
l_r: 0.17
6+
angle_offset: 0.0
7+
b_f: 2.58
8+
b_r: 5.0
9+
blend_ratio_lb: 0.01
10+
blend_ratio_ub: 0.01
11+
c_d: 0.0
12+
c_f: 1.2
13+
c_m_1: 8.0
14+
c_m_2: 1.5
15+
c_r: 1.27
16+
d_f: 0.02
17+
d_r: 0.017
18+
i_com: 0.01
19+
steering_limit: 0.3
20+
21+
bounds:
22+
use_blend: [0.0, 0.0]
23+
m: [1.6, 1.7]
24+
l_f: [0.11, 0.15]
25+
l_r: [0.15, 0.19]
26+
angle_offset: [-0.15, 0.15]
27+
b_f: [2.4, 2.6]
28+
b_r: [2.0, 8.0]
29+
blend_ratio_lb: [0.4, 0.4]
30+
blend_ratio_ub: [0.5, 0.5]
31+
c_d: [0.01, 0.01]
32+
c_f: [1.2, 1.2]
33+
c_m_1: [6.0, 10.0]
34+
c_m_2: [1.0, 1.8]
35+
c_r: [1.27, 1.27]
36+
d_f: [0.02, 0.02]
37+
d_r: [0.017, 0.017]
38+
i_com: [0.01, 0.1]
39+
steering_limit: [0.15, 0.4]
Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,39 @@
1-
use_blend: 1.0
2-
m: 1.65
3-
l_f: 0.13
4-
l_r: 0.17
5-
angle_offset: 0.0
6-
b_f: 2.75
7-
b_r: 5.0
8-
blend_ratio_lb: 0.001
9-
blend_ratio_ub: 0.017
10-
c_d: 0.0
11-
c_f: 1.45
12-
c_m_1: 8.2
13-
c_m_2: 1.25
14-
c_r: 1.3
15-
d_f: 0.4
16-
d_r: 0.3
17-
i_com: 0.06
18-
steering_limit: 0.6
1+
nominal:
2+
use_blend: 1.0
3+
m: 1.65
4+
l_f: 0.13
5+
l_r: 0.17
6+
angle_offset: 0.0
7+
b_f: 2.75
8+
b_r: 5.0
9+
blend_ratio_lb: 0.001
10+
blend_ratio_ub: 0.017
11+
c_d: 0.0
12+
c_f: 1.45
13+
c_m_1: 8.2
14+
c_m_2: 1.25
15+
c_r: 1.3
16+
d_f: 0.4
17+
d_r: 0.3
18+
i_com: 0.06
19+
steering_limit: 0.6
20+
21+
bounds:
22+
use_blend: [1.0, 1.0]
23+
m: [1.6, 1.7]
24+
l_f: [0.125, 0.135]
25+
l_r: [0.165, 0.175]
26+
angle_offset: [-0.15, 0.15]
27+
b_f: [2.0, 4.0]
28+
b_r: [3.0, 10.0]
29+
blend_ratio_lb: [0.0001, 0.1]
30+
blend_ratio_ub: [0.0001, 0.2]
31+
c_d: [0.0, 0.0]
32+
c_f: [1.1, 2.0]
33+
c_m_1: [6.5, 10.0]
34+
c_m_2: [1.0, 1.5]
35+
c_r: [0.4, 2.0]
36+
d_f: [0.25, 0.6]
37+
d_r: [0.15, 0.45]
38+
i_com: [0.03, 0.18]
39+
steering_limit: [0.4, 0.75]

0 commit comments

Comments
 (0)