Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add support for optimum-quanto #2000

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
a317cc5
[WIP][FEAT] Add support for optimum-quanto
BenjaminBossan Aug 9, 2024
caad385
Add some unit tests
BenjaminBossan Sep 2, 2024
f88edc1
Merge branch 'main' into feat-support-optimum-quanto
BenjaminBossan Sep 2, 2024
44d77b4
More progress on tests, but still many fail
BenjaminBossan Sep 3, 2024
d334eb3
Skip merge tests that are not LoRA
BenjaminBossan Sep 4, 2024
6d8b071
Add tests for int2
BenjaminBossan Sep 4, 2024
c4cc6da
Add test for conv2d
BenjaminBossan Sep 4, 2024
c50c7c6
Add some quanto docs
BenjaminBossan Sep 4, 2024
8cece29
More fixes to quanto tests, should now pass
BenjaminBossan Sep 5, 2024
4b02c8a
Better transformers "emulation"
BenjaminBossan Sep 6, 2024
095da1f
Merge branch 'main' into feat-support-optimum-quanto
BenjaminBossan Sep 19, 2024
b16b98c
Merge branch 'main' into feat-support-optimum-quanto
BenjaminBossan Oct 28, 2024
573583f
Rework tests to use QuantoConfig
BenjaminBossan Oct 28, 2024
252e045
Enable mixed batch inference for Linear
BenjaminBossan Oct 28, 2024
2773b17
Remove obsolete comment
BenjaminBossan Oct 28, 2024
f240c1c
Refactor merging to make tests pass
BenjaminBossan Oct 29, 2024
63e5cdb
Optimum-quanto import check and install for CI
BenjaminBossan Oct 29, 2024
3862b41
Fix import check
BenjaminBossan Oct 29, 2024
85d096f
Apply test filter where appropriate
BenjaminBossan Oct 29, 2024
1538cac
Skip MacOS, comment a segfaulting test
BenjaminBossan Oct 29, 2024
46ee134
Merge branch 'main' into feat-support-optimum-quanto
BenjaminBossan Jan 10, 2025
c86cee0
Some fixes for quanto + hf_device_map
BenjaminBossan Jan 10, 2025
d84a2c0
Merge branch 'main' into feat-support-optimum-quanto
BenjaminBossan Jan 10, 2025
56a3889
Merge branch 'main' into feat-support-optimum-quanto
BenjaminBossan Feb 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions docs/source/developer_guides/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,27 @@ model = get_peft_model(base_model, peft_config)
- DoRA only works with `quant_type = "int8_weight_only"` at the moment.
- There is explicit support for torchao when used with LoRA. However, when torchao quantizes a layer, its class does not change, only the type of the underlying tensor. For this reason, PEFT methods other than LoRA will generally also work with torchao, even if not explicitly supported. Be aware, however, that **merging only works correctly with LoRA and with `quant_type = "int8_weight_only"`**. If you use a different PEFT method or dtype, merging will likely result in an error, and even it doesn't, the results will still be incorrect.

## Optimum-quanto

PEFT supports models quantized with [optimum-quanto](https://github.com/huggingface/optimum-quanto). This has been tested with 2bit, 4bit, and 8bit int quantization. Optimum-quanto also works on CPU and MPS.

```python
from transformers import AutoModelForCausalLM, QuantoConfig

model_id = ...
quantization_config = QuantoConfig(weights="int4") # or qint2 or qint8
base_model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
peft_config = LoraConfig(...)
model = get_peft_model(base_model, peft_config)
```

### Caveats:

- Use optimum-quanto v0.2.5 or above, otherwise saving and loading won't work properly.
- If you want to use optimum-quanto via transformers, install transformers v4.46.0 or above.
- Float8 is discouraged as it can easily produce NaNs.
- There is explicit support for optimum-quanto when used with LoRA. However, when optimum-quanto quantizes a layer, it remains a subclass of the corresponding torch class (e.g., quanto's `QLinear` is a subclass of `nn.Linear`). For this reason, non-LoRA methods will generally also work with optimum-quanto, even if not explicitly supported. Be aware, however, that **merging only works correctly with LoRA**. If you use a method other than LoRA, merging may not raise an error but the results will be incorrect.

## Other Supported PEFT Methods

Besides LoRA, the following PEFT methods also support quantization:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"scipy",
"protobuf",
"sentencepiece",
"optimum-quanto",
]

setup(
Expand Down
7 changes: 7 additions & 0 deletions src/peft/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,10 @@ def is_xpu_available(check_device=False):
except RuntimeError:
return False
return hasattr(torch, "xpu") and torch.xpu.is_available()


@lru_cache
def is_quanto_available():
return (importlib.util.find_spec("optimum") is not None) and (
importlib.util.find_spec("optimum.quanto") is not None
)
2 changes: 2 additions & 0 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from .gptq import dispatch_gptq
from .hqq import dispatch_hqq
from .layer import Conv2d, LoraLayer, dispatch_default
from .quanto import dispatch_quanto
from .torchao import dispatch_torchao
from .tp_layer import dispatch_megatron

Expand Down Expand Up @@ -331,6 +332,7 @@ def dynamic_dispatch_func(target, adapter_name, lora_config, **kwargs):
dispatch_gptq,
dispatch_hqq,
dispatch_torchao,
dispatch_quanto,
dispatch_megatron,
dispatch_default,
]
Expand Down
Loading
Loading