Skip to content

Commit 5542aaf

Browse files
committed
Merge branch 'main' of github.com:yardenas/safe-sim2real
2 parents 8c5f6c1 + 3bde8d0 commit 5542aaf

File tree

13 files changed

+93
-371
lines changed

13 files changed

+93
-371
lines changed

ss2r/algorithms/sac/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
import ss2r.algorithms.sac.losses as sac_losses
3737
import ss2r.algorithms.sac.networks as sac_networks
38+
from ss2r.rl.evaluation import ConstraintsEvaluator
3839

3940
Metrics: TypeAlias = types.Metrics
4041
Transition: TypeAlias = types.Transition
@@ -534,7 +535,7 @@ def training_epoch_with_timing(
534535
randomization_fn=vf_randomization_fn,
535536
)
536537

537-
evaluator = acting.Evaluator(
538+
evaluator = ConstraintsEvaluator(
538539
eval_env,
539540
functools.partial(make_policy, deterministic=deterministic_eval),
540541
num_eval_envs=num_eval_envs,

ss2r/benchmark_suites/rccar/rccar.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def __init__(
159159
margin_factor: float = 10.0,
160160
max_throttle: float = 1.0,
161161
dt: float = 1 / 30.0,
162-
obstacle: tuple[float, float, float] = (-0.75, -0.75, 0.2),
162+
obstacle: tuple[float, float, float] = (0.75, -0.75, 0.2),
163163
):
164164
"""
165165
Race car simulator environment
@@ -173,21 +173,21 @@ def __init__(
173173
car_model_params: dictionary of car model parameters that overwrite the default values
174174
seed: random number generator seed
175175
"""
176-
self._goal = jnp.array([0.0, 0.0, 0.0])
176+
self.goal = jnp.array([0.0, 0.0, 0.0])
177177
self.obstacle = tuple(obstacle)
178178
self._init_pose = jnp.array([1.42, -1.04, jnp.pi])
179-
self._angle_idx = 2
179+
self.angle_idx = 2
180180
self._obs_noise_stds = OBS_NOISE_STD_SIM_CAR
181181
self.dim_action = (2,)
182-
self._dt = dt
182+
self.dt = dt
183183
self.dim_state = (7,) if encode_angle else (6,)
184184
self.encode_angle = encode_angle
185185
self.max_throttle = jnp.clip(max_throttle, 0.0, 1.0)
186-
self.dynamics_model = RaceCarDynamics(dt=self._dt)
186+
self.dynamics_model = RaceCarDynamics(dt=self.dt)
187187
self.sys = CarParams(**car_model_params)
188188
self.use_obs_noise = use_obs_noise
189189
self.reward_model = RCCarEnvReward(
190-
goal=self._goal,
190+
goal=self.goal,
191191
ctrl_cost_weight=ctrl_cost_weight,
192192
encode_angle=self.encode_angle,
193193
margin_factor=margin_factor,
@@ -205,7 +205,7 @@ def _obs(self, state: jnp.array, rng: jax.random.PRNGKey) -> jnp.array:
205205
obs = state
206206
# encode angle to sin(theta) ant cos(theta) if desired
207207
if self.encode_angle:
208-
obs = encode_angles(obs, self._angle_idx)
208+
obs = encode_angles(obs, self.angle_idx)
209209
assert (obs.shape[-1] == 7 and self.encode_angle) or (
210210
obs.shape[-1] == 6 and not self.encode_angle
211211
)
@@ -241,7 +241,7 @@ def step(self, state: State, action: jax.Array) -> State:
241241
action = action.at[0].set(self.max_throttle * action[0])
242242
obs = state.obs
243243
if self.encode_angle:
244-
dynamics_state = decode_angles(obs, self._angle_idx)
244+
dynamics_state = decode_angles(obs, self.angle_idx)
245245
next_dynamics_state = self.dynamics_model.step(dynamics_state, action, self.sys)
246246
# FIXME (yarden): hard-coded key is bad here.
247247
next_obs = self._obs(next_dynamics_state, rng=jax.random.PRNGKey(0))
@@ -259,10 +259,6 @@ def step(self, state: State, action: jax.Array) -> State:
259259
)
260260
return next_state
261261

262-
@property
263-
def dt(self):
264-
return self._dt
265-
266262
@property
267263
def observation_size(self) -> int:
268264
if self.encode_angle:
@@ -289,8 +285,8 @@ def render(env, policy, steps, rng):
289285
trajectory = jax.tree_map(lambda x: x[:, 0], trajectory.obs)
290286
if env.encode_angle:
291287
trajectory = decode_angles(trajectory, 2)
292-
293288
obstacle_position, obstacle_radius = env.obstacle[:2], env.obstacle[2]
289+
obstacle_position = jnp.array([obstacle_position[1], -obstacle_position[0]])
294290

295291
def draw_scene(timestep):
296292
# Create a figure and axis

ss2r/common/learner.py

Lines changed: 0 additions & 50 deletions
This file was deleted.
File renamed without changes.

ss2r/configs/environment/rccar.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ ctrl_cost_weight: 0.005
1111
margin_factor: 20.0
1212
max_throttle: 1.0
1313
use_obs_noise: false
14-
obstacle: [-0.75, -0.75, 0.2] # x, y, radius
14+
obstacle: [0.75, -0.75, 0.2] # x, y, radius

ss2r/rl/acting.py

Lines changed: 0 additions & 15 deletions
This file was deleted.

ss2r/rl/epoch_summary.py

Lines changed: 0 additions & 60 deletions
This file was deleted.

ss2r/rl/evaluation.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from typing import Callable
2+
3+
import jax
4+
import jax.numpy as jnp
5+
from brax.envs.base import Env, State
6+
from brax.envs.wrappers.training import EvalMetrics, EvalWrapper
7+
from brax.training.acting import Evaluator, generate_unroll
8+
from brax.training.types import Policy, PolicyParams, PRNGKey
9+
10+
11+
class ConstraintEvalWrapper(EvalWrapper):
12+
def reset(self, rng: jax.Array) -> State:
13+
reset_state = self.env.reset(rng)
14+
reset_state.metrics["reward"] = reset_state.reward
15+
reset_state.metrics["cost"] = reset_state.info.get("cost", jnp.array(0.0))
16+
eval_metrics = EvalMetrics(
17+
episode_metrics=jax.tree_util.tree_map(jnp.zeros_like, reset_state.metrics),
18+
active_episodes=jnp.ones_like(reset_state.reward),
19+
episode_steps=jnp.zeros_like(reset_state.reward),
20+
)
21+
reset_state.info["eval_metrics"] = eval_metrics
22+
return reset_state
23+
24+
def step(self, state: State, action: jax.Array) -> State:
25+
state_metrics = state.info["eval_metrics"]
26+
if not isinstance(state_metrics, EvalMetrics):
27+
raise ValueError(f"Incorrect type for state_metrics: {type(state_metrics)}")
28+
del state.info["eval_metrics"]
29+
nstate = self.env.step(state, action)
30+
nstate.metrics["reward"] = nstate.reward
31+
nstate.metrics["cost"] = nstate.info.get("cost", jnp.array(0.0))
32+
episode_steps = jnp.where(
33+
state_metrics.active_episodes,
34+
nstate.info["steps"],
35+
state_metrics.episode_steps,
36+
)
37+
episode_metrics = jax.tree_util.tree_map(
38+
lambda a, b: a + b * state_metrics.active_episodes,
39+
state_metrics.episode_metrics,
40+
nstate.metrics,
41+
)
42+
active_episodes = state_metrics.active_episodes * (1 - nstate.done)
43+
eval_metrics = EvalMetrics(
44+
episode_metrics=episode_metrics,
45+
active_episodes=active_episodes,
46+
episode_steps=episode_steps,
47+
)
48+
nstate.info["eval_metrics"] = eval_metrics
49+
return nstate
50+
51+
52+
class ConstraintsEvaluator(Evaluator):
53+
def __init__(
54+
self,
55+
eval_env: Env,
56+
eval_policy_fn: Callable[[PolicyParams], Policy],
57+
num_eval_envs: int,
58+
episode_length: int,
59+
action_repeat: int,
60+
key: jax.Array,
61+
):
62+
self._key = key
63+
self._eval_walltime = 0.0
64+
eval_env = ConstraintEvalWrapper(eval_env)
65+
66+
def generate_eval_unroll(policy_params: PolicyParams, key: PRNGKey) -> State:
67+
reset_keys = jax.random.split(key, num_eval_envs)
68+
eval_first_state = eval_env.reset(reset_keys)
69+
return generate_unroll(
70+
eval_env,
71+
eval_first_state,
72+
eval_policy_fn(policy_params),
73+
key,
74+
unroll_length=episode_length // action_repeat,
75+
)[0]
76+
77+
self._generate_eval_unroll = jax.jit(generate_eval_unroll)
78+
self._steps_per_unroll = episode_length * num_eval_envs

0 commit comments

Comments
 (0)