diff --git a/docs/source/developer_guides/quantization.md b/docs/source/developer_guides/quantization.md index 5067156d8b..a6a064c299 100644 --- a/docs/source/developer_guides/quantization.md +++ b/docs/source/developer_guides/quantization.md @@ -237,6 +237,27 @@ model = get_peft_model(base_model, peft_config) - DoRA only works with `quant_type = "int8_weight_only"` at the moment. - There is explicit support for torchao when used with LoRA. However, when torchao quantizes a layer, its class does not change, only the type of the underlying tensor. For this reason, PEFT methods other than LoRA will generally also work with torchao, even if not explicitly supported. Be aware, however, that **merging only works correctly with LoRA and with `quant_type = "int8_weight_only"`**. If you use a different PEFT method or dtype, merging will likely result in an error, and even it doesn't, the results will still be incorrect. +## Optimum-quanto + +PEFT supports models quantized with [optimum-quanto](https://github.com/huggingface/optimum-quanto). This has been tested with 2bit, 4bit, and 8bit int quantization. Optimum-quanto also works on CPU and MPS. + +```python +from transformers import AutoModelForCausalLM, QuantoConfig + +model_id = ... +quantization_config = QuantoConfig(weights="int4") # or qint2 or qint8 +base_model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config) +peft_config = LoraConfig(...) +model = get_peft_model(base_model, peft_config) +``` + +### Caveats: + +- Use optimum-quanto v0.2.5 or above, otherwise saving and loading won't work properly. +- If you want to use optimum-quanto via transformers, install transformers v4.46.0 or above. +- Float8 is discouraged as it can easily produce NaNs. +- There is explicit support for optimum-quanto when used with LoRA. However, when optimum-quanto quantizes a layer, it remains a subclass of the corresponding torch class (e.g., quanto's `QLinear` is a subclass of `nn.Linear`). For this reason, non-LoRA methods will generally also work with optimum-quanto, even if not explicitly supported. Be aware, however, that **merging only works correctly with LoRA**. If you use a method other than LoRA, merging may not raise an error but the results will be incorrect. + ## Other Supported PEFT Methods Besides LoRA, the following PEFT methods also support quantization: diff --git a/setup.py b/setup.py index f55d5b3fe8..02e4999120 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ "scipy", "protobuf", "sentencepiece", + "optimum-quanto", ] setup( diff --git a/src/peft/import_utils.py b/src/peft/import_utils.py index 97404aeb4b..dbc8c107e8 100644 --- a/src/peft/import_utils.py +++ b/src/peft/import_utils.py @@ -160,3 +160,10 @@ def is_xpu_available(check_device=False): except RuntimeError: return False return hasattr(torch, "xpu") and torch.xpu.is_available() + + +@lru_cache +def is_quanto_available(): + return (importlib.util.find_spec("optimum") is not None) and ( + importlib.util.find_spec("optimum.quanto") is not None + ) diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index cdbc41c652..35cdaea289 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -52,6 +52,7 @@ from .gptq import dispatch_gptq from .hqq import dispatch_hqq from .layer import Conv2d, LoraLayer, dispatch_default +from .quanto import dispatch_quanto from .torchao import dispatch_torchao from .tp_layer import dispatch_megatron @@ -331,6 +332,7 @@ def dynamic_dispatch_func(target, adapter_name, lora_config, **kwargs): dispatch_gptq, dispatch_hqq, dispatch_torchao, + dispatch_quanto, dispatch_megatron, dispatch_default, ] diff --git a/src/peft/tuners/lora/quanto.py b/src/peft/tuners/lora/quanto.py new file mode 100644 index 0000000000..0cd1e08f59 --- /dev/null +++ b/src/peft/tuners/lora/quanto.py @@ -0,0 +1,425 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import math +import warnings +from typing import Any, Optional + +import torch +from torch import nn +from torch.nn import functional as F + +from peft.import_utils import is_quanto_available +from peft.tuners.lora.layer import LoraLayer +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.utils.other import transpose + + +if is_quanto_available(): + # ensure that there are no quanto imports unless optimum.quanto is installed + from optimum.quanto import QConv2d, QLinear +else: + QConv2d, QLinear = None, None + + +class QuantoLoraLinear(torch.nn.Module, LoraLayer): + """LoRA layer implementation for quanto QLinear""" + + def __init__( + self, + base_layer, + adapter_name, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + init_lora_weights: bool = True, + use_rslora: bool = False, + use_dora: bool = False, + **kwargs, + ): + if use_dora: + raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + + super().__init__() + LoraLayer.__init__(self, base_layer) + self.fan_in_fan_out = fan_in_fan_out + + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) + + def _mixed_batch_forward( + self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any + ) -> torch.Tensor: + # This is a special method that handles the case when users pass the argument `adapter_names`. This is an + # extra argument that allows mixing different adapters in the same batch at inference time. + result = self.base_layer(x, *args, **kwargs) + + unique_adapters = set(adapter_names) + sub_batch_indices_list = [] + for adapter in unique_adapters: + sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) + + for i, active_adapter in enumerate(unique_adapters): + if active_adapter == "__base__": + continue + if active_adapter not in self.lora_A.keys(): + continue + + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + compute_dtype = lora_A.weight.dtype + if x.dtype != compute_dtype: + x = x.to(compute_dtype) + + # getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear + # layer output + sub_batch = x[sub_batch_indices_list[i]] + output = lora_B(lora_A(dropout(sub_batch))) * scaling + if requires_conversion: + output = output.to(expected_dtype) + result[sub_batch_indices_list[i]] += output + + return result + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif adapter_names is not None: + result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = x.to(lora_A.weight.dtype) + + output = lora_B(lora_A(dropout(x))) + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling + result = result + output + + return result + + def get_delta_weight(self, adapter): + return ( + transpose(self.lora_B[adapter].weight @ self.lora_A[adapter].weight, fan_in_fan_out=self.fan_in_fan_out) + * self.scaling[adapter] + ) + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + with torch.no_grad(): + new_module = torch.nn.Linear( + self.in_features, self.out_features, device=self.lora_A[adapter_names[0]].weight.device + ) + new_module.weight.zero_() + new_module.bias.zero_() + + base_layer = self.get_base_layer() + orig_weight = base_layer.qweight + new_module.weight.data += orig_weight + new_module.bias.data += base_layer.bias + + for active_adapter in adapter_names: + new_module.weight.data += self.get_delta_weight(active_adapter) + + quantized = base_layer.from_module(new_module, weights=base_layer.weight_qtype).qweight + if safe_merge and not torch.isfinite(quantized).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + base_layer.qweight._data = quantized._data + base_layer.qweight._scale = quantized._scale + self.merged_adapters.extend(adapter_names) + + def unmerge(self) -> None: + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + with torch.no_grad(): + new_module = torch.nn.Linear( + self.in_features, self.out_features, device=self.lora_A[self.active_adapters[0]].weight.device + ) + new_module.weight.zero_() + new_module.bias.zero_() + + base_layer = self.get_base_layer() + orig_weight = base_layer.qweight + new_module.weight.data += orig_weight + new_module.bias.data += base_layer.bias + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + new_module.weight.data -= self.get_delta_weight(active_adapter) + + quantized = base_layer.from_module(new_module, weights=base_layer.weight_qtype).qweight + base_layer.qweight._data = quantized._data + base_layer.qweight._scale = quantized._scale + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + +class QuantoLoraConv2d(torch.nn.Module, LoraLayer): + """LoRA layer implementation for quanto QConv2d""" + + def __init__( + self, + base_layer, + adapter_name, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: bool = True, + use_rslora: bool = False, + use_dora: bool = False, + **kwargs, + ): + if use_dora: + raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + + super().__init__() + LoraLayer.__init__(self, base_layer) + + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora) + + def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora): + # same as lora.layer.Conv2d + if r <= 0: + raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") + + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout[adapter_name] = lora_dropout_layer + # Actual trainable parameters + base_layer = self.get_base_layer() + kernel_size = base_layer.kernel_size + stride = base_layer.stride + padding = base_layer.padding + self.lora_A[adapter_name] = nn.Conv2d(self.in_features, r, kernel_size, stride, padding, bias=False) + self.lora_B[adapter_name] = nn.Conv2d(r, self.out_features, (1, 1), (1, 1), bias=False) + if use_rslora: + self.scaling[adapter_name] = lora_alpha / math.sqrt(r) + else: + self.scaling[adapter_name] = lora_alpha / r + + if init_lora_weights == "loftq": + self.loftq_init(adapter_name) + elif init_lora_weights: + self.reset_lora_parameters(adapter_name, init_lora_weights) + + # call this before dora_init + self._move_adapter_to_device_of_base_layer(adapter_name) + + if use_dora: + # TODO: Implement DoRA + self.dora_init(adapter_name) + self.use_dora[adapter_name] = True + else: + self.use_dora[adapter_name] = False + + self.set_adapter(self.active_adapters) + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + result = self.base_layer(x) + adapter_names = kwargs.pop("adapter_names", None) + if adapter_names is not None: + raise ValueError(f"{self.__class__.__name__} does not support mixed_batch_forward yet.") + + if self.disable_adapters: + return result + + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = x.to(lora_A.weight.dtype) + + output = lora_B(lora_A(dropout(x))) + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling + result = result + output + + return result + + def get_delta_weight(self, adapter): + # same as lora.layer.Conv2d + device = self.lora_B[adapter].weight.device + dtype = self.lora_A[adapter].weight.dtype + + # In case users wants to merge the adapter weights that are in + # (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to + # (b)float16 because some CPUs have slow bf16/fp16 matmuls. + cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16) + + weight_A = self.lora_A[adapter].weight + weight_B = self.lora_B[adapter].weight + + if cast_to_fp32: + weight_A = weight_A.float() + weight_B = weight_B.float() + + # https://github.com/bmaltais/kohya_ss/blob/feb6728762a8f463d15ba936d189d4c3abfaa1ab/networks/lora.py#L117 + if self.get_base_layer().weight.size()[2:4] == (1, 1): + # conv2d 1x1 + output_tensor = (weight_B.squeeze(3).squeeze(2) @ weight_A.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze( + 3 + ) * self.scaling[adapter] + else: + # conv2d 3x3 + output_tensor = ( + F.conv2d( + weight_A.permute(1, 0, 2, 3), + weight_B, + ).permute(1, 0, 2, 3) + * self.scaling[adapter] + ) + + if cast_to_fp32: + output_tensor = output_tensor.to(dtype=dtype) + + # cast back the weights + self.lora_A[adapter].weight.data = weight_A.to(dtype) + self.lora_B[adapter].weight.data = weight_B.to(dtype) + + return output_tensor + + def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: + # same as lora.quanto.QuantoLoraLinear + from optimum.quanto import quantize_weight + + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + # no adapter to merge + return + + base_layer = self.get_base_layer() + orig_weight = base_layer.qweight + + for active_adapter in adapter_names: + delta_weight = self.get_delta_weight(active_adapter) + # note: no in-place for safe_merge=False + new_weight_data = orig_weight + delta_weight + if safe_merge: + if torch.isfinite(new_weight_data).all(): + raise ValueError( + f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" + ) + quantized = quantize_weight(new_weight_data, qtype=orig_weight.qtype, axis=orig_weight.axis) + base_layer.qweight._data = quantized._data + base_layer.qweight._scale = quantized._scale + self.merged_adapters.append(active_adapter) + + def unmerge(self) -> None: + # same as lora.quanto.QuantoLoraLinear + from optimum.quanto import quantize_weight + + if not self.merged: + warnings.warn("Already unmerged. Nothing to do.") + return + + while len(self.merged_adapters) > 0: + active_adapter = self.merged_adapters.pop() + if active_adapter not in self.lora_A.keys(): + continue + + base_layer = self.get_base_layer() + orig_weight = base_layer.weight + new_weight_data = orig_weight - self.get_delta_weight(active_adapter) + quantized = quantize_weight(new_weight_data, qtype=orig_weight.qtype, axis=orig_weight.axis) + base_layer.weight._data = quantized._data + base_layer.weight._scale = quantized._scale + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + +def dispatch_quanto( + target: torch.nn.Module, + adapter_name: str, + **kwargs: Any, +) -> Optional[torch.nn.Module]: + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if is_quanto_available() and isinstance(target_base_layer, QLinear): + new_module = QuantoLoraLinear(target, adapter_name, **kwargs) + target.weight = target_base_layer.weight + + if hasattr(target, "bias"): + target.bias = target_base_layer.bias + elif is_quanto_available() and isinstance(target_base_layer, QConv2d): + new_module = QuantoLoraConv2d(target, adapter_name, **kwargs) + target.weight = target_base_layer.weight + + if hasattr(target, "bias"): + target.bias = target_base_layer.bias + + return new_module diff --git a/src/peft/utils/integrations.py b/src/peft/utils/integrations.py index 5c23f404d8..2edd87480b 100644 --- a/src/peft/utils/integrations.py +++ b/src/peft/utils/integrations.py @@ -130,15 +130,22 @@ def get_layer_device_map(model): """ Derive the device map for the layers of the model. """ - main_device = [d for d in model.hf_device_map.values() if d not in ["cpu", "disk"]][0] + if not hasattr(model, "hf_device_map"): + return None + + if (len(model.hf_device_map) == 1) and hasattr(model, "device"): + # E.g. with quanto, when the model is loaded as: + # `model = AutoModel.from_pretrained(model_id, quantization_config=quanto_config)` + # Then the model.hf_device_map is set to {'': 'cpu'}, even if model.to(0) is called later. Thus we can't fully + # rely on the hf_device_map. + main_device = model.device + else: + main_device = [d for d in model.hf_device_map.values() if d not in ["cpu", "disk"]][0] execution_device_map = { name: main_device if device in ["cpu", "disk"] else device for name, device in model.hf_device_map.items() } - if execution_device_map is None: - return None - if len(execution_device_map) == 1 and "" in execution_device_map: return {idx: execution_device_map[""] for idx in range(model.config.num_hidden_layers)} @@ -168,6 +175,9 @@ def map_cache_to_layer_device_map(model, cache) -> None: return layer_device_map = get_layer_device_map(model) + if layer_device_map is None: + return + for idx in range(model.config.num_hidden_layers): layer_device = layer_device_map[idx] cache.key_cache[idx] = cache.key_cache[idx].to(layer_device) diff --git a/tests/test_quanto.py b/tests/test_quanto.py new file mode 100644 index 0000000000..5d10a8070a --- /dev/null +++ b/tests/test_quanto.py @@ -0,0 +1,720 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import platform +import shutil +import tempfile +import unittest +from unittest.mock import Mock, call, patch + +import pytest +import torch +from parameterized import parameterized +from torch import nn +from transformers import AutoModelForCausalLM, AutoTokenizer + +from peft import ( + LoraConfig, + PrefixTuningConfig, + PromptTuningConfig, + PromptTuningInit, + get_peft_model, +) + +from .testing_common import PeftCommonTester, PeftTestConfigManager + + +# only the PEFT methods that are explicitly supported will be tested for merging +PEFT_METHODS_SUPPORTING_MERGING = [LoraConfig] + + +def filter_supported_methods_supporting_merging(test_list): + return [test for test in test_list if any(test[2] is cls for cls in PEFT_METHODS_SUPPORTING_MERGING)] + + +# only test a single model, it's already slow as is +PEFT_DECODER_MODELS_TO_TEST = [ + "hf-internal-testing/tiny-random-OPTForCausalLM", +] + +FULL_GRID = { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "task_type": "CAUSAL_LM", +} + + +def make_automodel_proxy(weights: str): + """Instantiate a quanto-quantized transformers model.""" + from transformers import QuantoConfig + + class QuantoModelProxy: + @classmethod + def from_pretrained(self, *args, **kwargs): + quantization_config = QuantoConfig(weights=weights) + model = AutoModelForCausalLM.from_pretrained(*args, quantization_config=quantization_config, **kwargs) + return model + + return QuantoModelProxy + + +@unittest.skipIf(platform.system() == "Darwin", "Tests are skipped on macOS") +class BasePeftQuantoModelTester: + r"""Base class implementing tests for quanto-quantized models. + + This class is based on PeftDecoderModelTester with some quanto-specific edits, especially for the merging tests, + which are less precise due to the quantization. + + Subclasses should implement the attributes below. + """ + + # The weights argument for quanto, should be "int2", "int4", or "int8" + weights = "MISSING" + # transformers class should be make_automodel_proxy(weights=weights) + transformers_class = "MISSING" + # expected minimum correlation between logits before and after merging + # subclasses should override this with a float between 0 and 1 + min_correlation = "MISSING" + # the allowed tolerance for comparing the output tensors + tol = "MISSING" + + def _get_correlation_matrix(self, *tensors): + return torch.corrcoef(torch.stack([t.flatten() for t in tensors])) + + def check_tensors_approximately_equal(self, *tensors): + # Strict equality checks will fail due to the quantization, so we check: + # 1. The correlation between the tensors is high + # 2. Tensor equality after removing 1% of highest and lowest outliers + cc_matrix = self._get_correlation_matrix(*tensors) + assert cc_matrix.min() > self.min_correlation + + for tensor0, tensor1 in zip(tensors, tensors[1:]): + tensor0, tensor1 = tensor0.flatten(), tensor1.flatten() + diff = tensor0 - tensor1 + indices = torch.argsort(diff) + # remove 1% outliers on both ends + indices = indices[len(indices) // 100 : -len(indices) // 100] + tensor0, tensor1 = tensor0[indices], tensor1[indices] + assert torch.allclose(tensor0, tensor1, atol=self.tol, rtol=self.tol) + + def prepare_inputs_for_testing(self): + input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]]).to(self.torch_device) + attention_mask = torch.tensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) + + input_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + return input_dict + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_attributes_parametrized(self, test_name, model_id, config_cls, config_kwargs): + self._test_model_attr(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_adapter_name(self, test_name, model_id, config_cls, config_kwargs): + self._test_adapter_name(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_prepare_for_training_parametrized(self, test_name, model_id, config_cls, config_kwargs): + self._test_prepare_for_training(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_prompt_tuning_text_prepare_for_training(self, test_name, model_id, config_cls, config_kwargs): + # Test that prompt tuning works with text init + if config_cls != PromptTuningConfig: + return pytest.skip(f"This test does not apply to {config_cls}") + + config_kwargs = config_kwargs.copy() + config_kwargs["prompt_tuning_init"] = PromptTuningInit.TEXT + config_kwargs["prompt_tuning_init_text"] = "This is a test prompt." + config_kwargs["tokenizer_name_or_path"] = model_id + self._test_prepare_for_training(model_id, config_cls, config_kwargs) + + def test_prompt_tuning_text_tokenizer_kwargs(self): + # Allow users to pass additional arguments to Tokenizer.from_pretrained + # Fix for #1032 + mock = Mock() + orig_from_pretrained = AutoTokenizer.from_pretrained + + def mock_autotokenizer_from_pretrained(*args, **kwargs): + mock(*args, **kwargs) + return orig_from_pretrained(config.tokenizer_name_or_path) + + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + config = PromptTuningConfig( + base_model_name_or_path=model_id, + tokenizer_name_or_path=model_id, + num_virtual_tokens=10, + prompt_tuning_init=PromptTuningInit.TEXT, + task_type="CAUSAL_LM", + prompt_tuning_init_text="This is a test prompt.", + tokenizer_kwargs={"trust_remote_code": True, "foo": "bar"}, + ) + model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) + with patch("transformers.AutoTokenizer.from_pretrained", mock_autotokenizer_from_pretrained): + model = get_peft_model(model, config) + + expected_call = call(model_id, trust_remote_code=True, foo="bar") + assert mock.call_args == expected_call + + def test_prompt_tuning_config_invalid_args(self): + # Raise an error when tokenizer_kwargs is used with prompt_tuning_init!='TEXT', because this argument has no + # function in that case + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + with pytest.raises(ValueError, match="tokenizer_kwargs only valid when using prompt_tuning_init='TEXT'."): + PromptTuningConfig( + base_model_name_or_path=model_id, + tokenizer_name_or_path=model_id, + num_virtual_tokens=10, + task_type="CAUSAL_LM", + prompt_tuning_init_text="This is a test prompt.", + prompt_tuning_init=PromptTuningInit.RANDOM, # <= should not be used together with tokenizer_kwargs + tokenizer_kwargs={"trust_remote_code": True, "foo": "bar"}, + ) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_save_pretrained(self, test_name, model_id, config_cls, config_kwargs): + self._test_save_pretrained(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_save_pretrained_pickle(self, test_name, model_id, config_cls, config_kwargs): + self._test_save_pretrained(model_id, config_cls, config_kwargs, safe_serialization=False) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_save_pretrained_selected_adapters(self, test_name, model_id, config_cls, config_kwargs): + self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_save_pretrained_selected_adapters_pickle(self, test_name, model_id, config_cls, config_kwargs): + self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs, safe_serialization=False) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_from_pretrained_config_construction(self, test_name, model_id, config_cls, config_kwargs): + self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters( + { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "lora_kwargs": {"init_lora_weights": [False]}, + "adalora_kwargs": {"init_lora_weights": [False]}, + "ia3_kwargs": {"init_ia3_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, + "vera_kwargs": {"init_weights": [False]}, + "fourierft_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, + "task_type": "CAUSAL_LM", + }, + filter_params_func=filter_supported_methods_supporting_merging, + ) + ) + def test_merge_layers(self, test_name, model_id, config_cls, config_kwargs): + # Not using PeftCommonTester for merging tests as merging is too imprecise. So instead of checking we use a + # custom check that relies on correlation and outlier removal + torch.manual_seed(0) + + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + if config.is_prompt_learning: + pytest.skip("Prompt learning models do not support merging.") + + model = self.transformers_class.from_pretrained(model_id) + model = get_peft_model(model, config) + model = model.to(self.torch_device) + + dummy_input = self.prepare_inputs_for_testing() + model.eval() + logits = model(**dummy_input)[0] + + model.merge_adapter() + logits_merged = model(**dummy_input)[0] + model.unmerge_adapter() + logits_unmerged = model(**dummy_input)[0] + + model = model.merge_and_unload() + logits_merged_unloaded = model(**dummy_input)[0] + + self.check_tensors_approximately_equal(logits, logits_merged, logits_unmerged, logits_merged_unloaded) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters( + { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "lora_kwargs": {"init_lora_weights": [False]}, + "ia3_kwargs": {"init_ia3_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, + "vera_kwargs": {"init_weights": [False]}, + "fourierft_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, + "task_type": "CAUSAL_LM", + }, + filter_params_func=filter_supported_methods_supporting_merging, + ) + ) + # TODO: enable if/when deepcopy-ing is supported + @pytest.mark.skip("Quanto does not work (yet) with deepcopy-ing") + def test_merge_layers_multi(self, test_name, model_id, config_cls, config_kwargs): + # Not using PeftCommonTester for merging tests as merging is too imprecise. So instead of checking we use a + # custom check that relies on correlation and outlier removal + # NOTE: don't use with `torch.inference_mode()`, see: https://github.com/huggingface/optimum-quanto/issues/304 + torch.manual_seed(0) + + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + if config.is_prompt_learning: + pytest.skip("Prompt learning models do not support merging.") + + model = self.transformers_class.from_pretrained(model_id) + model = get_peft_model(model, config) + + model = model.to(self.torch_device) + + dummy_input = self.prepare_inputs_for_testing() + model.eval() + + logits_adapter_1 = model(**dummy_input)[0] + + model.add_adapter("adapter-2", config) + model.set_adapter("adapter-2") + model.eval() + + logits_adapter_2 = model(**dummy_input)[0] + + assert not torch.allclose(logits_adapter_1, logits_adapter_2, atol=1e-3, rtol=1e-3) + + model.set_adapter("default") + + logits_adapter_1_after_set = model(**dummy_input)[0] + + self.check_tensors_approximately_equal(logits_adapter_1, logits_adapter_1_after_set) + + model_copy = copy.deepcopy(model) + model_copy_2 = copy.deepcopy(model) + model_merged_all = model.merge_and_unload(adapter_names=["adapter-2", "default"]) + + logits_merged_all = model_merged_all(**dummy_input)[0] + + assert not torch.allclose(logits_merged_all, logits_adapter_2, atol=1e-3, rtol=1e-3) + assert not torch.allclose(logits_merged_all, logits_adapter_1, atol=1e-3, rtol=1e-3) + + model_merged_adapter_2 = model_copy.merge_and_unload(adapter_names=["adapter-2"]) + + logits_merged_adapter_2 = model_merged_adapter_2(**dummy_input)[0] + + self.check_tensors_approximately_equal(logits_adapter_2, logits_merged_adapter_2) + + model_merged_adapter_default = model_copy_2.merge_and_unload(adapter_names=["default"]) + logits_merged_adapter_default = model_merged_adapter_default(**dummy_input)[0] + + self.check_tensors_approximately_equal(logits_adapter_1, logits_merged_adapter_default) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters( + { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "lora_kwargs": {"init_lora_weights": [False]}, + "ia3_kwargs": {"init_ia3_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, + "task_type": "CAUSAL_LM", + }, + filter_params_func=filter_supported_methods_supporting_merging, + ) + ) + def test_merge_layers_nan(self, test_name, model_id, config_cls, config_kwargs): + # Not using PeftCommonTester for merging tests as merging is too imprecise. So instead of checking we use a + # custom check that relies on correlation and outlier removal + torch.manual_seed(0) + + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + if config.is_prompt_learning: + pytest.skip("Prompt learning models do not support merging.") + + model = self.transformers_class.from_pretrained(model_id) + model = get_peft_model(model, config) + model = model.to(self.torch_device) + + dummy_input = self.prepare_inputs_for_testing() + + model.eval() + + # This should work + logits_unmerged = model(**dummy_input)[0] + + model = model.merge_and_unload() + logits_merged = model(**dummy_input)[0] + + self.check_tensors_approximately_equal(logits_unmerged, logits_merged) + + model = self.transformers_class.from_pretrained(model_id) + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + model = get_peft_model(model, config) + model = model.to(self.torch_device) + + prefixes = ["lora_A", "boft_R", "fourierft_spectrum", "hra_u", "hada_w1", "lokr_w1", "ia3_l", "oft_r"] + prefixes += ["vera_lambda_b"] + + for name, module in model.named_parameters(): + if any(prefix in name for prefix in prefixes): + module.data[0] = torch.nan + + with pytest.raises( + ValueError, match="NaNs detected in the merged weights. The adapter default seems to be broken" + ): + model = model.merge_and_unload(safe_merge=True) + + for name, module in model.named_parameters(): + if any(prefix in name for prefix in prefixes): + module.data[0] = torch.inf + + with pytest.raises( + ValueError, match="NaNs detected in the merged weights. The adapter default seems to be broken" + ): + model = model.merge_and_unload(safe_merge=True) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters( + { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "lora_kwargs": {"init_lora_weights": [False]}, + "adalora_kwargs": {"init_lora_weights": [False]}, + "ia3_kwargs": {"init_ia3_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, + "vera_kwargs": {"init_weights": [False]}, + "fourierft_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, + "task_type": "CAUSAL_LM", + }, + filter_params_func=filter_supported_methods_supporting_merging, + ) + ) + @pytest.mark.xfail(strict=True) + def test_load_merge_and_unloaded_model(self, test_name, model_id, config_cls, config_kwargs): + # Saving and loading a quanto model that has been merged and unloaded does not work correctly. Here is the + # reason: Quanto requires its own save_pretrained method, which, among others, saves the quantization map. + # Without it, the model cannot be correctly loaded. To make use of this, we should thus use a quanto + # QuantizedModel instance instead of a PretrainedModel instance. However, the QuantizedModel instance cannot be + # used for anything else, e.g. it has no __call__ method. Therefore, we cannot use that in PEFT. Therefore, + # users need to pass the PretrainedModel instance to get_peft_model, thus we don't have the modified + # save_pretrained, thus loading the merged and unloaded model does not work. + from optimum.quanto import QuantizedModelForCausalLM + + torch.manual_seed(0) + + model = self.transformers_class.from_pretrained(model_id) + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + model = get_peft_model(model, config) + model = model.to(self.torch_device) + model = model.merge_and_unload() + model.eval() + + dummy_input = self.prepare_inputs_for_testing() + logits = model(**dummy_input)[0] + + # model is a transformers model + tmp_dirname = tempfile.mkdtemp() + # note: not using the context manager here because it fails on Windows CI for some reason + try: + model.save_pretrained(tmp_dirname) + # Carefuly: must use QuantizedModelForCausalLM.from_pretrained not AutoModelForCausalLM.from_pretrained + model_from_pretrained = QuantizedModelForCausalLM.from_pretrained(tmp_dirname).to(self.torch_device) + finally: + try: + shutil.rmtree(tmp_dirname) + except PermissionError: + # windows error + pass + + logits_merged_from_pretrained = model_from_pretrained(**dummy_input)[0] + self.check_tensors_approximately_equal(logits, logits_merged_from_pretrained) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters( + { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "lora_kwargs": {"init_lora_weights": [False]}, + "task_type": "CAUSAL_LM", + }, + filter_params_func=filter_supported_methods_supporting_merging, + ) + ) + def test_mixed_adapter_batches(self, test_name, model_id, config_cls, config_kwargs): + self._test_mixed_adapter_batches(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_generate(self, test_name, model_id, config_cls, config_kwargs): + self._test_generate(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_generate_pos_args(self, test_name, model_id, config_cls, config_kwargs): + # positional args are supported for PeftModelForCausalLM + self._test_generate_pos_args(model_id, config_cls, config_kwargs, raises_err=False) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters(FULL_GRID), + filter_params_func=filter_supported_methods_supporting_merging, + ) + def test_merge_layers_fp16(self, test_name, model_id, config_cls, config_kwargs): + self._test_merge_layers_fp16(model_id, config_cls, config_kwargs) + + # this fails for a couple of methods (IA³, LoRA, prefix tuning) with segfault on GH CI + # @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + # def test_generate_half_prec(self, test_name, model_id, config_cls, config_kwargs): + # self._test_generate_half_prec(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + @pytest.mark.skip("Quanto raises an error when trying to convert the dtype, skipping test.") + def test_prefix_tuning_half_prec_conversion(self, test_name, model_id, config_cls, config_kwargs): + self._test_prefix_tuning_half_prec_conversion(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_training_decoders(self, test_name, model_id, config_cls, config_kwargs): + self._test_training(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_training_decoders_layer_indexing(self, test_name, model_id, config_cls, config_kwargs): + self._test_training_layer_indexing(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_training_decoders_gradient_checkpointing(self, test_name, model_id, config_cls, config_kwargs): + self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_inference_safetensors(self, test_name, model_id, config_cls, config_kwargs): + self._test_inference_safetensors(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_peft_model_device_map(self, test_name, model_id, config_cls, config_kwargs): + self._test_peft_model_device_map(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs): + self._test_delete_adapter(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_kwargs): + self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs): + self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters( + { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "lora_kwargs": {"init_lora_weights": [False]}, + "adalora_kwargs": {"init_lora_weights": [False]}, + "ia3_kwargs": {"init_ia3_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, + "vera_kwargs": {"init_weights": [False]}, + "fourierft_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, + "task_type": "CAUSAL_LM", + }, + filter_params_func=filter_supported_methods_supporting_merging, + ) + ) + def test_unload_adapter(self, test_name, model_id, config_cls, config_kwargs): + self._test_unload_adapter(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters( + { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "lora_kwargs": {"init_lora_weights": [False]}, + "ia3_kwargs": {"init_ia3_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, + "task_type": "CAUSAL_LM", + }, + ) + ) + def test_weighted_combination_of_adapters(self, test_name, model_id, config_cls, config_kwargs): + self._test_weighted_combination_of_adapters(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_training_prompt_learning_tasks(self, test_name, model_id, config_cls, config_kwargs): + self._test_training_prompt_learning_tasks(model_id, config_cls, config_kwargs) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters( + { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "lora_kwargs": {"init_lora_weights": [False]}, + "ia3_kwargs": {"init_ia3_weights": [False]}, + "adalora_kwargs": {"init_lora_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, + "vera_kwargs": {"init_weights": [False]}, + "fourierft_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, + "task_type": "CAUSAL_LM", + }, + filter_params_func=filter_supported_methods_supporting_merging, + ) + ) + def test_disable_adapter(self, test_name, model_id, config_cls, config_kwargs): + self._test_disable_adapter(model_id, config_cls, config_kwargs) + + @parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID)) + def test_passing_input_embeds_works(self, test_name, model_id, config_cls, config_kwargs): + self._test_passing_input_embeds_works(test_name, model_id, config_cls, config_kwargs) + + # TODO: enable if/when deepcopy-ing is supported + @pytest.mark.skip("Quanto does not work (yet) with deepcopy-ing") + def test_lora_layer_replication(self): + model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM" + config_kwargs = { + "target_modules": ["down_proj", "up_proj"], + "task_type": "CAUSAL_LM", + "lora_dropout": 0.0, + "layer_replication": [[0, 1], [0, 2], [1, 2]], + } + model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) + config = LoraConfig( + base_model_name_or_path=model_id, + **config_kwargs, + ) + assert len(model.model.layers), "Expected 2 layers in original model." == 2 + model = get_peft_model(model, config) + layers = model.base_model.model.model.layers + assert len(layers) == 4, "Expected 4 layers in adapted model." + assert ( + layers[0].mlp.up_proj.base_layer.weight.data.storage().data_ptr() + == layers[1].mlp.up_proj.base_layer.weight.data.storage().data_ptr() + and layers[2].mlp.up_proj.base_layer.weight.data.storage().data_ptr() + == layers[3].mlp.up_proj.base_layer.weight.data.storage().data_ptr() + ), "Expected layers 0-1 and 2-3 to share weights" + assert ( + layers[0].mlp.up_proj.base_layer.weight.data.storage().data_ptr() + != layers[2].mlp.up_proj.base_layer.weight.data.storage().data_ptr() + ), "Expected layers 0 and 2 to have different weights" + assert ( + layers[0].mlp.up_proj.lora_A.default.weight.data.storage().data_ptr() + != layers[1].mlp.up_proj.lora_A.default.weight.data.storage().data_ptr() + and layers[2].mlp.up_proj.lora_A.default.weight.data.storage().data_ptr() + != layers[3].mlp.up_proj.lora_A.default.weight.data.storage().data_ptr() + ), "Expected all LoRA adapters to have distinct weights" + assert len([n for n, _ in model.named_parameters() if ".lora_A." in n]) == 8, ( + "Expected 8 LoRA adapters since we are adding one each for up and down." + ) + self._test_prepare_for_training(model_id, LoraConfig, config_kwargs) + self._test_generate(model_id, LoraConfig, config_kwargs) + + def test_prompt_learning_with_grouped_query_attention(self): + # See 1901, fixes a bug with handling GQA + model_id = "peft-internal-testing/tiny-dummy-qwen2" + base_model = AutoModelForCausalLM.from_pretrained(model_id) + peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type="CAUSAL_LM") + model = get_peft_model(base_model, peft_config) + x = torch.tensor([[1, 2, 3]]) + # does not raise + model(x) + + @parameterized.expand( + PeftTestConfigManager.get_grid_parameters( + { + "model_ids": PEFT_DECODER_MODELS_TO_TEST, + "lora_kwargs": {"init_lora_weights": [False]}, + "adalora_kwargs": {"init_lora_weights": [False]}, + "ia3_kwargs": {"init_ia3_weights": [False]}, + "boft_kwargs": {"init_weights": [False]}, + "vera_kwargs": {"init_weights": [False]}, + "fourierft_kwargs": {"init_weights": [False]}, + "hra_kwargs": {"init_weights": [False]}, + "task_type": "CAUSAL_LM", + }, + filter_params_func=filter_supported_methods_supporting_merging, + ) + ) + def test_quanto_merge_conv2d(self, test_name, model_id, config_cls, config_kwargs): + torch.manual_seed(0) + + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + if config.is_prompt_learning: + pytest.skip("Prompt learning models do not support merging.") + + config.target_modules = {"seq.0", "seq.2", "seq.4"} + config.task_type = None + + class ModelConv2D(nn.Module): + def __init__(self): + super().__init__() + self.seq = nn.Sequential( + nn.Conv2d(3, 8, 3), + nn.ReLU(), + nn.Conv2d(8, 8, 3), + nn.ReLU(), + nn.Conv2d(8, 8, 3), + nn.ReLU(), + nn.Flatten(), + nn.Linear(800, 64), + ) + + def forward(self, X): + return self.seq(X) + + model = ModelConv2D() + model = get_peft_model(model, config) + model = model.to(self.torch_device) + + dummy_input = torch.randn(5, 3, 16, 16).to(self.torch_device) + model.eval() + logits = model(dummy_input)[0] + + model.merge_adapter() + logits_merged = model(dummy_input)[0] + model.unmerge_adapter() + logits_unmerged = model(dummy_input)[0] + + model = model.merge_and_unload() + logits_merged_unloaded = model(dummy_input)[0] + + self.check_tensors_approximately_equal(logits, logits_merged, logits_unmerged, logits_merged_unloaded) + + +class PeftQuanto2bitModelTester(unittest.TestCase, PeftCommonTester, BasePeftQuantoModelTester): + weights = "int2" + transformers_class = make_automodel_proxy(weights=weights) + min_correlation = 0.9 + tol = 0.3 + + +class PeftQuanto4bitModelTester(unittest.TestCase, PeftCommonTester, BasePeftQuantoModelTester): + weights = "int4" + transformers_class = make_automodel_proxy(weights=weights) + min_correlation = 0.95 + tol = 1e-2 + + +class PeftQuanto8bitModelTester(unittest.TestCase, PeftCommonTester, BasePeftQuantoModelTester): + weights = "int8" + transformers_class = make_automodel_proxy(weights=weights) + min_correlation = 0.95 + tol = 1e-2 diff --git a/tests/testing_common.py b/tests/testing_common.py index 9e3fbdc667..9d5b9b9a63 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -20,6 +20,7 @@ import tempfile import warnings from collections import OrderedDict +from contextlib import nullcontext from dataclasses import replace import pytest @@ -713,6 +714,10 @@ def _test_merge_layers(self, model_id, config_cls, config_kwargs): if (config.peft_type in {"IA3", "LORA"}) and (model_id in conv_ids): # for some reason, the Conv introduces a larger error atol, rtol = 0.3, 0.01 + if quant_method := getattr(model, "quantization_method", None): + if quant_method.value == "quanto": + atol, rtol = 5e-3, 5e-3 + assert torch.allclose(logits, logits_merged, atol=atol, rtol=rtol) assert torch.allclose(logits, logits_unmerged, atol=atol, rtol=rtol) assert torch.allclose(logits, logits_merged_unloaded, atol=atol, rtol=rtol) @@ -886,6 +891,8 @@ def _test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs): if config_cls not in (LoraConfig,): return pytest.skip(f"Mixed adapter batches not supported for {config_cls}") + from transformers.quantizers.quantizer_quanto import QuantoHfQuantizer + config = config_cls( base_model_name_or_path=model_id, **config_kwargs, @@ -901,18 +908,27 @@ def _test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs): # ensure that we have at least 3 samples for this test dummy_input = {k: torch.cat([v for _ in range(3)]) for k, v in dummy_input.items()} - with torch.inference_mode(): + # Using quanto with inference model raises an error: + # > RuntimeError: Cannot set version_counter for inference tensor + # https://github.com/huggingface/optimum-quanto/issues/304 + # TODO: remove when/if this is fixed + if isinstance(getattr(model, "hf_quantizer", None), QuantoHfQuantizer): + inference_mode = nullcontext + else: + inference_mode = torch.inference_mode + + with inference_mode(): with model.disable_adapter(): output_base = model(**dummy_input)[0] logits_base = model.generate(**dummy_input, return_dict_in_generate=True, output_scores=True).scores[0] model.set_adapter("adapter0") - with torch.inference_mode(): + with inference_mode(): output_adapter0 = model(**dummy_input)[0] logits_adapter0 = model.generate(**dummy_input, return_dict_in_generate=True, output_scores=True).scores[0] model.set_adapter("adapter1") - with torch.inference_mode(): + with inference_mode(): output_adapter1 = model(**dummy_input)[0] logits_adapter1 = model.generate(**dummy_input, return_dict_in_generate=True, output_scores=True).scores[0] @@ -931,7 +947,7 @@ def _test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs): adapters = ["__base__", "adapter0", "adapter1"] dummy_input["adapter_names"] = [adapters[i % 3] for i in (range(len(dummy_input["input_ids"])))] - with torch.inference_mode(): + with inference_mode(): output_mixed = model(**dummy_input)[0] logits_mixed = model.generate(**dummy_input, return_dict_in_generate=True, output_scores=True).scores[0]