diff --git a/torchrl/envs/model_based/common.py b/torchrl/envs/model_based/common.py index daff2db4dbb..171ea068428 100644 --- a/torchrl/envs/model_based/common.py +++ b/torchrl/envs/model_based/common.py @@ -161,12 +161,13 @@ def _step( else: tensordict_out = self.world_model(tensordict_out) # done can be missing, it will be filled by `step` - tensordict_out = tensordict_out.select( - *self.observation_spec.keys(), - *self.full_done_spec.keys(), - *self.full_reward_spec.keys(), - strict=False, + # Convert to list for torch.compile compatibility (dynamo can't unpack _CompositeSpecKeysView) + keys_to_select = ( + list(self.observation_spec.keys()) + + list(self.full_done_spec.keys()) + + list(self.full_reward_spec.keys()) ) + tensordict_out = tensordict_out.select(*keys_to_select, strict=False) return tensordict_out @abc.abstractmethod