Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 109 additions & 16 deletions torchrl/modules/models/model_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,17 +236,55 @@ class RSSMRollout(TensorDictModuleBase):
Args:
rssm_prior (TensorDictModule): Prior network.
rssm_posterior (TensorDictModule): Posterior network.
use_scan (bool, optional): If True, uses torch._higher_order_ops.scan for
the rollout loop. This is more torch.compile friendly but may have
different performance characteristics. Defaults to False.
compile_step (bool, optional): If True, compiles the individual step function.
Only used when use_scan=False. Defaults to False.
compile_backend (str, optional): Backend to use for compilation.
Defaults to "inductor".
compile_mode (str, optional): Mode to use for compilation.
Defaults to None (uses PyTorch default).


"""

def __init__(self, rssm_prior: TensorDictModule, rssm_posterior: TensorDictModule):
def __init__(
self,
rssm_prior: TensorDictModule,
rssm_posterior: TensorDictModule,
use_scan: bool = False,
compile_step: bool = False,
compile_backend: str = "inductor",
compile_mode: str | None = None,
):
super().__init__()
_module = TensorDictSequential(rssm_prior, rssm_posterior)
self.in_keys = _module.in_keys
self.out_keys = _module.out_keys
self.rssm_prior = rssm_prior
self.rssm_posterior = rssm_posterior
self.use_scan = use_scan
self.compile_step = compile_step
self.compile_backend = compile_backend
self.compile_mode = compile_mode
self._compiled_step = None

def _get_step_fn(self):
"""Get the step function, optionally compiled."""
if self.compile_step and self._compiled_step is None:
self._compiled_step = torch.compile(
self._step,
backend=self.compile_backend,
mode=self.compile_mode,
)
return self._compiled_step if self.compile_step else self._step

def _step(self, _tensordict):
"""Single RSSM step: prior + posterior."""
self.rssm_prior(_tensordict)
self.rssm_posterior(_tensordict)
return _tensordict

def forward(self, tensordict):
"""Runs a rollout of simulated transitions in the latent space given a sequence of actions and environment observations.
Expand All @@ -267,25 +305,23 @@ def forward(self, tensordict):
which amends to q(s_{t+1} | s_t, a_t, o_{t+1})

"""
# from torchrl.envs.utils import step_mdp
if self.use_scan:
return self._forward_scan(tensordict)
return self._forward_loop(tensordict)

def _forward_loop(self, tensordict):
"""Traditional loop-based forward."""
tensordict_out = []
*batch, time_steps = tensordict.shape

update_values = tensordict.exclude(*self.out_keys).unbind(-1)
_tensordict = update_values[0]
for t in range(time_steps):
# samples according to p(s_{t+1} | s_t, a_t, b_t)
# ["state", "belief", "action"] -> [("next", "prior_mean"), ("next", "prior_std"), "_", ("next", "belief")]
with timeit("rssm_rollout/time-rssm_prior"):
self.rssm_prior(_tensordict)
step_fn = self._get_step_fn()

# samples according to p(s_{t+1} | s_t, a_t, o_{t+1}) = p(s_t | b_t, o_t)
# [("next", "belief"), ("next", "encoded_latents")] -> [("next", "posterior_mean"), ("next", "posterior_std"), ("next", "state")]
with timeit("rssm_rollout/time-rssm_posterior"):
self.rssm_posterior(_tensordict)
for t in range(time_steps):
_tensordict = step_fn(_tensordict)

tensordict_out.append(_tensordict)
# _tensordict = step_mdp(_tensordict, keep_other=True)
if t < time_steps - 1:
# Translate ("next", *) to the non-next key required for the current step input
_tensordict = _tensordict.select(*self.in_keys, strict=False)
Expand All @@ -294,6 +330,36 @@ def forward(self, tensordict):
out = torch.stack(tensordict_out, tensordict.ndim - 1)
return out

def _forward_scan(self, tensordict):
"""Scan-based forward using torch._higher_order_ops.scan.

This is more torch.compile friendly as it avoids Python control flow.
"""
from torch._higher_order_ops.scan import scan

*batch, time_steps = tensordict.shape

update_values = tensordict.exclude(*self.out_keys).unbind(-1)
init_td = update_values[0]

# Stack the update values for scan input
stacked_updates = torch.stack(list(update_values), dim=0)

def scan_fn(carry, x):
# carry is the current tensordict, x is the update for this step
_td = x.update(carry.select(*self.in_keys, strict=False))
self.rssm_prior(_td)
self.rssm_posterior(_td)
# Return output and new carry
return _td, _td

# Run scan
_, outputs = scan(scan_fn, [init_td], [stacked_updates])

# outputs is stacked along dim 0, move to time dimension
out = outputs.transpose(0, tensordict.ndim - 1)
return out


class RSSMPrior(nn.Module):
"""The prior network of the RSSM.
Expand Down Expand Up @@ -356,7 +422,19 @@ def __init__(
self.rnn_hidden_dim = rnn_hidden_dim
self.action_shape = action_spec.shape

def forward(self, state, belief, action):
def forward(self, state, belief, action, noise=None):
"""Forward pass through the prior network.

Args:
state: Previous stochastic state.
belief: Previous deterministic belief.
action: Action to condition on.
noise: Optional pre-sampled noise for the prior state.
If None, samples from standard normal. Used for deterministic testing.

Returns:
Tuple of (prior_mean, prior_std, state, belief).
"""
projector_input = torch.cat([state, action], dim=-1)
action_state = self.action_state_projector(projector_input)
unsqueeze = False
Expand All @@ -377,7 +455,9 @@ def forward(self, state, belief, action):
belief = belief.squeeze(0)

prior_mean, prior_std = self.rnn_to_prior_projector(belief)
state = prior_mean + torch.randn_like(prior_std) * prior_std
if noise is None:
noise = torch.randn_like(prior_std)
state = prior_mean + noise * prior_std
return prior_mean, prior_std, state, belief


Expand Down Expand Up @@ -424,9 +504,22 @@ def __init__(self, hidden_dim=200, state_dim=30, scale_lb=0.1, rnn_hidden_dim=No
)
self.hidden_dim = hidden_dim

def forward(self, belief, obs_embedding):
def forward(self, belief, obs_embedding, noise=None):
"""Forward pass through the posterior network.

Args:
belief: Deterministic belief from the prior.
obs_embedding: Encoded observation.
noise: Optional pre-sampled noise for the posterior state.
If None, samples from standard normal. Used for deterministic testing.

Returns:
Tuple of (posterior_mean, posterior_std, state).
"""
posterior_mean, posterior_std = self.obs_rnn_to_post_projector(
torch.cat([belief, obs_embedding], dim=-1)
)
state = posterior_mean + torch.randn_like(posterior_std) * posterior_std
if noise is None:
noise = torch.randn_like(posterior_std)
state = posterior_mean + noise * posterior_std
return posterior_mean, posterior_std, state
Loading