Skip to content

Commit

Permalink
add max_input
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed Aug 31, 2023
1 parent 234a427 commit 539c4be
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 10 deletions.
21 changes: 17 additions & 4 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

if is_auto_gptq_available():
from auto_gptq.modeling._utils import autogptq_post_init
from auto_gptq import exllama_set_max_input_length
from auto_gptq.quantization import GPTQ
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear

Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(
batch_size: int = 1,
pad_token_id: Optional[int] = None,
disable_exllama: bool = False,
max_input_length: Optional[int] = None,
*args,
**kwargs,
):
Expand Down Expand Up @@ -107,6 +109,9 @@ def __init__(
The pad token id. Needed to prepare the dataset when `batch_size` > 1.
disable_exllama (`bool`, defaults to `False`):
Whether to use exllama backend. Only works with `bits` = 4.
max_input_length (`Optional[int]`, defaults to `None`):
The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input length.
It is specific to the exllama backend with act-order.
"""

self.bits = bits
Expand All @@ -123,6 +128,7 @@ def __init__(
self.batch_size = batch_size
self.pad_token_id = pad_token_id
self.disable_exllama = disable_exllama
self.max_input_length = max_input_length

if self.bits not in [2, 3, 4, 8]:
raise ValueError("only support quantize to [2,3,4,8] bits.")
Expand Down Expand Up @@ -470,8 +476,11 @@ def post_init_model(self, model):
"Found modules on cpu/disk. Using Exllama backend requires all the modules to be on GPU."
"You can deactivate exllama backend by setting `disable_exllama=True` in the quantization config object"
)

return autogptq_post_init(model, use_act_order=self.desc_act)
model.quantize_config.desc_act = self.desc_act
model = autogptq_post_init(model, use_act_order=self.desc_act)
if self.desc_act and not self.disable_exllama:
model = exllama_set_max_input_length(model,self.max_input_length)
return model

def pack_model(
self,
Expand Down Expand Up @@ -540,7 +549,6 @@ def save(self, model: nn.Module, save_dir: str, max_shard_size: str = "10GB", sa
)

os.makedirs(save_dir, exist_ok=True)
model = model.to("cpu")
# save model and config
accelerator = Accelerator()
accelerator.save_model(model, save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
Expand All @@ -560,6 +568,7 @@ def load_quantized_model(
offload_buffers: Optional[str] = None,
offload_state_dict: bool = False,
disable_exllama: bool = False,
max_input_length: Optional[int] = None,
):
"""
Load quantized weights from the save_folder into the converted model and dispatch the weights according to the device_map.
Expand Down Expand Up @@ -593,7 +602,10 @@ def load_quantized_model(
the weight of the CPU state dict + the biggest shard does not fit. Will default to `True` if the device map
picked contains `"disk"` values.
disable_exllama (`bool`, defaults to `False`):
Whether to use exllama backend. Only works with `bits` = 4.
Whether to use exllama backend. Only works with `bits` = 4.
max_input_length (`Optional[int]`, defaults to `None`):
The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input length.
It is specific to the exllama backend with act-order.
Returns:
`nn.Module`: The quantized model
Expand All @@ -615,6 +627,7 @@ def load_quantized_model(
quantize_config_dict = json.load(f)
quantizer = GPTQQuantizer.from_dict(quantize_config_dict)
quantizer.disable_exllama = disable_exllama
quantizer.max_input_length = max_input_length

model = quantizer.convert_model(model)

Expand Down
11 changes: 9 additions & 2 deletions optimum/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
TORCH_MINIMUM_VERSION = packaging.version.parse("1.11.0")
TRANSFORMERS_MINIMUM_VERSION = packaging.version.parse("4.25.0")
DIFFUSERS_MINIMUM_VERSION = packaging.version.parse("0.18.0")
AUTOGPTQ_MINIMUM_VERSION = packaging.version.parse("0.4.2")


# This is the minimal required version to support some ONNX Runtime features
Expand Down Expand Up @@ -102,8 +103,14 @@ def is_diffusers_available():


def is_auto_gptq_available():
return _auto_gptq_available

if _auto_gptq_available:
version_autogptq = packaging.version.parse(importlib_metadata.version('auto_gptq'))
if AUTOGPTQ_MINIMUM_VERSION <= version_autogptq:
return True
else:
raise ImportError(
f"Found an incompatible version of auto-gptq. Found version {version_autogptq}, but only {AUTOGPTQ_MINIMUM_VERSION} and above are supported"
)

@contextmanager
def check_if_pytorch_greater(target_version: str, message: str):
Expand Down
30 changes: 26 additions & 4 deletions tests/gptq/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import torch
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from transformers.testing_utils import slow

from optimum.gptq import GPTQQuantizer, load_quantized_model
Expand All @@ -35,8 +35,11 @@ class GPTQTest(unittest.TestCase):
input_text = "Hello my name is"
EXPECTED_OUTPUTS = set()
EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I")
EXPECTED_OUTPUTS.add("Hello my name is John, I am a professional photographer and I")
EXPECTED_OUTPUTS.add("Hello my name is jay and i am a student at university.")
EXPECTED_OUTPUTS.add("Hello my name is John, I am a student in the University of")
EXPECTED_OUTPUTS.add("Hello my name is John and I am a very good looking man.")

# this seems a little small considering that we are doing 4bit quant but we have a small model and ww don't quantize the embeddings
EXPECTED_RELATIVE_DIFFERENCE = 1.664253062

Expand Down Expand Up @@ -126,15 +129,34 @@ def test_serialization(self):
self.quantizer.save(self.quantized_model, tmpdirname)
self.quantized_model.config.save_pretrained(tmpdirname)
with init_empty_weights():
empty_model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=torch.float16)
empty_model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(self.model_name), torch_dtype=torch.float16)
empty_model.tie_weights()
quantized_model_from_saved = load_quantized_model(empty_model, save_folder=tmpdirname, device_map={"": 0})
self.check_inference_correctness(quantized_model_from_saved)


class GPTQTestExllama(GPTQTest):
disable_exllama = False
desc_act=True
def test_exllama_max_input_length(self):
from accelerate import init_empty_weights
max_input_length = 4028

with tempfile.TemporaryDirectory() as tmpdirname:
self.quantizer.save(self.quantized_model, tmpdirname)
self.quantized_model.config.save_pretrained(tmpdirname)
with init_empty_weights():
empty_model = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(self.model_name), torch_dtype=torch.float16)
empty_model.tie_weights()
quantized_model_from_saved = load_quantized_model(empty_model, save_folder=tmpdirname, device_map={"": 0})

prompt = "I am in Paris and" * 450

inp = self.tokenizer(prompt, return_tensors="pt").to(0)
self.assertTrue(inp["input_ids"].shape[1] > 2048)

with self.assertRaises(RuntimeError) as cm:
res = quantized_model_from_saved.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)
self.assertTrue("temp_state buffer is too small" in str(cm.exception))

class GPTQUtilsTest(unittest.TestCase):
"""
Expand Down

0 comments on commit 539c4be

Please sign in to comment.