Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions torchrl/modules/models/model_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,25 @@

# from torchrl.modules.tensordict_module.rnn import GRUCell
from torch.nn import GRUCell
from torchrl._utils import _maybe_record_function_decorator

from torchrl.modules.models.models import MLP

UNSQUEEZE_RNN_INPUT = version.parse(torch.__version__) < version.parse("1.11")


class _Contiguous(nn.Module):
"""Helper module that makes a tensor contiguous.

This is useful inside nn.Sequential for torch.compile inductor compatibility.
Inductor sometimes needs explicit contiguous() calls after reshape operations
for efficient memory layout.
"""

def forward(self, x):
return x.contiguous()


class DreamerActor(nn.Module):
"""Dreamer actor network.

Expand Down Expand Up @@ -129,16 +142,18 @@ def __init__(
k = k * 2
self.encoder = nn.Sequential(*layers)

@_maybe_record_function_decorator("ObsEncoder.forward")
def forward(self, observation):
*batch_sizes, C, H, W = observation.shape
if len(batch_sizes) == 0:
end_dim = 0
else:
end_dim = len(batch_sizes) - 1
observation = torch.flatten(observation, start_dim=0, end_dim=end_dim)
# Flatten all batch dimensions into one for conv
# Use contiguous() for inductor compatibility
observation = observation.flatten(
0, len(batch_sizes) - 1 if batch_sizes else 0
).contiguous()
obs_encoded = self.encoder(observation)
latent = obs_encoded.reshape(*batch_sizes, -1)
return latent
# Reshape back to original batch dims + latent
latent = obs_encoded.unflatten(0, batch_sizes) if batch_sizes else obs_encoded
return latent.reshape(*batch_sizes, -1).contiguous()


class ObsDecoder(nn.Module):
Expand Down Expand Up @@ -232,14 +247,25 @@ def __init__(
self.decoder = nn.Sequential(*layers)
self._depth = channels

@_maybe_record_function_decorator("ObsDecoder.forward")
def forward(self, state, rnn_hidden):
# Concatenate and project to latent space
latent = self.state_to_latent(torch.cat([state, rnn_hidden], dim=-1))
*batch_sizes, D = latent.shape
latent = latent.view(-1, D, 1, 1)
# Flatten batch dimensions and reshape for conv
latent = (
latent.flatten(0, len(batch_sizes) - 1 if batch_sizes else 0)
.unsqueeze(-1)
.unsqueeze(-1)
.contiguous()
)
obs_decoded = self.decoder(latent)
_, C, H, W = obs_decoded.shape
obs_decoded = obs_decoded.view(*batch_sizes, C, H, W)
return obs_decoded
# Unflatten back to original batch dims
obs_decoded = (
obs_decoded.unflatten(0, batch_sizes) if batch_sizes else obs_decoded
)
return obs_decoded.contiguous()


class RSSMRollout(TensorDictModuleBase):
Expand Down
Loading