@@ -511,17 +511,28 @@ def _log_weight(
511511 # current log_prob of actions
512512 action = _maybe_get_or_select (tensordict , self .tensor_keys .action )
513513
514+ is_composite = None
515+ if all (key in tensordict for key in self .actor_network .dist_params_keys ):
516+ prev_dist = self .actor_network .build_dist_from_params (tensordict .detach ())
517+ kwargs , is_composite = _get_composite_kwargs (prev_dist )
518+ if is_composite :
519+ prev_log_prob = prev_dist .log_prob (tensordict , ** kwargs )
520+ else :
521+ prev_log_prob = prev_dist .log_prob (action , ** kwargs )
522+ print ('prev_log_prob' , prev_log_prob )
523+ else :
524+ try :
525+ prev_log_prob = _maybe_get_or_select (
526+ tensordict , self .tensor_keys .sample_log_prob
527+ )
528+ except KeyError as err :
529+ raise _make_lp_get_error (self .tensor_keys , tensordict , err )
530+
514531 with self .actor_network_params .to_module (
515532 self .actor_network
516533 ) if self .functional else contextlib .nullcontext ():
517- dist = self .actor_network .get_dist (tensordict )
534+ current_dist = self .actor_network .get_dist (tensordict )
518535
519- try :
520- prev_log_prob = _maybe_get_or_select (
521- tensordict , self .tensor_keys .sample_log_prob
522- )
523- except KeyError as err :
524- raise _make_lp_get_error (self .tensor_keys , tensordict , err )
525536
526537 if prev_log_prob .requires_grad :
527538 raise RuntimeError (
@@ -532,35 +543,12 @@ def _log_weight(
532543 raise RuntimeError (
533544 f"tensordict stored { self .tensor_keys .action } requires grad."
534545 )
535- if isinstance (dist , CompositeDistribution ):
536- is_composite = True
537- aggregate = dist .aggregate_probabilities
538- if aggregate is None :
539- aggregate = False
540- include_sum = dist .include_sum
541- if include_sum is None :
542- include_sum = False
543- kwargs = {
544- "inplace" : False ,
545- "aggregate_probabilities" : aggregate ,
546- "include_sum" : include_sum ,
547- }
548- else :
549- is_composite = False
550- kwargs = {}
551- if not is_composite :
552- log_prob = dist .log_prob (action )
546+ if isinstance (action , torch .Tensor ):
547+ log_prob = current_dist .log_prob (action )
553548 else :
554- log_prob : TensorDictBase = dist .log_prob (tensordict , ** kwargs )
555- if not is_tensor_collection (prev_log_prob ):
556- # this isn't great, in general multihead actions should have a composite log-prob too
557- warnings .warn (
558- "You are using a composite distribution, yet your log-probability is a tensor. "
559- "This usually happens whenever the CompositeDistribution has aggregate_probabilities=True "
560- "or include_sum=True. These options should be avoided: leaf log-probs should be written "
561- "independently and PPO will take care of the aggregation." ,
562- category = UserWarning ,
563- )
549+ if is_composite is None :
550+ kwargs , is_composite = _get_composite_kwargs (current_dist )
551+ log_prob : TensorDictBase = current_dist .log_prob (tensordict , ** kwargs )
564552 if (
565553 is_composite
566554 and not is_tensor_collection (prev_log_prob )
@@ -574,7 +562,7 @@ def _log_weight(
574562 if is_tensor_collection (kl_approx ):
575563 kl_approx = _sum_td_features (kl_approx )
576564
577- return log_weight , dist , kl_approx
565+ return log_weight , current_dist , kl_approx
578566
579567 def loss_critic (self , tensordict : TensorDictBase ) -> torch .Tensor :
580568 """Returns the critic loss multiplied by ``critic_coef``, if it is not ``None``."""
@@ -650,6 +638,9 @@ def _cached_critic_network_params_detached(self):
650638 @dispatch
651639 def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
652640 tensordict = tensordict .clone (False )
641+
642+ log_weight , dist , kl_approx = self ._log_weight (tensordict )
643+
653644 advantage = tensordict .get (self .tensor_keys .advantage , None )
654645 if advantage is None :
655646 self .value_estimator (
@@ -663,7 +654,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
663654 scale = advantage .std ().clamp_min (1e-6 )
664655 advantage = (advantage - loc ) / scale
665656
666- log_weight , dist , kl_approx = self ._log_weight (tensordict )
667657 if is_tensor_collection (log_weight ):
668658 log_weight = _sum_td_features (log_weight )
669659 log_weight = log_weight .view (advantage .shape )
@@ -1305,3 +1295,22 @@ def _make_lp_get_error(tensor_keys, log_prob, err):
13051295 return KeyError (result )
13061296 result += "This is usually due to a missing call to loss.set_keys(sample_log_prob=<list_of_log_prob_keys>)."
13071297 return KeyError (result )
1298+
1299+ def _get_composite_kwargs (current_dist ):
1300+ if isinstance (current_dist , CompositeDistribution ):
1301+ is_composite = True
1302+ aggregate = current_dist .aggregate_probabilities
1303+ if aggregate is None :
1304+ aggregate = False
1305+ include_sum = current_dist .include_sum
1306+ if include_sum is None :
1307+ include_sum = False
1308+ kwargs = {
1309+ "inplace" : False ,
1310+ "aggregate_probabilities" : aggregate ,
1311+ "include_sum" : include_sum ,
1312+ }
1313+ else :
1314+ is_composite = False
1315+ kwargs = {}
1316+ return kwargs , is_composite
0 commit comments