1
- from typing import NamedTuple
2
-
3
1
import jax
4
2
import jax .numpy as jnp
3
+ from flax .struct import dataclass
5
4
6
5
7
- class CarParams (NamedTuple ):
6
+ @dataclass
7
+ class CarParams :
8
8
"""
9
9
d_f, d_r : Represent grip of the car. Range: [0.015, 0.025]
10
10
b_f, b_r: Slope of the pacejka. Range: [2.0 - 4.0].
@@ -35,9 +35,7 @@ class CarParams(NamedTuple):
35
35
c_m_2 : jax .Array = jnp .array (1.5003588 ) # [0.00, 0.007]
36
36
c_d : jax .Array = jnp .array (0.0 ) # [0.01, 0.1]
37
37
steering_limit : jax .Array = jnp .array (0.19989373 )
38
- use_blend : jax .Array = jnp .array (
39
- 0.0
40
- ) # 0.0 -> (only kinematics), 1.0 -> (kinematics + dynamics)
38
+ use_blend : jax .Array = jnp .array (0.0 )
41
39
# parameters used to compute the blend ratio characteristics
42
40
blend_ratio_ub : jax .Array = jnp .array ([0.5477225575 ])
43
41
blend_ratio_lb : jax .Array = jnp .array ([0.4472135955 ])
@@ -92,7 +90,7 @@ def compute_accelerations(x, u, params: CarParams):
92
90
return acceleration
93
91
94
92
95
- class RaceCar :
93
+ class RaceCarDynamics :
96
94
"""
97
95
local_coordinates: bool
98
96
Used to indicate if local or global coordinates shall be used.
@@ -108,19 +106,17 @@ class RaceCar:
108
106
def __init__ (
109
107
self ,
110
108
dt ,
111
- encode_angle : bool = True ,
112
109
local_coordinates : bool = False ,
113
110
rk_integrator : bool = True ,
114
111
):
115
- self .encode_angle = encode_angle
116
112
if dt <= 1 / 100 :
117
113
integration_dt = dt
118
114
else :
119
115
integration_dt = 1 / 100
120
116
self .local_coordinates = local_coordinates
121
117
self .angle_idx = 2
122
- self .velocity_start_idx = 4 if self . encode_angle else 3
123
- self .velocity_end_idx = 5 if self . encode_angle else 4
118
+ self .velocity_start_idx = 3
119
+ self .velocity_end_idx = 4
124
120
self .rk_integrator = rk_integrator
125
121
self ._num_steps_integrate = int (dt / integration_dt )
126
122
self .dt_integration = integration_dt
@@ -133,12 +129,11 @@ def body(carry, _):
133
129
return q , None
134
130
135
131
next_state , _ = jax .lax .scan (body , x , xs = None , length = self ._num_steps_integrate )
136
- if self .angle_idx is not None :
137
- theta = next_state [self .angle_idx ]
138
- sin_theta , cos_theta = jnp .sin (theta ), jnp .cos (theta )
139
- next_state = next_state .at [self .angle_idx ].set (
140
- jnp .arctan2 (sin_theta , cos_theta )
141
- )
132
+ theta = next_state [self .angle_idx ]
133
+ sin_theta , cos_theta = jnp .sin (theta ), jnp .cos (theta )
134
+ next_state = next_state .at [self .angle_idx ].set (
135
+ jnp .arctan2 (sin_theta , cos_theta )
136
+ )
142
137
return next_state
143
138
144
139
def rk_integration (
@@ -183,20 +178,16 @@ def rk_integrate(carry, ins):
183
178
return q , None
184
179
185
180
next_state , _ = jax .lax .scan (body , x , xs = None , length = self ._num_steps_integrate )
186
- if self .angle_idx is not None :
187
- theta = next_state [self .angle_idx ]
188
- sin_theta , cos_theta = jnp .sin (theta ), jnp .cos (theta )
189
- next_state = next_state .at [self .angle_idx ].set (
190
- jnp .arctan2 (sin_theta , cos_theta )
191
- )
181
+ theta = next_state [self .angle_idx ]
182
+ sin_theta , cos_theta = jnp .sin (theta ), jnp .cos (theta )
183
+ next_state = next_state .at [self .angle_idx ].set (
184
+ jnp .arctan2 (sin_theta , cos_theta )
185
+ )
192
186
return next_state
193
187
194
188
def step (self , x : jnp .array , u : jnp .array , params : CarParams ) -> jnp .array :
195
- theta_x = (
196
- jnp .arctan2 (x [..., self .angle_idx ], x [..., self .angle_idx + 1 ])
197
- if self .encode_angle
198
- else x [..., self .angle_idx ]
199
- )
189
+ assert x .shape [- 1 ] == 6
190
+ theta_x = x [..., self .angle_idx ]
200
191
offset = jnp .clip (params .angle_offset , - jnp .pi , jnp .pi )
201
192
theta_x = theta_x + offset
202
193
if not self .local_coordinates :
@@ -208,41 +199,18 @@ def step(self, x: jnp.array, u: jnp.array, params: CarParams) -> jnp.array:
208
199
x = x .at [..., self .velocity_start_idx : self .velocity_end_idx + 1 ].set (
209
200
rotated_vel
210
201
)
211
- if self .encode_angle :
212
- x_reduced = self .reduce_x (x )
213
- if self .rk_integrator :
214
- x_reduced = self .rk_integration (x_reduced , u , params )
215
- else :
216
- x_reduced = self ._compute_one_dt (x_reduced , u , params )
217
- next_theta = jnp .atleast_1d (x_reduced [..., self .angle_idx ])
218
- next_x = jnp .concatenate (
219
- [
220
- x_reduced [..., 0 : self .angle_idx ],
221
- jnp .sin (next_theta ),
222
- jnp .cos (next_theta ),
223
- x_reduced [..., self .angle_idx + 1 :],
224
- ],
225
- axis = - 1 ,
226
- )
202
+ if self .rk_integrator :
203
+ next_x = self .rk_integration (x , u , params )
227
204
else :
228
- if self .rk_integrator :
229
- next_x = self .rk_integration (x , u , params )
230
- else :
231
- next_x = self ._compute_one_dt (x , u , params )
205
+ next_x = self ._compute_one_dt (x , u , params )
232
206
if self .local_coordinates :
233
207
# convert position to local frame
234
208
pos = next_x [..., 0 : self .angle_idx ] - x [..., 0 : self .angle_idx ]
235
209
rotated_pos = rotate_vector (pos , - theta_x )
236
210
next_x = next_x .at [..., 0 : self .angle_idx ].set (rotated_pos )
237
211
else :
238
212
# convert velocity to global frame
239
- new_theta_x = (
240
- jnp .arctan2 (
241
- next_x [..., self .angle_idx ], next_x [..., self .angle_idx + 1 ]
242
- )
243
- if self .encode_angle
244
- else next_x [..., self .angle_idx ]
245
- )
213
+ new_theta_x = next_x [..., self .angle_idx ]
246
214
new_theta_x = new_theta_x + offset
247
215
velocity = next_x [..., self .velocity_start_idx : self .velocity_end_idx + 1 ]
248
216
rotated_vel = rotate_vector (velocity , new_theta_x )
@@ -251,19 +219,6 @@ def step(self, x: jnp.array, u: jnp.array, params: CarParams) -> jnp.array:
251
219
].set (rotated_vel )
252
220
return next_x
253
221
254
- def reduce_x (self , x ):
255
- theta = jnp .arctan2 (x [..., self .angle_idx ], x [..., self .angle_idx + 1 ])
256
-
257
- x_reduced = jnp .concatenate (
258
- [
259
- x [..., 0 : self .angle_idx ],
260
- jnp .atleast_1d (theta ),
261
- x [..., self .velocity_start_idx :],
262
- ],
263
- axis = - 1 ,
264
- )
265
- return x_reduced
266
-
267
222
def _ode_dyn (self , x , u , params : CarParams ):
268
223
"""Compute derivative using dynamic model.
269
224
Inputs
0 commit comments