Skip to content

Commit 7cfff96

Browse files
authored
qwen2_bugfix, add adamround vision UT (#281)
Signed-off-by: Zhang, Weiwei1 <[email protected]>
1 parent afa9e26 commit 7cfff96

File tree

20 files changed

+97
-18
lines changed

20 files changed

+97
-18
lines changed

auto_round/autoround.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -720,8 +720,8 @@ def forward(m, hidden_states, *positional_args, **kwargs):
720720
self.inputs[name][key].extend(list(torch.split(alibi.to("cpu"), 1, dim=0)))
721721
else:
722722
self.inputs[name][key] = list(torch.split(alibi.to("cpu"), 1, dim=0))
723-
elif "position_ids" in key or 'cache_position' in key:
724-
if self.train_bs == 1 and self.not_share_rotary_pos_emb_flag:
723+
elif "position_ids" in key or 'cache_position' in key or 'position_embeddings' in key:
724+
if self.train_bs == 1 and self.not_share_position_ids_flag:
725725
if key not in self.inputs[name].keys():
726726
self.inputs[name][key] = [to_device(kwargs[key], device=torch.device("cpu"))]
727727
else:
@@ -1104,7 +1104,7 @@ def quant_blocks(
11041104
input_others[key] = input_others[key].to(tmp_dtype)
11051105
elif isinstance(input_others[key], list):
11061106
for i in range(len(input_others[key])):
1107-
input_others[key][i].to(tmp_dtype)
1107+
to_dtype(input_others[key][i], tmp_dtype)
11081108
pbar = tqdm(range(0, len(block_names), nblocks))
11091109
for i in pbar:
11101110
if nblocks == 1:
@@ -1621,3 +1621,4 @@ def __init__(
16211621
)
16221622

16231623

1624+

auto_round/export/export_to_autogptq/export.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,16 @@ def save_quantized_as_autogptq(output_dir, inplace=True, backend="auto_gptq:exll
113113
"""Export the model to autogptq format to easily leverage cuda kernel."""
114114

115115
model = kwargs["model"]
116-
tokenizer = kwargs["tokenizer"]
117116
supported_types = kwargs["supported_types"]
118117
safe_serialization = True if 'safe_serialization' not in kwargs.keys() else kwargs["safe_serialization"]
119118
quant_block_list = kwargs["quant_block_list"]
120119
logger.info("Saving quantized model to autogptq format, this may take a while...")
120+
tokenizer = kwargs.get("tokenizer", None)
121+
processor = kwargs.get("processor", None)
121122
if tokenizer is not None:
122123
tokenizer.save_pretrained(output_dir)
124+
if processor is not None:
125+
processor.save_pretrained(output_dir)
123126
##check module quantized in block, this may have bug for mixed precision quantization
124127
if bool(quant_block_list):
125128
all_blocks = quant_block_list
@@ -200,3 +203,4 @@ def save(model: torch.nn.Module, save_dir: str, max_shard_size: str = "5GB", saf
200203
if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
201204
with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f:
202205
json.dump(model.config.quantization_config, f, indent=2)
206+

auto_round/export/export_to_autoround/export.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex
198198
layer_config = kwargs["layer_config"]
199199
quantization_config = kwargs["serialization_dict"]
200200
quantization_config["quant_method"] = "intel/auto-round"
201+
tokenizer = kwargs.get("tokenizer", None)
202+
processor = kwargs.get("processor", None)
201203
if "awq" not in backend:
202204
quantization_config["backend"] = backend
203205
extra_config = {}
@@ -235,12 +237,14 @@ def wrapper(name):
235237
model.config.quantization_config = quantization_config
236238
if output_dir is None:
237239
return model
238-
tokenizer = kwargs["tokenizer"]
240+
239241
if output_dir is None:
240242
model.tokenizer = tokenizer
241243
return model
242244
if tokenizer is not None:
243245
tokenizer.save_pretrained(output_dir)
246+
if processor is not None:
247+
processor.save_pretrained(output_dir)
244248
modules_to_not_convert = []
245249
if "awq" not in backend:
246250
save(model, output_dir, safe_serialization=safe_serialization)
@@ -317,3 +321,4 @@ def save_awq(
317321
if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
318322
with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f:
319323
json.dump(quantization_config, f, indent=2)
324+

auto_round/export/export_to_awq/export.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,14 @@ def save_quantized_as_autoawq(output_dir, inplace=True, **kwargs):
9191
enable_minmax_tuning = kwargs["enable_minmax_tuning"]
9292
enable_quanted_input = kwargs["enable_quanted_input"]
9393
scale_dtype = kwargs["scale_dtype"]
94-
tokenizer = kwargs["tokenizer"]
94+
tokenizer = kwargs.get("tokenizer", None)
95+
processor = kwargs.get("processor", None)
9596

9697
logger.info("Saving quantized model to auto_awq format")
9798
if tokenizer is not None:
9899
tokenizer.save_pretrained(output_dir)
100+
if processor is not None:
101+
processor.save_pretrained(output_dir)
99102
##check module quantized in block, this may have bug for mixed precision quantization
100103
modules_to_not_convert = []
101104
if inplace:
@@ -250,3 +253,4 @@ def get_module_name(model, module_to_find):
250253
if module is module_to_find:
251254
return name
252255
return None
256+

auto_round/export/export_to_itrex/export.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ def save_quantized_as_itrex_xpu(output_dir, inplace=True, **kwargs):
121121
enable_minmax_tuning = kwargs["enable_minmax_tuning"]
122122
enable_quanted_input = kwargs["enable_quanted_input"]
123123
scale_dtype = kwargs["scale_dtype"]
124-
tokenizer = kwargs["tokenizer"]
124+
tokenizer = kwargs.get("tokenizer", None)
125+
processor = kwargs.get("processor", None)
125126

126127
compressed_model = pack_model(inplace=inplace, **kwargs)
127128
if output_dir is None:
@@ -149,6 +150,8 @@ def save_quantized_as_itrex_xpu(output_dir, inplace=True, **kwargs):
149150
compressed_model.save_pretrained(output_dir, safe_serialization=True)
150151
if tokenizer is not None:
151152
tokenizer.save_pretrained(output_dir)
153+
if processor is not None:
154+
processor.save_pretrained(output_dir)
152155
logger.info("Saved config file and weights of quantized model to {}.".format(output_dir))
153156
except IOError as e: # pragma: no cover
154157
logger.error("Fail to save configure file and weights due to {}.".format(e))
@@ -252,3 +255,4 @@ def pack_model(
252255
return compressed_model
253256

254257

258+

auto_round/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,14 +366,18 @@ def sampling_inputs(input_ids, input_others, indices, seqlen,
366366
current_input_others = {"positional_inputs": input_others["positional_inputs"]}
367367
for key in input_others.keys():
368368
if not share_attention_mask_flag and ("attention_mask" in key or "alibi" in key) \
369-
or (not_share_position_ids_flag and ("position_ids" in key or "cache_position" in key)) \
369+
or (not_share_position_ids_flag and ("position_ids" in key or \
370+
"cache_position" in key or "position_embeddings" in key)) \
370371
or (not_share_rotary_pos_emb_flag and ("rotary_pos_emb" in key or 'cu_seqlens' in key)) \
371372
or "cross_attention_states" in key:
372373
current_input_others[key] = None
373374
if input_others[key] is not None:
374375
current_input_others[key] = [input_others[key][i] for i in indices]
375376
if not isinstance(current_input_others[key], torch.Tensor):
376-
current_input_others[key] = torch.cat(current_input_others[key], dim=0)
377+
if len(current_input_others[key]) == 1:
378+
current_input_others[key] = current_input_others[key][0]
379+
else:
380+
current_input_others[key] = torch.cat(current_input_others[key], dim=0)
377381
else:
378382
current_input_others[key] = input_others[key]
379383

@@ -973,3 +977,4 @@ def get_autogptq_packing_qlinear(backend, bits=4, group_size=128, sym=False):
973977
return QuantLinear
974978

975979

980+

examples/multimodal-modeling/Qwen-VL/README.md renamed to examples/multimodal-modeling/Common_model/README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,37 @@ print(output_text)
193193
```
194194

195195

196+
- Llama-3.2-11B-Vision-Instruct inference
197+
198+
```python
199+
import requests
200+
import torch
201+
from PIL import Image
202+
from transformers import MllamaForConditionalGeneration, AutoProcessor
203+
from auto_round.auto_quantizer import AutoHfQuantizer
204+
quantized_model_path="./tmp_autoround"
205+
model = MllamaForConditionalGeneration.from_pretrained(
206+
quantized_model_path, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
207+
processor = AutoProcessor.from_pretrained(quantized_model_path)
208+
209+
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
210+
image = Image.open(requests.get(url, stream=True).raw)
211+
212+
prompt = "<|image|><|begin_of_text|>If I had to write a haiku for this one"
213+
inputs = processor(image, prompt, return_tensors="pt", truncation=True).to(model.device)
214+
215+
output = model.generate(**inputs, max_new_tokens=30)
216+
print(processor.decode(output[0]))
217+
218+
# <|begin_of_text|><|image|><|begin_of_text|>If I had to write a haiku for this one, it would be:
219+
220+
# Rabbit in a coat
221+
# Dressed up in style for the day
222+
# Country charm abounds
223+
224+
# The image depicts a rabbit
225+
```
226+
196227
## 4. Results
197228
Using [COCO 2017](https://cocodataset.org/) and [LLaVA-Instruct-150K](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) datasets for quantization calibration, and TextVQA dataset for evaluation. please follow the [recipe](./run_autoround.sh) and [evaluate script](./run_eval.sh). The results for Qwen-VL are as follows:
198229
| Metric | bf16 | INT4 |
@@ -241,3 +272,4 @@ If you find SignRound useful for your research, please cite our paper:
241272

242273

243274

275+

examples/multimodal-modeling/Qwen-VL/main.py renamed to examples/multimodal-modeling/Common_model/main.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def DataFormating(raw_data, image_folder=None, model_type='qwen'):
3434
sentence['value'] = sentence['value'].strip()
3535
if 'qwen2' in model_type: # for Qwen2-vl
3636
replace_token = '<|vision_start|><|image_pad|><|vision_end|>'
37-
if 'mllama' in model_type:
37+
elif 'mllama' in model_type:
3838
replace_token = '<|image|>'
3939
else:
4040
replace_img = os.path.join(image_folder, os.path.basename(source["image"]))
@@ -422,7 +422,7 @@ def get_train_dataloader(train_dataset, model, data_collator=default_data_collat
422422
model_type = config.model_type
423423
if "mllama" in model_type:
424424
from transformers import MllamaForConditionalGeneration
425-
model = MllamaForConditionalGeneration.from_pretrained(args.model_name,
425+
model = MllamaForConditionalGeneration.from_pretrained(args.model_name, attn_implementation="eager",
426426
trust_remote_code=not args.disable_trust_remote_code) # torch_dtype=torch.bfloat16
427427
processor = AutoProcessor.from_pretrained(args.model_name)
428428
tokenizer.processor = processor
@@ -534,17 +534,17 @@ def get_train_dataloader(train_dataset, model, data_collator=default_data_collat
534534
for gpu_format in gpu_formats:
535535
if "round" in gpu_format:
536536
eval_folder = f'{export_dir}-round'
537-
autoround.save_quantized(eval_folder, format=gpu_format, use_triton=False, inplace=inplace)
537+
autoround.save_quantized(eval_folder, format=gpu_format, use_triton=False, inplace=inplace, processor=processor)
538538
elif "gptq" in gpu_format:
539539
eval_folder = f'{export_dir}-gpu'
540-
autoround.save_quantized(eval_folder, format=gpu_format, use_triton=False, inplace=inplace)
540+
autoround.save_quantized(eval_folder, format=gpu_format, use_triton=False, inplace=inplace, processor=processor)
541541

542542
if 'xpu' in deployment_device:
543543
autoround.save_quantized(f'{export_dir}-xpu', format="itrex_xpu", use_triton=True, inplace=inplace,
544544
compression_dtype=torch.int8, compression_dim=0, use_optimum_format=False,
545-
device="xpu")
545+
device="xpu", processor=processor)
546546
if "cpu" in deployment_device:
547-
autoround.save_quantized(output_dir=f'{export_dir}-cpu', format='itrex', inplace=inplace)
547+
autoround.save_quantized(output_dir=f'{export_dir}-cpu', format='itrex', inplace=inplace, processor=processor)
548548
if "fake" in deployment_device:
549549
model = model.to("cpu")
550550
model.save_pretrained(output_dir)
@@ -580,3 +580,4 @@ def get_train_dataloader(train_dataset, model, data_collator=default_data_collat
580580

581581

582582

583+

0 commit comments

Comments
 (0)