diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 6634f04c7cb..23a31b800a6 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -11,7 +11,7 @@ from tensordict.nn import TensorDictModule from tensordict.utils import NestedKey -from torchrl._utils import timeit +from torchrl._utils import _maybe_record_function_decorator, _maybe_timeit from torchrl.envs.model_based.dreamer import DreamerEnv from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp from torchrl.objectives.common import LossModule @@ -123,38 +123,48 @@ def __init__( def _forward_value_estimator_keys(self, **kwargs) -> None: pass + @_maybe_record_function_decorator("world_model_loss/forward") def forward(self, tensordict: TensorDict) -> torch.Tensor: - tensordict = tensordict.clone(recurse=False) + tensordict = tensordict.copy() tensordict.rename_key_( ("next", self.tensor_keys.reward), ("next", self.tensor_keys.true_reward), ) + tensordict = self.world_model(tensordict) - # compute model loss + + prior_mean = tensordict.get(("next", self.tensor_keys.prior_mean)) + prior_std = tensordict.get(("next", self.tensor_keys.prior_std)) + posterior_mean = tensordict.get(("next", self.tensor_keys.posterior_mean)) + posterior_std = tensordict.get(("next", self.tensor_keys.posterior_std)) + kl_loss = self.kl_loss( - tensordict.get(("next", self.tensor_keys.prior_mean)), - tensordict.get(("next", self.tensor_keys.prior_std)), - tensordict.get(("next", self.tensor_keys.posterior_mean)), - tensordict.get(("next", self.tensor_keys.posterior_std)), + prior_mean, prior_std, posterior_mean, posterior_std, ).unsqueeze(-1) + + # Ensure contiguous layout for torch.compile compatibility + # The gradient from distance_loss flows back through decoder convolutions + pixels = tensordict.get(("next", self.tensor_keys.pixels)).contiguous() + reco_pixels = tensordict.get(("next", self.tensor_keys.reco_pixels)).contiguous() reco_loss = distance_loss( - tensordict.get(("next", self.tensor_keys.pixels)), - tensordict.get(("next", self.tensor_keys.reco_pixels)), + pixels, + reco_pixels, self.reco_loss, ) if not self.global_average: reco_loss = reco_loss.sum((-3, -2, -1)) reco_loss = reco_loss.mean().unsqueeze(-1) + true_reward = tensordict.get(("next", self.tensor_keys.true_reward)) + pred_reward = tensordict.get(("next", self.tensor_keys.reward)) reward_loss = distance_loss( - tensordict.get(("next", self.tensor_keys.true_reward)), - tensordict.get(("next", self.tensor_keys.reward)), + true_reward, + pred_reward, self.reward_loss, ) if not self.global_average: reward_loss = reward_loss.squeeze(-1) reward_loss = reward_loss.mean().unsqueeze(-1) - # import ipdb; ipdb.set_trace() td_out = TensorDict( loss_model_kl=self.lambda_kl * kl_loss, @@ -162,10 +172,8 @@ def forward(self, tensordict: TensorDict) -> torch.Tensor: loss_model_reward=self.lambda_reward * reward_loss, ) self._clear_weakrefs(tensordict, td_out) - return ( - td_out, - tensordict.detach(), - ) + + return (td_out, tensordict.data) @staticmethod def normal_log_probability(x, mean, std): @@ -275,10 +283,11 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self._tensor_keys.value, ) + @_maybe_record_function_decorator("actor_loss/forward") def forward(self, tensordict: TensorDict) -> tuple[TensorDict, TensorDict]: - tensordict = tensordict.select("state", self.tensor_keys.belief).detach() + tensordict = tensordict.select("state", self.tensor_keys.belief).data - with timeit("actor_loss/time-rollout"), hold_out_net( + with _maybe_timeit("actor_loss/time-rollout"), hold_out_net( self.model_based_env ), set_exploration_type(ExplorationType.RANDOM): tensordict = self.model_based_env.reset(tensordict.copy()) @@ -288,7 +297,6 @@ def forward(self, tensordict: TensorDict) -> tuple[TensorDict, TensorDict]: auto_reset=False, tensordict=tensordict, ) - next_tensordict = step_mdp(fake_data, keep_other=True) with hold_out_net(self.value_model): next_tensordict = self.value_model(next_tensordict) @@ -308,7 +316,8 @@ def forward(self, tensordict: TensorDict) -> tuple[TensorDict, TensorDict]: actor_loss = -lambda_target.sum((-2, -1)).mean() loss_tensordict = TensorDict({"loss_actor": actor_loss}, []) self._clear_weakrefs(tensordict, loss_tensordict) - return loss_tensordict, fake_data.detach() + + return loss_tensordict, fake_data.data def lambda_target(self, reward: torch.Tensor, value: torch.Tensor) -> torch.Tensor: done = torch.zeros(reward.shape, dtype=torch.bool, device=reward.device) @@ -420,14 +429,15 @@ def __init__( def _forward_value_estimator_keys(self, **kwargs) -> None: pass + @_maybe_record_function_decorator("value_loss/forward") def forward(self, fake_data) -> torch.Tensor: lambda_target = fake_data.get("lambda_target") + tensordict_select = fake_data.select(*self.value_model.in_keys, strict=False) self.value_model(tensordict_select) + if self.discount_loss: - discount = self.gamma * torch.ones_like( - lambda_target, device=lambda_target.device - ) + discount = self.gamma * torch.ones_like(lambda_target, device=lambda_target.device) discount[..., 0, :] = 1 discount = discount.cumprod(dim=-2) value_loss = ( @@ -452,6 +462,8 @@ def forward(self, fake_data) -> torch.Tensor: .sum((-1, -2)) .mean() ) + loss_tensordict = TensorDict({"loss_value": value_loss}) self._clear_weakrefs(fake_data, loss_tensordict) + return loss_tensordict, fake_data