diff --git a/unsloth/models/rl_replacements.py b/unsloth/models/rl_replacements.py index a207514e7..e1d20a340 100644 --- a/unsloth/models/rl_replacements.py +++ b/unsloth/models/rl_replacements.py @@ -431,7 +431,11 @@ def _get_per_token_logps_and_entropies(self, model, input_ids, attention_mask, l unwrapped_model = self.accelerator.unwrap_model(model, keep_fp32_wrapper=False) with torch.amp.autocast(device_type = 'cuda', dtype = self._autocast_dtype): - with torch.inference_mode(): + # aten.t() for tensor subclasses raises RuntimeError: + # Cannot set version_counter for inference tensor + # Turning off inference mode for now + # See https://github.com/pytorch/pytorch/issues/164872 for more details + with torch.inference_mode(False): if pixel_values is None: attention_mask = input_ids != self.processing_class.pad_token_id attention_mask = attention_mask.to(attention_mask.dtype)