diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 6abd49222..48628a8d0 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -1787,6 +1787,11 @@ def unsloth_fast_generate( class FastLlamaModel: + @staticmethod + def _prepare_for_qat(model, qat_scheme): + model = _prepare_model_for_qat(model, qat_scheme) + return model + @staticmethod def pre_patch(): init_name, function = patch_llama_rope_scaling( @@ -2668,7 +2673,8 @@ def get_peft_model( # Apply QAT + LoRA if specified if qat_scheme is not None: print("Unsloth: Applying QAT to mitigate quantization degradation") - model = _prepare_model_for_qat(model, qat_scheme) + model = FastLlamaModel._prepare_for_qat(model, qat_scheme) + pass model._saved_temp_tokenizer = _saved_temp_tokenizer diff --git a/unsloth/models/loader.py b/unsloth/models/loader.py index 5f92999f4..f31d82106 100644 --- a/unsloth/models/loader.py +++ b/unsloth/models/loader.py @@ -569,6 +569,12 @@ def from_pretrained( class FastModel(FastBaseModel): + + @staticmethod + def _prepare_for_qat(model, qat_scheme): + model = _prepare_model_for_qat(model, qat_scheme) + return model + @staticmethod def from_pretrained( model_name = "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit", @@ -1070,7 +1076,7 @@ def from_pretrained( # Apply QAT if specified if qat_scheme is not None: print("Unsloth: Applying QAT to mitigate quantization degradation") - model = _prepare_model_for_qat(model, qat_scheme) + model = FastModel._prepare_for_qat(model, qat_scheme) pass return model, tokenizer