diff --git a/torchrl/modules/llm/policies/transformers_wrapper.py b/torchrl/modules/llm/policies/transformers_wrapper.py index f2e8981b566..77c759d7589 100644 --- a/torchrl/modules/llm/policies/transformers_wrapper.py +++ b/torchrl/modules/llm/policies/transformers_wrapper.py @@ -787,6 +787,10 @@ def _from_transformers_generate_history(self, td, cfg, out) -> TensorDictBase: response_struct = history.apply_chat_template( tokenizer=self.tokenizer, **tokenizer_kwargs ) + + if self._device is not None: + response_struct = response_struct.to(self._device) + tokens_prompt_padded = response_struct.get( "input_ids", as_padded_tensor=True,