diff --git a/torchrl/modules/models/model_based.py b/torchrl/modules/models/model_based.py index 20a5f9c4e9a..3e2035f1cd9 100644 --- a/torchrl/modules/models/model_based.py +++ b/torchrl/modules/models/model_based.py @@ -236,17 +236,55 @@ class RSSMRollout(TensorDictModuleBase): Args: rssm_prior (TensorDictModule): Prior network. rssm_posterior (TensorDictModule): Posterior network. + use_scan (bool, optional): If True, uses torch._higher_order_ops.scan for + the rollout loop. This is more torch.compile friendly but may have + different performance characteristics. Defaults to False. + compile_step (bool, optional): If True, compiles the individual step function. + Only used when use_scan=False. Defaults to False. + compile_backend (str, optional): Backend to use for compilation. + Defaults to "inductor". + compile_mode (str, optional): Mode to use for compilation. + Defaults to None (uses PyTorch default). """ - def __init__(self, rssm_prior: TensorDictModule, rssm_posterior: TensorDictModule): + def __init__( + self, + rssm_prior: TensorDictModule, + rssm_posterior: TensorDictModule, + use_scan: bool = False, + compile_step: bool = False, + compile_backend: str = "inductor", + compile_mode: str | None = None, + ): super().__init__() _module = TensorDictSequential(rssm_prior, rssm_posterior) self.in_keys = _module.in_keys self.out_keys = _module.out_keys self.rssm_prior = rssm_prior self.rssm_posterior = rssm_posterior + self.use_scan = use_scan + self.compile_step = compile_step + self.compile_backend = compile_backend + self.compile_mode = compile_mode + self._compiled_step = None + + def _get_step_fn(self): + """Get the step function, optionally compiled.""" + if self.compile_step and self._compiled_step is None: + self._compiled_step = torch.compile( + self._step, + backend=self.compile_backend, + mode=self.compile_mode, + ) + return self._compiled_step if self.compile_step else self._step + + def _step(self, _tensordict): + """Single RSSM step: prior + posterior.""" + self.rssm_prior(_tensordict) + self.rssm_posterior(_tensordict) + return _tensordict def forward(self, tensordict): """Runs a rollout of simulated transitions in the latent space given a sequence of actions and environment observations. @@ -267,25 +305,23 @@ def forward(self, tensordict): which amends to q(s_{t+1} | s_t, a_t, o_{t+1}) """ - # from torchrl.envs.utils import step_mdp + if self.use_scan: + return self._forward_scan(tensordict) + return self._forward_loop(tensordict) + + def _forward_loop(self, tensordict): + """Traditional loop-based forward.""" tensordict_out = [] *batch, time_steps = tensordict.shape update_values = tensordict.exclude(*self.out_keys).unbind(-1) _tensordict = update_values[0] - for t in range(time_steps): - # samples according to p(s_{t+1} | s_t, a_t, b_t) - # ["state", "belief", "action"] -> [("next", "prior_mean"), ("next", "prior_std"), "_", ("next", "belief")] - with timeit("rssm_rollout/time-rssm_prior"): - self.rssm_prior(_tensordict) + step_fn = self._get_step_fn() - # samples according to p(s_{t+1} | s_t, a_t, o_{t+1}) = p(s_t | b_t, o_t) - # [("next", "belief"), ("next", "encoded_latents")] -> [("next", "posterior_mean"), ("next", "posterior_std"), ("next", "state")] - with timeit("rssm_rollout/time-rssm_posterior"): - self.rssm_posterior(_tensordict) + for t in range(time_steps): + _tensordict = step_fn(_tensordict) tensordict_out.append(_tensordict) - # _tensordict = step_mdp(_tensordict, keep_other=True) if t < time_steps - 1: # Translate ("next", *) to the non-next key required for the current step input _tensordict = _tensordict.select(*self.in_keys, strict=False) @@ -294,6 +330,36 @@ def forward(self, tensordict): out = torch.stack(tensordict_out, tensordict.ndim - 1) return out + def _forward_scan(self, tensordict): + """Scan-based forward using torch._higher_order_ops.scan. + + This is more torch.compile friendly as it avoids Python control flow. + """ + from torch._higher_order_ops.scan import scan + + *batch, time_steps = tensordict.shape + + update_values = tensordict.exclude(*self.out_keys).unbind(-1) + init_td = update_values[0] + + # Stack the update values for scan input + stacked_updates = torch.stack(list(update_values), dim=0) + + def scan_fn(carry, x): + # carry is the current tensordict, x is the update for this step + _td = x.update(carry.select(*self.in_keys, strict=False)) + self.rssm_prior(_td) + self.rssm_posterior(_td) + # Return output and new carry + return _td, _td + + # Run scan + _, outputs = scan(scan_fn, [init_td], [stacked_updates]) + + # outputs is stacked along dim 0, move to time dimension + out = outputs.transpose(0, tensordict.ndim - 1) + return out + class RSSMPrior(nn.Module): """The prior network of the RSSM. @@ -356,7 +422,19 @@ def __init__( self.rnn_hidden_dim = rnn_hidden_dim self.action_shape = action_spec.shape - def forward(self, state, belief, action): + def forward(self, state, belief, action, noise=None): + """Forward pass through the prior network. + + Args: + state: Previous stochastic state. + belief: Previous deterministic belief. + action: Action to condition on. + noise: Optional pre-sampled noise for the prior state. + If None, samples from standard normal. Used for deterministic testing. + + Returns: + Tuple of (prior_mean, prior_std, state, belief). + """ projector_input = torch.cat([state, action], dim=-1) action_state = self.action_state_projector(projector_input) unsqueeze = False @@ -377,7 +455,9 @@ def forward(self, state, belief, action): belief = belief.squeeze(0) prior_mean, prior_std = self.rnn_to_prior_projector(belief) - state = prior_mean + torch.randn_like(prior_std) * prior_std + if noise is None: + noise = torch.randn_like(prior_std) + state = prior_mean + noise * prior_std return prior_mean, prior_std, state, belief @@ -424,9 +504,22 @@ def __init__(self, hidden_dim=200, state_dim=30, scale_lb=0.1, rnn_hidden_dim=No ) self.hidden_dim = hidden_dim - def forward(self, belief, obs_embedding): + def forward(self, belief, obs_embedding, noise=None): + """Forward pass through the posterior network. + + Args: + belief: Deterministic belief from the prior. + obs_embedding: Encoded observation. + noise: Optional pre-sampled noise for the posterior state. + If None, samples from standard normal. Used for deterministic testing. + + Returns: + Tuple of (posterior_mean, posterior_std, state). + """ posterior_mean, posterior_std = self.obs_rnn_to_post_projector( torch.cat([belief, obs_embedding], dim=-1) ) - state = posterior_mean + torch.randn_like(posterior_std) * posterior_std + if noise is None: + noise = torch.randn_like(posterior_std) + state = posterior_mean + noise * posterior_std return posterior_mean, posterior_std, state