Skip to content

Commit c0a39f9

Browse files
authored
Add car domain randomization (#6)
* CarParams to dataclass * Remove domain name * margin factor to 20 * Add Lenart * Use deterministic eval and longer episodes makes the trick
1 parent 9dde1b4 commit c0a39f9

File tree

14 files changed

+321
-142
lines changed

14 files changed

+321
-142
lines changed

poetry.lock

Lines changed: 128 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ jax = "0.4.25"
2727
brax = "^0.10.5"
2828
equinox = "^0.11.4"
2929
pyqt6 = "^6.7.0"
30+
mbpo = {git = "https://github.com/lasgroup/Model-based-policy-optimizers.git"}
3031

3132

3233
[[tool.poetry.source]]

ss2r/benchmark_suites/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def make(cfg):
1818

1919
def make_rccar_envs(cfg):
2020
task_cfg = dict(get_task_config(cfg))
21+
task_cfg.pop("domain_name")
2122
train_car_params = task_cfg.pop("train_car_params")
2223
eval_car_params = task_cfg.pop("eval_car_params")
2324
train_env = rccar.RCCar(train_car_params, **task_cfg)

ss2r/benchmark_suites/rccar/model.py

Lines changed: 23 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from typing import NamedTuple
2-
31
import jax
42
import jax.numpy as jnp
3+
from flax.struct import dataclass
54

65

7-
class CarParams(NamedTuple):
6+
@dataclass
7+
class CarParams:
88
"""
99
d_f, d_r : Represent grip of the car. Range: [0.015, 0.025]
1010
b_f, b_r: Slope of the pacejka. Range: [2.0 - 4.0].
@@ -35,9 +35,7 @@ class CarParams(NamedTuple):
3535
c_m_2: jax.Array = jnp.array(1.5003588) # [0.00, 0.007]
3636
c_d: jax.Array = jnp.array(0.0) # [0.01, 0.1]
3737
steering_limit: jax.Array = jnp.array(0.19989373)
38-
use_blend: jax.Array = jnp.array(
39-
0.0
40-
) # 0.0 -> (only kinematics), 1.0 -> (kinematics + dynamics)
38+
use_blend: jax.Array = jnp.array(0.0)
4139
# parameters used to compute the blend ratio characteristics
4240
blend_ratio_ub: jax.Array = jnp.array([0.5477225575])
4341
blend_ratio_lb: jax.Array = jnp.array([0.4472135955])
@@ -92,7 +90,7 @@ def compute_accelerations(x, u, params: CarParams):
9290
return acceleration
9391

9492

95-
class RaceCar:
93+
class RaceCarDynamics:
9694
"""
9795
local_coordinates: bool
9896
Used to indicate if local or global coordinates shall be used.
@@ -108,19 +106,17 @@ class RaceCar:
108106
def __init__(
109107
self,
110108
dt,
111-
encode_angle: bool = True,
112109
local_coordinates: bool = False,
113110
rk_integrator: bool = True,
114111
):
115-
self.encode_angle = encode_angle
116112
if dt <= 1 / 100:
117113
integration_dt = dt
118114
else:
119115
integration_dt = 1 / 100
120116
self.local_coordinates = local_coordinates
121117
self.angle_idx = 2
122-
self.velocity_start_idx = 4 if self.encode_angle else 3
123-
self.velocity_end_idx = 5 if self.encode_angle else 4
118+
self.velocity_start_idx = 3
119+
self.velocity_end_idx = 4
124120
self.rk_integrator = rk_integrator
125121
self._num_steps_integrate = int(dt / integration_dt)
126122
self.dt_integration = integration_dt
@@ -133,12 +129,11 @@ def body(carry, _):
133129
return q, None
134130

135131
next_state, _ = jax.lax.scan(body, x, xs=None, length=self._num_steps_integrate)
136-
if self.angle_idx is not None:
137-
theta = next_state[self.angle_idx]
138-
sin_theta, cos_theta = jnp.sin(theta), jnp.cos(theta)
139-
next_state = next_state.at[self.angle_idx].set(
140-
jnp.arctan2(sin_theta, cos_theta)
141-
)
132+
theta = next_state[self.angle_idx]
133+
sin_theta, cos_theta = jnp.sin(theta), jnp.cos(theta)
134+
next_state = next_state.at[self.angle_idx].set(
135+
jnp.arctan2(sin_theta, cos_theta)
136+
)
142137
return next_state
143138

144139
def rk_integration(
@@ -183,20 +178,16 @@ def rk_integrate(carry, ins):
183178
return q, None
184179

185180
next_state, _ = jax.lax.scan(body, x, xs=None, length=self._num_steps_integrate)
186-
if self.angle_idx is not None:
187-
theta = next_state[self.angle_idx]
188-
sin_theta, cos_theta = jnp.sin(theta), jnp.cos(theta)
189-
next_state = next_state.at[self.angle_idx].set(
190-
jnp.arctan2(sin_theta, cos_theta)
191-
)
181+
theta = next_state[self.angle_idx]
182+
sin_theta, cos_theta = jnp.sin(theta), jnp.cos(theta)
183+
next_state = next_state.at[self.angle_idx].set(
184+
jnp.arctan2(sin_theta, cos_theta)
185+
)
192186
return next_state
193187

194188
def step(self, x: jnp.array, u: jnp.array, params: CarParams) -> jnp.array:
195-
theta_x = (
196-
jnp.arctan2(x[..., self.angle_idx], x[..., self.angle_idx + 1])
197-
if self.encode_angle
198-
else x[..., self.angle_idx]
199-
)
189+
assert x.shape[-1] == 6
190+
theta_x = x[..., self.angle_idx]
200191
offset = jnp.clip(params.angle_offset, -jnp.pi, jnp.pi)
201192
theta_x = theta_x + offset
202193
if not self.local_coordinates:
@@ -208,41 +199,18 @@ def step(self, x: jnp.array, u: jnp.array, params: CarParams) -> jnp.array:
208199
x = x.at[..., self.velocity_start_idx : self.velocity_end_idx + 1].set(
209200
rotated_vel
210201
)
211-
if self.encode_angle:
212-
x_reduced = self.reduce_x(x)
213-
if self.rk_integrator:
214-
x_reduced = self.rk_integration(x_reduced, u, params)
215-
else:
216-
x_reduced = self._compute_one_dt(x_reduced, u, params)
217-
next_theta = jnp.atleast_1d(x_reduced[..., self.angle_idx])
218-
next_x = jnp.concatenate(
219-
[
220-
x_reduced[..., 0 : self.angle_idx],
221-
jnp.sin(next_theta),
222-
jnp.cos(next_theta),
223-
x_reduced[..., self.angle_idx + 1 :],
224-
],
225-
axis=-1,
226-
)
202+
if self.rk_integrator:
203+
next_x = self.rk_integration(x, u, params)
227204
else:
228-
if self.rk_integrator:
229-
next_x = self.rk_integration(x, u, params)
230-
else:
231-
next_x = self._compute_one_dt(x, u, params)
205+
next_x = self._compute_one_dt(x, u, params)
232206
if self.local_coordinates:
233207
# convert position to local frame
234208
pos = next_x[..., 0 : self.angle_idx] - x[..., 0 : self.angle_idx]
235209
rotated_pos = rotate_vector(pos, -theta_x)
236210
next_x = next_x.at[..., 0 : self.angle_idx].set(rotated_pos)
237211
else:
238212
# convert velocity to global frame
239-
new_theta_x = (
240-
jnp.arctan2(
241-
next_x[..., self.angle_idx], next_x[..., self.angle_idx + 1]
242-
)
243-
if self.encode_angle
244-
else next_x[..., self.angle_idx]
245-
)
213+
new_theta_x = next_x[..., self.angle_idx]
246214
new_theta_x = new_theta_x + offset
247215
velocity = next_x[..., self.velocity_start_idx : self.velocity_end_idx + 1]
248216
rotated_vel = rotate_vector(velocity, new_theta_x)
@@ -251,19 +219,6 @@ def step(self, x: jnp.array, u: jnp.array, params: CarParams) -> jnp.array:
251219
].set(rotated_vel)
252220
return next_x
253221

254-
def reduce_x(self, x):
255-
theta = jnp.arctan2(x[..., self.angle_idx], x[..., self.angle_idx + 1])
256-
257-
x_reduced = jnp.concatenate(
258-
[
259-
x[..., 0 : self.angle_idx],
260-
jnp.atleast_1d(theta),
261-
x[..., self.velocity_start_idx :],
262-
],
263-
axis=-1,
264-
)
265-
return x_reduced
266-
267222
def _ode_dyn(self, x, u, params: CarParams):
268223
"""Compute derivative using dynamic model.
269224
Inputs

0 commit comments

Comments
 (0)