-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEAT] Add support for optimum-quanto #2000
base: main
Are you sure you want to change the base?
[FEAT] Add support for optimum-quanto #2000
Conversation
This is unfinished, only pure implementations are provided. TODOs: - [ ] Documentation - [ ] Tests (should work on CPU!) - [ ] Whether Conv2d works is not verified yet - [ ] Optional: DoRA support - [ ] Optional: Mixed adapter batches support
This is what I used for "testing" so far and the results look correct: import torch
from peft import LoraConfig, set_peft_model_state_dict, get_peft_model
from optimum.quanto import quantize, freeze, qint8
from transformers import AutoModelForCausalLM
torch.manual_seed(0)
inputs = torch.arange(5).view(-1, 1)
print("loading model")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m").eval()
with torch.inference_mode():
output_base = model(inputs).logits
print("output_base")
print(output_base[0, 0, :5])
# Step 3: Quantize the Model
print("quantizing model")
quantize(model, weights=qint8)
print("freezing model")
freeze(model)
with torch.inference_mode():
output_quantized = model(inputs).logits
print("output_quantized")
print(output_quantized[0, 0, :5])
config = LoraConfig(r=64, lora_alpha=1280, lora_dropout=0.1, init_lora_weights=False)
print("adding adapter (random)")
model = get_peft_model(model, config)
model.eval()
with torch.inference_mode():
output_lora = model(inputs).logits
print("output_lora")
print(output_lora[0, 0, :5])
with model.disable_adapter():
output_disabled = model(inputs).logits
print("output_disabled")
print(output_disabled[0, 0, :5])
output_after_disabled = model(inputs).logits
print("output_after_disabled")
print(output_after_disabled[0, 0, :5])
model.merge_adapter()
with torch.inference_mode():
output_merged = model(inputs).logits
print("output_merged")
print(output_merged[0, 0, :5])
model.unmerge_adapter()
with torch.inference_mode():
output_unmerged = model(inputs).logits
print("output_unmerged")
print(output_unmerged[0, 0, :5])
unloaded = model.merge_and_unload()
with torch.inference_mode():
output_unloaded = unloaded(inputs).logits
print("output_unloaded")
print(output_unloaded[0, 0, :5]) If someone wants to test this, they can checkout this branch or they can copy-paste the layer definitions and then dynamically dispatch to the new layers using the normal PEFT release: from optimum.quanto import QConv2d, QLinear
# copy code for QuantoLoraLinear and QuantoLoraConv2d
custom_module_mapping = {QConv2d: QuantoLoraConv2d, QLinear: QuantoLoraLinear}
config = LoraConfig(...)
config._register_custom_module(custom_module_mapping) |
i pulled the new peft build from your branch and applied the mapping to the LoraConfig, but I still see this error when it comes time to loading the state dict. I think the problem is on |
Other methods would have to explicitly support quanto for these tests to pass.
float8 produces nans, so not used right now.
Status update: Optimum-quanto v0.2.5 is released and is the minimum version for this to work. Moreover, huggingface/transformers#31732 is merged but it's not part of the latest transformers release yet. As we don't want to depend on an unreleased transformers version and as we're not in a huge hurry, let's wait for the next transformers release. |
There is an issue with quanto not working with torch.inference_mode that the test needs to work around.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ! Sorry for the delay !
Failing Windows CI is caused by a known issue in quanto that's already been fixed, but there was no release yet. |
This is unfinished, only pure implementations are provided.
Resolves #1997
TODOs:
QuantoLoraConv2d
works_data
and_scales
, overriding.data
did not have any effect.State of unit tests
Since quanto layers are subclasses of their respective torch equivalents, they will generally work with PEFT methods, even if not supported explicitly. E.g. BOFT will "just work" with quanto. However, some merging etc. won't work properly, as this requires special handling for quanto. Therefore, these tests are skipped.
It could be argued that we should explicitly raise when trying to use a non-supported method with quanto. However, we don't do that in general, as we assume that a subclass relationship should mean that the method works with that module. We could do strict checking of type (not subclass), but who knows how much existing code would break for no reason because of that.
Merging tests had to be relaxed,
torch.allclose
would require quite a high tolerance to pass. Therefore, instead now measure that correlation is > 0.97, which is more robust to outliers.Moreover, a bunch of tests needed to be skipped, e.g. because quanto does not support deepcopy-ing, and either the PEFT functionaliy (layer replication) or the test itself depends on copying. Also, quanto does not allow to convert the dtype (like calling
model.half()
).