4141)
4242from  ss2r .algorithms .mbpo .training_step  import  make_training_step 
4343from  ss2r .algorithms .mbpo .types  import  TrainingState 
44+ from  ss2r .algorithms .penalizers  import  Params , Penalizer 
4445from  ss2r .algorithms .sac  import  gradients 
4546from  ss2r .algorithms .sac .data  import  collect_single_step 
4647from  ss2r .algorithms .sac .q_transforms  import  QTransformation , SACBase , SACCost 
@@ -87,6 +88,7 @@ def _init_training_state(
8788    qc_optimizer : optax .GradientTransformation ,
8889    model_optimizer : optax .GradientTransformation ,
8990    model_ensemble_size : int ,
91+     penalizer_params : Params  |  None ,
9092) ->  TrainingState :
9193    """Inits the training state and replicates it over devices.""" 
9294    key_policy , key_qr , key_model  =  jax .random .split (key , 3 )
@@ -101,16 +103,14 @@ def _init_training_state(
101103    model_params  =  init_model_ensemble (model_keys )
102104    model_optimizer_state  =  model_optimizer .init (model_params )
103105    if  mbpo_network .qc_network  is  not   None :
104-         qc_params  =  mbpo_network .qc_network .init (key_qr )
106+         backup_qc_params  =  mbpo_network .qc_network .init (key_qr )
105107        assert  qc_optimizer  is  not   None 
106-         qc_optimizer_state  =  qc_optimizer .init (qc_params )
108+         backup_qc_optimizer_state  =  qc_optimizer .init (backup_qc_params )
107109        backup_qr_params  =  qr_params 
108-         backup_qr_optimizer_state  =  qr_optimizer_state 
109110    else :
110-         qc_params  =  None 
111-         qc_optimizer_state  =  None 
111+         backup_qc_params  =  None 
112+         backup_qc_optimizer_state  =  None 
112113        backup_qr_params  =  None 
113-         backup_qr_optimizer_state  =  None 
114114    if  isinstance (obs_size , Mapping ):
115115        obs_shape  =  {
116116            k : specs .Array (v , jnp .dtype ("float32" )) for  k , v  in  obs_size .items ()
@@ -119,24 +119,27 @@ def _init_training_state(
119119        obs_shape  =  specs .Array ((obs_size ,), jnp .dtype ("float32" ))
120120    normalizer_params  =  running_statistics .init_state (obs_shape )
121121    training_state  =  TrainingState (
122-         policy_optimizer_state = policy_optimizer_state ,
123-         policy_params = policy_params ,
122+         behavior_policy_optimizer_state = policy_optimizer_state ,
123+         behavior_policy_params = policy_params ,
124124        backup_policy_params = policy_params ,
125-         qr_optimizer_state = qr_optimizer_state ,
126-         qr_params = qr_params ,
127-         backup_qr_optimizer_state = backup_qr_optimizer_state ,
125+         behavior_qr_optimizer_state = qr_optimizer_state ,
126+         behavior_qr_params = qr_params ,
128127        backup_qr_params = backup_qr_params ,
129-         qc_optimizer_state = qc_optimizer_state ,
130-         qc_params = qc_params ,
131-         target_qr_params = qr_params ,
132-         target_qc_params = qc_params ,
128+         behavior_qc_optimizer_state = backup_qc_optimizer_state ,
129+         behavior_qc_params = backup_qc_params ,
130+         behavior_target_qr_params = qr_params ,
131+         behavior_target_qc_params = backup_qc_params ,
132+         backup_qc_params = backup_qc_params ,
133+         backup_qc_optimizer_state = backup_qc_optimizer_state ,
134+         backup_target_qc_params = backup_qc_params ,
133135        model_params = model_params ,
134136        model_optimizer_state = model_optimizer_state ,
135137        gradient_steps = jnp .zeros (()),
136138        env_steps = jnp .zeros (()),
137139        alpha_optimizer_state = alpha_optimizer_state ,
138140        alpha_params = log_alpha ,
139141        normalizer_params = normalizer_params ,
142+         penalizer_params = penalizer_params ,
140143    )  #  type: ignore 
141144    return  training_state 
142145
@@ -188,6 +191,8 @@ def train(
188191    eval_env : Optional [envs .Env ] =  None ,
189192    safe : bool  =  False ,
190193    safety_budget : float  =  float ("inf" ),
194+     penalizer : Penalizer  |  None  =  None ,
195+     penalizer_params : Params  |  None  =  None ,
191196    reward_q_transform : QTransformation  =  SACBase (),
192197    cost_q_transform : QTransformation  =  SACCost (),
193198    use_bro : bool  =  True ,
@@ -302,6 +307,7 @@ def train(
302307        qc_optimizer = qc_optimizer ,
303308        model_optimizer = model_optimizer ,
304309        model_ensemble_size = model_ensemble_size ,
310+         penalizer_params = penalizer_params ,
305311    )
306312    del  global_key 
307313    local_key , model_rb_key , actor_critic_rb_key , env_key , eval_key  =  jax .random .split (
@@ -318,13 +324,13 @@ def train(
318324            ts_normalizer_params  =  params [0 ]
319325        training_state  =  training_state .replace (  # type: ignore 
320326            normalizer_params = ts_normalizer_params ,
321-             policy_params = params [1 ],
327+             behavior_policy_params = params [1 ],
322328            backup_policy_params = params [1 ],
323-             qr_params = params [3 ],
329+             behavior_qr_params = params [3 ],
324330            backup_qr_params = params [3 ],
325-             qc_params = params [4 ] if  safe  else  None ,
331+             behavior_qc_params = params [4 ] if  safe  else  None ,
332+             backup_qc_params = params [4 ] if  safe  else  None ,
326333        )
327- 
328334    make_planning_policy  =  mbpo_networks .make_inference_fn (mbpo_network )
329335    if  safe :
330336        make_rollout_policy  =  make_safe_inference_fn (
@@ -386,7 +392,7 @@ def train(
386392    )
387393    actor_update  =  (
388394        gradients .gradient_update_fn (  # pytype: disable=wrong-arg-types  # jax-ndarray 
389-             actor_loss , policy_optimizer , pmap_axis_name = None 
395+             actor_loss , policy_optimizer , pmap_axis_name = None ,  has_aux = True 
390396        )
391397    )
392398    extra_fields  =  ("truncation" ,)
@@ -402,7 +408,7 @@ def train(
402408        safety_budget = safety_budget ,
403409        cost_discount = safety_discounting ,
404410        scaling_fn = budget_scaling_fun ,
405-         use_termination = use_termination ,
411+         use_termination = penalizer   is   not   None   and   use_termination ,
406412    )
407413    training_step  =  make_training_step (
408414        env ,
@@ -434,7 +440,9 @@ def train(
434440        pessimism ,
435441        model_to_real_data_ratio ,
436442        budget_scaling_fun ,
437-         use_termination = use_termination ,
443+         use_termination ,
444+         penalizer ,
445+         safety_budget ,
438446    )
439447
440448    def  prefill_replay_buffer (
@@ -635,9 +643,9 @@ def training_epoch_with_timing(
635643            # Save current policy. 
636644            params  =  (
637645                training_state .normalizer_params ,
638-                 training_state .policy_params ,
639-                 training_state .qr_params ,
640-                 training_state .qc_params ,
646+                 training_state .behavior_policy_params ,
647+                 training_state .behavior_qr_params ,
648+                 training_state .backup_qc_params ,
641649                training_state .model_params ,
642650            )
643651            if  store_buffer :
@@ -660,9 +668,9 @@ def training_epoch_with_timing(
660668    assert  total_steps  >=  num_timesteps 
661669    params  =  (
662670        training_state .normalizer_params ,
663-         training_state .policy_params ,
664-         training_state .qr_params ,
665-         training_state .qc_params ,
671+         training_state .behavior_policy_params ,
672+         training_state .behavior_qr_params ,
673+         training_state .backup_qc_params ,
666674        training_state .model_params ,
667675    )
668676    if  store_buffer :
0 commit comments