Skip to content

Commit e960e3a

Browse files
authored
Robust (#16)
* New interface for Q tranformations * Implementation of CVaR * Add costs
1 parent 2c6de02 commit e960e3a

File tree

6 files changed

+197
-77
lines changed

6 files changed

+197
-77
lines changed

ss2r/algorithms/sac/losses.py

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from brax.training.agents.sac import networks as sac_networks
2626
from brax.training.types import Params, PRNGKey
2727

28+
from ss2r.algorithms.sac.robustness import QTransformation, SACBase
29+
2830
Transition: TypeAlias = types.Transition
2931

3032

@@ -70,10 +72,12 @@ def critic_loss(
7072
alpha: jnp.ndarray,
7173
transitions: Transition,
7274
key: PRNGKey,
73-
exploration_bonus: bool = True,
7475
safe: bool = False,
76+
target_q_fn: QTransformation = SACBase(),
7577
) -> jnp.ndarray:
76-
domain_params = transitions.extras.get("domain_parameters", None)
78+
domain_params = transitions.extras["state_extras"].get(
79+
"domain_parameters", None
80+
)
7781
if domain_params is not None:
7882
action = jnp.concatenate([transitions.action, domain_params], axis=-1)
7983
else:
@@ -83,46 +87,30 @@ def critic_loss(
8387
q_old_action = q_network.apply(
8488
normalizer_params, q_params, transitions.observation, action
8589
)
86-
next_dist_params = policy_network.apply(
87-
normalizer_params, policy_params, transitions.next_observation
88-
)
89-
next_action = parametric_action_distribution.sample_no_postprocessing(
90-
next_dist_params, key
91-
)
92-
next_log_prob = parametric_action_distribution.log_prob(
93-
next_dist_params, next_action
94-
)
95-
next_action = parametric_action_distribution.postprocess(next_action)
96-
if domain_params is not None:
97-
next_action = jnp.concatenate([next_action, domain_params], axis=-1)
98-
next_q = q_network.apply(
99-
normalizer_params,
100-
target_q_params,
101-
transitions.next_observation,
102-
next_action,
103-
)
104-
if safe:
105-
next_v = jnp.mean(next_q, axis=-1)
106-
else:
107-
next_v = jnp.min(next_q, axis=-1)
108-
if exploration_bonus:
109-
next_v -= alpha * next_log_prob
110-
reward = transitions.reward
111-
if safe:
112-
assert "imagined_cost" in transitions.extras or "cost" in transitions.extras
113-
reward = transitions.extras.get(
114-
"imagined_cost",
115-
transitions.extras.get("cost", jnp.zeros_like(transitions.reward)),
90+
91+
def policy(obs: jax.Array) -> tuple[jax.Array, jax.Array]:
92+
next_dist_params = policy_network.apply(
93+
normalizer_params, policy_params, obs
94+
)
95+
next_action = parametric_action_distribution.sample_no_postprocessing(
96+
next_dist_params, key
11697
)
117-
target_q = jax.lax.stop_gradient(
118-
reward * reward_scaling + transitions.discount * gamma * next_v
98+
next_log_prob = parametric_action_distribution.log_prob(
99+
next_dist_params, next_action
100+
)
101+
next_action = parametric_action_distribution.postprocess(next_action)
102+
return next_action, next_log_prob
103+
104+
q_fn = lambda obs, action: q_network.apply(
105+
normalizer_params, target_q_params, obs, action
106+
)
107+
target_q = target_q_fn(
108+
transitions, q_fn, policy, gamma, domain_params, alpha, reward_scaling
119109
)
120110
q_error = q_old_action - jnp.expand_dims(target_q, -1)
121-
122111
# Better bootstrapping for truncated episodes.
123112
truncation = transitions.extras["state_extras"]["truncation"]
124113
q_error *= jnp.expand_dims(1 - truncation, -1)
125-
126114
q_loss = 0.5 * jnp.mean(jnp.square(q_error))
127115
return q_loss
128116

@@ -145,7 +133,9 @@ def actor_loss(
145133
)
146134
log_prob = parametric_action_distribution.log_prob(dist_params, action)
147135
action = parametric_action_distribution.postprocess(action)
148-
domain_params = transitions.extras.get("domain_parameters", None)
136+
domain_params = transitions.extras["state_extras"].get(
137+
"domain_parameters", None
138+
)
149139
if domain_params is not None:
150140
action = jnp.concatenate([action, domain_params], axis=-1)
151141
qr_action = qr_network.apply(

ss2r/algorithms/sac/robustness.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
from typing import Callable, Protocol
2+
3+
import jax
4+
import jax.numpy as jnp
5+
from brax.training.types import Params, Transition
6+
7+
8+
class QTransformation(Protocol):
9+
def __call__(
10+
self,
11+
transitions: Transition,
12+
q_fn: Callable[[Params, jax.Array], jax.Array],
13+
policy: Callable[[jax.Array], tuple[jax.Array, jax.Array]],
14+
gamma: float,
15+
domain_params: jax.Array | None = None,
16+
alpha: jax.Array | None = None,
17+
reward_scaling: float = 1.0,
18+
):
19+
...
20+
21+
22+
class LCB(QTransformation):
23+
def __init__(self, lambda_: float) -> None:
24+
self.lambda_ = lambda_
25+
26+
def __call__(
27+
self,
28+
transitions: Transition,
29+
q_fn: Callable[[Params, jax.Array], jax.Array],
30+
policy: Callable[[jax.Array], tuple[jax.Array, jax.Array]],
31+
gamma: float,
32+
domain_params: jax.Array | None = None,
33+
alpha: jax.Array | None = None,
34+
reward_scaling: float = 1.0,
35+
):
36+
next_obs = transitions.extras["state_extras"]["state_propagation"]["next_obs"]
37+
next_action, _ = policy(next_obs)
38+
if domain_params is not None:
39+
domain_params = jnp.tile(
40+
domain_params[:, None], (1, next_action.shape[1], 1)
41+
)
42+
next_action = jnp.concatenate([next_action, domain_params], axis=-1)
43+
next_q = q_fn(next_obs, next_action)
44+
next_v = next_q.mean(axis=-1)
45+
std = jnp.std(next_v, axis=-1)
46+
cost = transitions.extras["state_extras"]["cost"]
47+
cost += self.lambda_ * std
48+
target_q = jax.lax.stop_gradient(
49+
cost * reward_scaling + transitions.discount * gamma * next_v
50+
)
51+
return target_q
52+
53+
54+
class CVaR(QTransformation):
55+
def __init__(self, confidence: float) -> None:
56+
self.confidence = confidence
57+
58+
def __call__(
59+
self,
60+
transitions: Transition,
61+
q_fn: Callable[[Params, jax.Array], jax.Array],
62+
policy: Callable[[jax.Array], tuple[jax.Array, jax.Array]],
63+
gamma: float,
64+
domain_params: jax.Array | None = None,
65+
alpha: jax.Array | None = None,
66+
reward_scaling: float = 1.0,
67+
):
68+
next_obs = transitions.extras["state_extras"]["state_propagation"]["next_obs"]
69+
next_action, _ = policy(next_obs)
70+
if domain_params is not None:
71+
domain_params = jnp.tile(
72+
domain_params[:, None], (1, next_action.shape[1], 1)
73+
)
74+
next_action = jnp.concatenate([next_action, domain_params], axis=-1)
75+
next_q = q_fn(next_obs, next_action)
76+
next_v = next_q.mean(axis=-1)
77+
sort_next_v = jnp.sort(next_v, axis=-1)
78+
cvar_index = int((1 - self.confidence) * next_v.shape[1])
79+
next_v = jnp.mean(sort_next_v[:, :cvar_index], axis=-1)
80+
cost = transitions.extras["state_extras"]["cost"]
81+
target_q = jax.lax.stop_gradient(
82+
cost * reward_scaling + transitions.discount * gamma * next_v
83+
)
84+
return target_q
85+
86+
87+
class SACBase(QTransformation):
88+
def __call__(
89+
self,
90+
transitions: Transition,
91+
q_fn: Callable[[Params, jax.Array], jax.Array],
92+
policy: Callable[[jax.Array], tuple[jax.Array, jax.Array]],
93+
gamma: float,
94+
domain_params: jax.Array | None = None,
95+
alpha: jax.Array | None = None,
96+
reward_scaling: float = 1.0,
97+
):
98+
next_action, next_log_prob = policy(transitions.next_observation)
99+
if domain_params is not None:
100+
next_action = jnp.concatenate([next_action, domain_params], axis=-1)
101+
next_q = q_fn(transitions.next_observation, next_action)
102+
next_v = next_q.min(axis=-1)
103+
next_v -= alpha * next_log_prob
104+
target_q = jax.lax.stop_gradient(
105+
transitions.reward * reward_scaling + transitions.discount * gamma * next_v
106+
)
107+
return target_q
108+
109+
110+
class SACCost(QTransformation):
111+
def __call__(
112+
self,
113+
transitions: Transition,
114+
q_fn: Callable[[Params, jax.Array], jax.Array],
115+
policy: Callable[[jax.Array], tuple[jax.Array, jax.Array]],
116+
gamma: float,
117+
domain_params: jax.Array | None = None,
118+
alpha: jax.Array | None = None,
119+
reward_scaling: float = 1.0,
120+
):
121+
next_action, _ = policy(transitions.next_observation)
122+
if domain_params is not None:
123+
next_action = jnp.concatenate([next_action, domain_params], axis=-1)
124+
next_q = q_fn(transitions.next_observation, next_action)
125+
next_v = next_q.mean(axis=-1)
126+
cost = transitions.extras["state_extras"]["cost"]
127+
target_q = jax.lax.stop_gradient(
128+
cost * reward_scaling + transitions.discount * gamma * next_v
129+
)
130+
return target_q

ss2r/algorithms/sac/train.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,8 @@
3535

3636
import ss2r.algorithms.sac.losses as sac_losses
3737
import ss2r.algorithms.sac.networks as sac_networks
38-
from ss2r.algorithms.sac.wrappers import (
39-
DomainRandomizationParams,
40-
StatePropagation,
41-
std_bonus,
42-
)
38+
from ss2r.algorithms.sac.robustness import SACCost
39+
from ss2r.algorithms.sac.wrappers import DomainRandomizationParams, StatePropagation
4340
from ss2r.rl.evaluation import ConstraintsEvaluator
4441

4542
Metrics: TypeAlias = types.Metrics
@@ -172,6 +169,8 @@ def train(
172169
lagrange_multiplier: float = 1e-9,
173170
penalty_multiplier: float = 1.0,
174171
penalty_multiplier_factor: float = 1.0,
172+
cost_q_transform: str | None = None,
173+
cvar_confidence: float = 0.95,
175174
):
176175
"""SAC training."""
177176
process_id = jax.process_index()
@@ -247,12 +246,7 @@ def train(
247246
else:
248247
domain_parameters = None
249248
if propagation is not None:
250-
cost_penalty_fn = (
251-
functools.partial(std_bonus, lambda_=cost_penalty)
252-
if cost_penalty is not None
253-
else None
254-
)
255-
env = StatePropagation(env, cost_penalty_fn=cost_penalty_fn)
249+
env = StatePropagation(env)
256250

257251
obs_size = env.observation_size
258252
action_size = env.action_size
@@ -287,12 +281,14 @@ def train(
287281
"policy_extras": {},
288282
}
289283
if domain_parameters is not None:
290-
extras["domain_parameters"] = domain_parameters[0]
284+
extras["state_extras"]["domain_parameters"] = domain_parameters[0] # type: ignore
291285
if safe:
292-
if propagation is not None and cost_penalty is not None:
293-
extras["imagined_cost"] = 0.0
294-
else:
295-
extras["cost"] = 0.0
286+
if propagation is not None:
287+
extras["state_extras"]["state_propagation"] = { # type: ignore
288+
"next_obs": jnp.tile(dummy_obs, (num_envs,) + (1,) * dummy_obs.ndim),
289+
"rng": rng,
290+
}
291+
extras["state_extras"]["cost"] = 0.0 # type: ignore
296292

297293
dummy_transition = Transition( # pytype: disable=wrong-arg-types # jax-ndarray
298294
observation=dummy_obs,
@@ -372,8 +368,8 @@ def sgd_step(
372368
alpha,
373369
transitions,
374370
key_critic,
375-
False,
376371
True,
372+
SACCost(),
377373
optimizer_state=training_state.qc_optimizer_state,
378374
)
379375
cost_metrics = {
@@ -457,9 +453,13 @@ def get_experience(
457453
ReplayBufferState,
458454
]:
459455
policy = make_policy((normalizer_params, policy_params))
460-
extra_fields = ("truncation",) + tuple(
461-
key for key in extras.keys() if key not in ["state_extras", "policy_extras"]
462-
)
456+
extra_fields = ("truncation",)
457+
if domain_parameters is not None:
458+
extra_fields += ("domain_parameters",) # type: ignore
459+
if propagation is not None:
460+
extra_fields += ("state_propagation",) # type: ignore
461+
if safe:
462+
extra_fields += ("cost",) # type: ignore
463463
step = lambda state: acting.actor_step(
464464
env, state, policy, key, extra_fields=extra_fields
465465
)
@@ -469,10 +469,7 @@ def get_experience(
469469
normalizer_params, transitions.observation, pmap_axis_name=_PMAP_AXIS_NAME
470470
)
471471
if transitions.observation.ndim == 3:
472-
transitions = jax.tree_util.tree_map(
473-
lambda x: x[0],
474-
transitions,
475-
)
472+
transitions = jax.tree_util.tree_map(lambda x: x[0], transitions)
476473
buffer_state = replay_buffer.insert(buffer_state, transitions)
477474
return normalizer_params, env_state, buffer_state
478475

ss2r/algorithms/sac/wrappers.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,39 +28,40 @@ class StatePropagation(Wrapper):
2828
This wrapper assumes that the environment is wrapped before with a VmapWrapper or DomainRandomizationVmapWrapper
2929
"""
3030

31-
def __init__(self, env, propagation_fn=ts1, cost_penalty_fn=None):
31+
def __init__(self, env, propagation_fn=ts1):
3232
super().__init__(env)
33-
self.cost_penalty_fn = cost_penalty_fn
3433
self.propagation_fn = propagation_fn
3534
self.num_envs = None
3635

3736
def reset(self, rng: jax.Array) -> State:
3837
if self.num_envs is None:
3938
self.num_envs = rng.shape[0]
4039
state = self.env.reset(rng)
41-
if "propagation_rng" in state.info:
42-
propagation_rng = state.info["propagation_rng"]
43-
else:
44-
propagation_rng = jax.random.split(rng[0])[1]
40+
propagation_rng = jax.random.split(rng[0])[1]
4541
n_key, key = jax.random.split(propagation_rng)
46-
state.info["propagation_rng"] = jax.random.split(n_key, self.num_envs)
47-
state.info["imagined_cost"] = jnp.zeros(self.num_envs)
48-
return self.propagation_fn(state, key)
42+
state.info["state_propagation"] = {}
43+
state.info["state_propagation"]["rng"] = jax.random.split(n_key, self.num_envs)
44+
orig_next_obs = state.obs
45+
state = self.propagation_fn(state, key)
46+
state.info["state_propagation"]["next_obs"] = orig_next_obs
47+
return state
4948

5049
def step(self, state: State, action: jax.Array) -> State:
5150
# The order here matters, the tree_map changes the dimensions of
5251
# the propgattion_rng
53-
propagation_rng = state.info["propagation_rng"]
52+
propagation_rng = state.info["state_propagation"]["rng"]
5453
tile = lambda tree: jax.tree_map(
5554
lambda x: jnp.tile(x, (self.num_envs,) + (1,) * x.ndim), tree
5655
)
5756
state, action = tile(state), tile(action)
5857
nstate = self.env.step(state, action)
5958
n_key, key = jax.random.split(propagation_rng)
60-
nstate.info["propagation_rng"] = jax.random.split(n_key, self.num_envs)
61-
if self.cost_penalty_fn is not None:
62-
nstate.info["imagined_cost"] += self.cost_penalty_fn(nstate)
63-
return self.propagation_fn(nstate, key)
59+
orig_next_obs = nstate.obs
60+
nstate.info["state_propagation"]["rng"] = jax.random.split(n_key, self.num_envs)
61+
nstate.info["state_propagation"]["next_obs"] = nstate.obs
62+
nstate = self.propagation_fn(nstate, key)
63+
nstate.info["state_propagation"]["next_obs"] = orig_next_obs
64+
return nstate
6465

6566

6667
def get_randomized_values(sys_v, in_axes):

ss2r/configs/agent/sac.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,6 @@ cost_penalty: null
1919
propagation: standard
2020
lagrange_multiplier: 0.0001
2121
penalty_multiplier: 5e-8
22-
penalty_multiplier_factor: 8e-6
22+
penalty_multiplier_factor: 8e-6
23+
cost_q_transform: cvar
24+
cvar_confidence: 0.95

0 commit comments

Comments
 (0)