From 3771f5de8dd7098440def726e69de637df939b88 Mon Sep 17 00:00:00 2001 From: Daniel Han-Chen Date: Wed, 5 Jun 2024 20:57:08 +1000 Subject: [PATCH] Update llama.py --- unsloth/models/llama.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 7dec8624e..9c9b45dba 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1853,9 +1853,11 @@ def for_inference(model): pass # Wrap model.generate - model._unwrapped_old_generate = model.generate - model.generate = _wrap_fast_inference(model.generate, device_type, dtype, model) - + if model.generate.__name__ != "_fast_generate": + model._unwrapped_old_generate = model.generate + model.generate = _wrap_fast_inference(model.generate, device_type, dtype, model) + pass + # Patch tokenizer to pad to the left internal_model = model while hasattr(internal_model, "model"):