diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index ea5dabf73..1e301acc9 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -791,7 +791,7 @@ def _CausalLM_fast_forward( *args, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: - if past_key_values is not None: + if past_key_values is not None and self.config.model_type != "qwen2": outputs = fast_forward_inference( self, input_ids,