Skip to content

Commit b6b8867

Browse files
authored
Rccar constraint (#10)
* Cleanup * Add simple cost function * Add safety constraints * Add obstacle as parameter
1 parent 9c58780 commit b6b8867

File tree

2 files changed

+29
-18
lines changed

2 files changed

+29
-18
lines changed

ss2r/benchmark_suites/rccar/rccar.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import functools
2-
from typing import Tuple
32

43
import jax
54
import jax.flatten_util
@@ -144,13 +143,13 @@ def state_reward(self, obs: jax.Array, next_obs: jax.Array) -> jax.Array:
144143
return reward
145144

146145

147-
class RCCar(Env):
148-
dim_action: Tuple[int] = (2,)
149-
_goal: jax.Array = jnp.array([0.0, 0.0, 0.0])
150-
_init_pose: jax.Array = jnp.array([1.42, -1.04, jnp.pi])
151-
_angle_idx: int = 2
152-
_obs_noise_stds: jax.Array = OBS_NOISE_STD_SIM_CAR
146+
def cost_fn(state: jax.Array, obstacle_position, obstacle_radius) -> jax.Array:
147+
xy = state[..., :2]
148+
distance = jnp.linalg.norm(xy - obstacle_position)
149+
return jnp.where(distance >= obstacle_radius, 0.0, 1.0)
150+
153151

152+
class RCCar(Env):
154153
def __init__(
155154
self,
156155
car_model_params: dict,
@@ -160,6 +159,7 @@ def __init__(
160159
margin_factor: float = 10.0,
161160
max_throttle: float = 1.0,
162161
dt: float = 1 / 30.0,
162+
obstacle: tuple[float, float, float] = (-0.75, -0.75, 0.2),
163163
):
164164
"""
165165
Race car simulator environment
@@ -173,6 +173,12 @@ def __init__(
173173
car_model_params: dictionary of car model parameters that overwrite the default values
174174
seed: random number generator seed
175175
"""
176+
self._goal = jnp.array([0.0, 0.0, 0.0])
177+
self.obstacle = tuple(obstacle)
178+
self._init_pose = jnp.array([1.42, -1.04, jnp.pi])
179+
self._angle_idx = 2
180+
self._obs_noise_stds = OBS_NOISE_STD_SIM_CAR
181+
self.dim_action = (2,)
176182
self._dt = dt
177183
self.dim_state = (7,) if encode_angle else (6,)
178184
self.encode_angle = encode_angle
@@ -226,6 +232,7 @@ def reset(self, rng: jax.Array) -> State:
226232
obs=init_state,
227233
reward=jnp.array(0.0),
228234
done=jnp.array(0.0),
235+
info={"cost": jnp.array(0.0)},
229236
)
230237

231238
def step(self, state: State, action: jax.Array) -> State:
@@ -239,14 +246,16 @@ def step(self, state: State, action: jax.Array) -> State:
239246
# FIXME (yarden): hard-coded key is bad here.
240247
next_obs = self._obs(next_dynamics_state, rng=jax.random.PRNGKey(0))
241248
reward = self.reward_model.forward(obs=None, action=action, next_obs=next_obs)
249+
cost = cost_fn(obs, jnp.asarray(self.obstacle[:2]), self.obstacle[2])
242250
done = jnp.asarray(0.0)
251+
info = {**state.info, "cost": cost}
243252
next_state = State(
244253
pipeline_state=state.pipeline_state,
245254
obs=next_obs,
246255
reward=reward,
247256
done=done,
248257
metrics=state.metrics,
249-
info=state.info,
258+
info=info,
250259
)
251260
return next_state
252261

@@ -281,6 +290,8 @@ def render(env, policy, steps, rng):
281290
if env.encode_angle:
282291
trajectory = decode_angles(trajectory, 2)
283292

293+
obstacle_position, obstacle_radius = env.obstacle[:2], env.obstacle[2]
294+
284295
def draw_scene(timestep):
285296
# Create a figure and axis
286297
fig = Figure(figsize=(2.5, 2.5), dpi=300)
@@ -301,7 +312,7 @@ def draw_scene(timestep):
301312
# Plot the car's position and velocity at the specified timestep
302313
x, y = rotated_trajectory[timestep, 0], rotated_trajectory[timestep, 1]
303314
vx, vy = rotated_trajectory[timestep, 3], rotated_trajectory[timestep, 4]
304-
car_width, car_length = 0.3, 0.6
315+
car_width, car_length = 0.07, 0.2
305316
car = Rectangle(
306317
(x - car_length / 2, y - car_width / 2),
307318
car_length,
@@ -313,17 +324,15 @@ def draw_scene(timestep):
313324
rotation_point="center",
314325
)
315326
ax.add_patch(car)
316-
# Add an arrow to indicate the car's orientation
317-
ax.arrow(
318-
x,
319-
y,
320-
vx * 0.5,
321-
vy * 0.5,
322-
head_width=0.2,
323-
head_length=0.2,
324-
fc="black",
327+
obstacle = Circle(
328+
obstacle_position,
329+
obstacle_radius,
330+
color="gray",
331+
alpha=0.5,
325332
ec="black",
333+
lw=1.5,
326334
)
335+
ax.add_patch(obstacle)
327336
ax.quiver(
328337
x,
329338
y,
@@ -336,6 +345,7 @@ def draw_scene(timestep):
336345
headwidth=3,
337346
linewidth=0.5,
338347
)
348+
ax.grid(True, linewidth=0.5, c="gainsboro", zorder=0)
339349
# Render figure to canvas and retrieve RGB array
340350
canvas.draw()
341351
image = np.frombuffer(canvas.tostring_rgb(), dtype="uint8").copy()

ss2r/configs/environment/rccar.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ ctrl_cost_weight: 0.005
1111
margin_factor: 20.0
1212
max_throttle: 1.0
1313
use_obs_noise: false
14+
obstacle: [-0.75, -0.75, 0.2] # x, y, radius

0 commit comments

Comments
 (0)