Skip to content

Commit 7927cd4

Browse files
authored
Add penalizer to mbpo (#131)
* Training of behavior qc * Rename to backup and behavior * Don't terminate or stop rewards if penalizer is on
1 parent 330d510 commit 7927cd4

File tree

10 files changed

+173
-80
lines changed

10 files changed

+173
-80
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ skip-magic-trailing-comma = false
100100
line-ending = "auto"
101101

102102
[tool.mypy]
103-
plugins = ["numpy.typing.mypy_plugin"]
104103
ignore_missing_imports = true
105104
show_column_numbers = true
106105
disallow_untyped_defs = false

ss2r/algorithms/mbpo/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import functools
22

33
import ss2r.algorithms.mbpo.networks as mbpo_networks
4+
from ss2r.algorithms.penalizers import get_penalizer
45
from ss2r.algorithms.sac.data import get_collection_fn
56
from ss2r.algorithms.sac.q_transforms import (
67
get_cost_q_transform,
@@ -56,6 +57,7 @@ def get_train_fn(cfg, checkpoint_path, restore_checkpoint_path):
5657
value_obs_key=value_obs_key,
5758
policy_obs_key=policy_obs_key,
5859
)
60+
penalizer, penalizer_params = get_penalizer(cfg)
5961
reward_q_transform = get_reward_q_transform(cfg)
6062
cost_q_transform = get_cost_q_transform(cfg)
6163
data_collection = get_collection_fn(cfg)
@@ -67,6 +69,8 @@ def get_train_fn(cfg, checkpoint_path, restore_checkpoint_path):
6769
checkpoint_logdir=checkpoint_path,
6870
reward_q_transform=reward_q_transform,
6971
cost_q_transform=cost_q_transform,
72+
penalizer=penalizer,
73+
penalizer_params=penalizer_params,
7074
get_experience_fn=data_collection,
7175
restore_checkpoint_path=restore_checkpoint_path,
7276
)

ss2r/algorithms/mbpo/losses.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from brax.training.types import Params, PRNGKey
2727

2828
from ss2r.algorithms.mbpo.networks import MBPONetworks
29+
from ss2r.algorithms.penalizers import Penalizer
2930
from ss2r.algorithms.sac.q_transforms import QTransformation
3031

3132
Transition: TypeAlias = types.Transition
@@ -45,6 +46,7 @@ def make_losses(
4546
target_entropy = -0.5 * action_size if init_alpha is None else init_alpha
4647
policy_network = mbpo_network.policy_network
4748
qr_network = mbpo_network.qr_network
49+
qc_network = mbpo_network.qc_network
4850
parametric_action_distribution = mbpo_network.parametric_action_distribution
4951

5052
def alpha_loss(
@@ -122,9 +124,13 @@ def actor_loss(
122124
policy_params: Params,
123125
normalizer_params: Any,
124126
qr_params: Params,
127+
qc_params: Params | None,
125128
alpha: jnp.ndarray,
126129
transitions: Transition,
127130
key: PRNGKey,
131+
safety_budget: float,
132+
penalizer: Penalizer | None,
133+
penalizer_params: Any,
128134
) -> jnp.ndarray:
129135
dist_params = policy_network.apply(
130136
normalizer_params, policy_params, transitions.observation
@@ -143,8 +149,23 @@ def actor_loss(
143149
qr = jnp.min(qr_action, axis=-1)
144150
actor_loss = -qr.mean()
145151
exploration_loss = (alpha * log_prob).mean()
152+
aux = {}
153+
if penalizer is not None:
154+
assert qc_network is not None
155+
qc_action = qc_network.apply(
156+
normalizer_params, qc_params, transitions.observation, action
157+
)
158+
mean_qc = jnp.mean(qc_action, axis=-1)
159+
constraint = safety_budget - mean_qc.mean() / cost_scaling
160+
actor_loss, penalizer_aux, penalizer_params = penalizer(
161+
actor_loss, constraint, jax.lax.stop_gradient(penalizer_params)
162+
)
163+
aux["constraint_estimate"] = constraint
164+
aux["cost"] = mean_qc.mean() / cost_scaling
165+
aux["penalizer_params"] = penalizer_params
166+
aux |= penalizer_aux
146167
actor_loss += exploration_loss
147-
return actor_loss
168+
return actor_loss, aux
148169

149170
def compute_model_loss(model_params, normalizer_params, data, obs_key="state"):
150171
model_apply = jax.vmap(mbpo_network.model_network.apply, (None, 0, None, None))

ss2r/algorithms/mbpo/model_env.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(
2626
self.model_network = mbpo_network.model_network
2727
self.model_params = training_state.model_params
2828
self.qc_network = mbpo_network.qc_network
29-
self.qc_params = training_state.qc_params
29+
self.backup_qc_params = training_state.backup_qc_params
3030
self.qr_network = mbpo_network.qr_network
3131
self.backup_qr_params = training_state.backup_qr_params
3232
self.policy_network = mbpo_network.policy_network
@@ -86,7 +86,7 @@ def step(self, state: base.State, action: jax.Array) -> base.State:
8686
expected_cost_for_traj = prev_cumulative_cost + self.scaling_fn(
8787
self.qc_network.apply(
8888
self.normalizer_params,
89-
self.qc_params,
89+
self.backup_qc_params,
9090
state.obs,
9191
action,
9292
).mean(axis=-1)

ss2r/algorithms/mbpo/safe_rollout.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ def get_inference_policy_params(safe: bool, safety_budget=float("inf")) -> Any:
1313
def get_params(training_state: TrainingState) -> Any:
1414
if safe:
1515
return (
16-
training_state.policy_params,
17-
training_state.qc_params,
16+
training_state.behavior_policy_params,
17+
training_state.backup_qc_params,
1818
safety_budget,
1919
)
2020
else:
21-
return training_state.policy_params
21+
return training_state.behavior_policy_params
2222

2323
return get_params
2424

ss2r/algorithms/mbpo/train.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
)
4242
from ss2r.algorithms.mbpo.training_step import make_training_step
4343
from ss2r.algorithms.mbpo.types import TrainingState
44+
from ss2r.algorithms.penalizers import Params, Penalizer
4445
from ss2r.algorithms.sac import gradients
4546
from ss2r.algorithms.sac.data import collect_single_step
4647
from ss2r.algorithms.sac.q_transforms import QTransformation, SACBase, SACCost
@@ -87,6 +88,7 @@ def _init_training_state(
8788
qc_optimizer: optax.GradientTransformation,
8889
model_optimizer: optax.GradientTransformation,
8990
model_ensemble_size: int,
91+
penalizer_params: Params | None,
9092
) -> TrainingState:
9193
"""Inits the training state and replicates it over devices."""
9294
key_policy, key_qr, key_model = jax.random.split(key, 3)
@@ -101,16 +103,14 @@ def _init_training_state(
101103
model_params = init_model_ensemble(model_keys)
102104
model_optimizer_state = model_optimizer.init(model_params)
103105
if mbpo_network.qc_network is not None:
104-
qc_params = mbpo_network.qc_network.init(key_qr)
106+
backup_qc_params = mbpo_network.qc_network.init(key_qr)
105107
assert qc_optimizer is not None
106-
qc_optimizer_state = qc_optimizer.init(qc_params)
108+
backup_qc_optimizer_state = qc_optimizer.init(backup_qc_params)
107109
backup_qr_params = qr_params
108-
backup_qr_optimizer_state = qr_optimizer_state
109110
else:
110-
qc_params = None
111-
qc_optimizer_state = None
111+
backup_qc_params = None
112+
backup_qc_optimizer_state = None
112113
backup_qr_params = None
113-
backup_qr_optimizer_state = None
114114
if isinstance(obs_size, Mapping):
115115
obs_shape = {
116116
k: specs.Array(v, jnp.dtype("float32")) for k, v in obs_size.items()
@@ -119,24 +119,27 @@ def _init_training_state(
119119
obs_shape = specs.Array((obs_size,), jnp.dtype("float32"))
120120
normalizer_params = running_statistics.init_state(obs_shape)
121121
training_state = TrainingState(
122-
policy_optimizer_state=policy_optimizer_state,
123-
policy_params=policy_params,
122+
behavior_policy_optimizer_state=policy_optimizer_state,
123+
behavior_policy_params=policy_params,
124124
backup_policy_params=policy_params,
125-
qr_optimizer_state=qr_optimizer_state,
126-
qr_params=qr_params,
127-
backup_qr_optimizer_state=backup_qr_optimizer_state,
125+
behavior_qr_optimizer_state=qr_optimizer_state,
126+
behavior_qr_params=qr_params,
128127
backup_qr_params=backup_qr_params,
129-
qc_optimizer_state=qc_optimizer_state,
130-
qc_params=qc_params,
131-
target_qr_params=qr_params,
132-
target_qc_params=qc_params,
128+
behavior_qc_optimizer_state=backup_qc_optimizer_state,
129+
behavior_qc_params=backup_qc_params,
130+
behavior_target_qr_params=qr_params,
131+
behavior_target_qc_params=backup_qc_params,
132+
backup_qc_params=backup_qc_params,
133+
backup_qc_optimizer_state=backup_qc_optimizer_state,
134+
backup_target_qc_params=backup_qc_params,
133135
model_params=model_params,
134136
model_optimizer_state=model_optimizer_state,
135137
gradient_steps=jnp.zeros(()),
136138
env_steps=jnp.zeros(()),
137139
alpha_optimizer_state=alpha_optimizer_state,
138140
alpha_params=log_alpha,
139141
normalizer_params=normalizer_params,
142+
penalizer_params=penalizer_params,
140143
) # type: ignore
141144
return training_state
142145

@@ -188,6 +191,8 @@ def train(
188191
eval_env: Optional[envs.Env] = None,
189192
safe: bool = False,
190193
safety_budget: float = float("inf"),
194+
penalizer: Penalizer | None = None,
195+
penalizer_params: Params | None = None,
191196
reward_q_transform: QTransformation = SACBase(),
192197
cost_q_transform: QTransformation = SACCost(),
193198
use_bro: bool = True,
@@ -302,6 +307,7 @@ def train(
302307
qc_optimizer=qc_optimizer,
303308
model_optimizer=model_optimizer,
304309
model_ensemble_size=model_ensemble_size,
310+
penalizer_params=penalizer_params,
305311
)
306312
del global_key
307313
local_key, model_rb_key, actor_critic_rb_key, env_key, eval_key = jax.random.split(
@@ -318,13 +324,13 @@ def train(
318324
ts_normalizer_params = params[0]
319325
training_state = training_state.replace( # type: ignore
320326
normalizer_params=ts_normalizer_params,
321-
policy_params=params[1],
327+
behavior_policy_params=params[1],
322328
backup_policy_params=params[1],
323-
qr_params=params[3],
329+
behavior_qr_params=params[3],
324330
backup_qr_params=params[3],
325-
qc_params=params[4] if safe else None,
331+
behavior_qc_params=params[4] if safe else None,
332+
backup_qc_params=params[4] if safe else None,
326333
)
327-
328334
make_planning_policy = mbpo_networks.make_inference_fn(mbpo_network)
329335
if safe:
330336
make_rollout_policy = make_safe_inference_fn(
@@ -386,7 +392,7 @@ def train(
386392
)
387393
actor_update = (
388394
gradients.gradient_update_fn( # pytype: disable=wrong-arg-types # jax-ndarray
389-
actor_loss, policy_optimizer, pmap_axis_name=None
395+
actor_loss, policy_optimizer, pmap_axis_name=None, has_aux=True
390396
)
391397
)
392398
extra_fields = ("truncation",)
@@ -402,7 +408,7 @@ def train(
402408
safety_budget=safety_budget,
403409
cost_discount=safety_discounting,
404410
scaling_fn=budget_scaling_fun,
405-
use_termination=use_termination,
411+
use_termination=penalizer is not None and use_termination,
406412
)
407413
training_step = make_training_step(
408414
env,
@@ -434,7 +440,9 @@ def train(
434440
pessimism,
435441
model_to_real_data_ratio,
436442
budget_scaling_fun,
437-
use_termination=use_termination,
443+
use_termination,
444+
penalizer,
445+
safety_budget,
438446
)
439447

440448
def prefill_replay_buffer(
@@ -635,9 +643,9 @@ def training_epoch_with_timing(
635643
# Save current policy.
636644
params = (
637645
training_state.normalizer_params,
638-
training_state.policy_params,
639-
training_state.qr_params,
640-
training_state.qc_params,
646+
training_state.behavior_policy_params,
647+
training_state.behavior_qr_params,
648+
training_state.backup_qc_params,
641649
training_state.model_params,
642650
)
643651
if store_buffer:
@@ -660,9 +668,9 @@ def training_epoch_with_timing(
660668
assert total_steps >= num_timesteps
661669
params = (
662670
training_state.normalizer_params,
663-
training_state.policy_params,
664-
training_state.qr_params,
665-
training_state.qc_params,
671+
training_state.behavior_policy_params,
672+
training_state.behavior_qr_params,
673+
training_state.backup_qc_params,
666674
training_state.model_params,
667675
)
668676
if store_buffer:

0 commit comments

Comments
 (0)