11import functools
2- from typing import Tuple
32
43import jax
54import jax .flatten_util
@@ -144,13 +143,13 @@ def state_reward(self, obs: jax.Array, next_obs: jax.Array) -> jax.Array:
144143 return reward
145144
146145
147- class RCCar (Env ):
148- dim_action : Tuple [int ] = (2 ,)
149- _goal : jax .Array = jnp .array ([0.0 , 0.0 , 0.0 ])
150- _init_pose : jax .Array = jnp .array ([1.42 , - 1.04 , jnp .pi ])
151- _angle_idx : int = 2
152- _obs_noise_stds : jax .Array = OBS_NOISE_STD_SIM_CAR
146+ def cost_fn (state : jax .Array , obstacle_position , obstacle_radius ) -> jax .Array :
147+ xy = state [..., :2 ]
148+ distance = jnp .linalg .norm (xy - obstacle_position )
149+ return jnp .where (distance >= obstacle_radius , 0.0 , 1.0 )
150+
153151
152+ class RCCar (Env ):
154153 def __init__ (
155154 self ,
156155 car_model_params : dict ,
@@ -160,6 +159,7 @@ def __init__(
160159 margin_factor : float = 10.0 ,
161160 max_throttle : float = 1.0 ,
162161 dt : float = 1 / 30.0 ,
162+ obstacle : tuple [float , float , float ] = (- 0.75 , - 0.75 , 0.2 ),
163163 ):
164164 """
165165 Race car simulator environment
@@ -173,6 +173,12 @@ def __init__(
173173 car_model_params: dictionary of car model parameters that overwrite the default values
174174 seed: random number generator seed
175175 """
176+ self ._goal = jnp .array ([0.0 , 0.0 , 0.0 ])
177+ self .obstacle = tuple (obstacle )
178+ self ._init_pose = jnp .array ([1.42 , - 1.04 , jnp .pi ])
179+ self ._angle_idx = 2
180+ self ._obs_noise_stds = OBS_NOISE_STD_SIM_CAR
181+ self .dim_action = (2 ,)
176182 self ._dt = dt
177183 self .dim_state = (7 ,) if encode_angle else (6 ,)
178184 self .encode_angle = encode_angle
@@ -226,6 +232,7 @@ def reset(self, rng: jax.Array) -> State:
226232 obs = init_state ,
227233 reward = jnp .array (0.0 ),
228234 done = jnp .array (0.0 ),
235+ info = {"cost" : jnp .array (0.0 )},
229236 )
230237
231238 def step (self , state : State , action : jax .Array ) -> State :
@@ -239,14 +246,16 @@ def step(self, state: State, action: jax.Array) -> State:
239246 # FIXME (yarden): hard-coded key is bad here.
240247 next_obs = self ._obs (next_dynamics_state , rng = jax .random .PRNGKey (0 ))
241248 reward = self .reward_model .forward (obs = None , action = action , next_obs = next_obs )
249+ cost = cost_fn (obs , jnp .asarray (self .obstacle [:2 ]), self .obstacle [2 ])
242250 done = jnp .asarray (0.0 )
251+ info = {** state .info , "cost" : cost }
243252 next_state = State (
244253 pipeline_state = state .pipeline_state ,
245254 obs = next_obs ,
246255 reward = reward ,
247256 done = done ,
248257 metrics = state .metrics ,
249- info = state . info ,
258+ info = info ,
250259 )
251260 return next_state
252261
@@ -281,6 +290,8 @@ def render(env, policy, steps, rng):
281290 if env .encode_angle :
282291 trajectory = decode_angles (trajectory , 2 )
283292
293+ obstacle_position , obstacle_radius = env .obstacle [:2 ], env .obstacle [2 ]
294+
284295 def draw_scene (timestep ):
285296 # Create a figure and axis
286297 fig = Figure (figsize = (2.5 , 2.5 ), dpi = 300 )
@@ -301,7 +312,7 @@ def draw_scene(timestep):
301312 # Plot the car's position and velocity at the specified timestep
302313 x , y = rotated_trajectory [timestep , 0 ], rotated_trajectory [timestep , 1 ]
303314 vx , vy = rotated_trajectory [timestep , 3 ], rotated_trajectory [timestep , 4 ]
304- car_width , car_length = 0.3 , 0.6
315+ car_width , car_length = 0.07 , 0.2
305316 car = Rectangle (
306317 (x - car_length / 2 , y - car_width / 2 ),
307318 car_length ,
@@ -313,17 +324,15 @@ def draw_scene(timestep):
313324 rotation_point = "center" ,
314325 )
315326 ax .add_patch (car )
316- # Add an arrow to indicate the car's orientation
317- ax .arrow (
318- x ,
319- y ,
320- vx * 0.5 ,
321- vy * 0.5 ,
322- head_width = 0.2 ,
323- head_length = 0.2 ,
324- fc = "black" ,
327+ obstacle = Circle (
328+ obstacle_position ,
329+ obstacle_radius ,
330+ color = "gray" ,
331+ alpha = 0.5 ,
325332 ec = "black" ,
333+ lw = 1.5 ,
326334 )
335+ ax .add_patch (obstacle )
327336 ax .quiver (
328337 x ,
329338 y ,
@@ -336,6 +345,7 @@ def draw_scene(timestep):
336345 headwidth = 3 ,
337346 linewidth = 0.5 ,
338347 )
348+ ax .grid (True , linewidth = 0.5 , c = "gainsboro" , zorder = 0 )
339349 # Render figure to canvas and retrieve RGB array
340350 canvas .draw ()
341351 image = np .frombuffer (canvas .tostring_rgb (), dtype = "uint8" ).copy ()
0 commit comments