diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index cc713fb56b2..f68a66c8fa0 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -18,12 +18,25 @@ # from torchrl.modules.tensordict_module.rnn import GRUCell from torch.nn import GRUCell +from torchrl._utils import _maybe_record_function_decorator from torchrl.modules.models.models import MLP UNSQUEEZE_RNN_INPUT = version.parse(torch.__version__) < version.parse("1.11") +class _Contiguous(nn.Module): + """Helper module that makes a tensor contiguous. + + This is useful inside nn.Sequential for torch.compile inductor compatibility. + Inductor sometimes needs explicit contiguous() calls after reshape operations + for efficient memory layout. + """ + + def forward(self, x): + return x.contiguous() + + class DreamerActor(nn.Module): """Dreamer actor network. @@ -129,16 +142,18 @@ def __init__( k = k * 2 self.encoder = nn.Sequential(*layers) + @_maybe_record_function_decorator("ObsEncoder.forward") def forward(self, observation): *batch_sizes, C, H, W = observation.shape - if len(batch_sizes) == 0: - end_dim = 0 - else: - end_dim = len(batch_sizes) - 1 - observation = torch.flatten(observation, start_dim=0, end_dim=end_dim) + # Flatten all batch dimensions into one for conv + # Use contiguous() for inductor compatibility + observation = observation.flatten( + 0, len(batch_sizes) - 1 if batch_sizes else 0 + ).contiguous() obs_encoded = self.encoder(observation) - latent = obs_encoded.reshape(*batch_sizes, -1) - return latent + # Reshape back to original batch dims + latent + latent = obs_encoded.unflatten(0, batch_sizes) if batch_sizes else obs_encoded + return latent.reshape(*batch_sizes, -1).contiguous() class ObsDecoder(nn.Module): @@ -232,14 +247,25 @@ def __init__( self.decoder = nn.Sequential(*layers) self._depth = channels + @_maybe_record_function_decorator("ObsDecoder.forward") def forward(self, state, rnn_hidden): + # Concatenate and project to latent space latent = self.state_to_latent(torch.cat([state, rnn_hidden], dim=-1)) *batch_sizes, D = latent.shape - latent = latent.view(-1, D, 1, 1) + # Flatten batch dimensions and reshape for conv + latent = ( + latent.flatten(0, len(batch_sizes) - 1 if batch_sizes else 0) + .unsqueeze(-1) + .unsqueeze(-1) + .contiguous() + ) obs_decoded = self.decoder(latent) _, C, H, W = obs_decoded.shape - obs_decoded = obs_decoded.view(*batch_sizes, C, H, W) - return obs_decoded + # Unflatten back to original batch dims + obs_decoded = ( + obs_decoded.unflatten(0, batch_sizes) if batch_sizes else obs_decoded + ) + return obs_decoded.contiguous() class RSSMRollout(TensorDictModuleBase):