diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 5605393e6..d473e1bf3 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -58,7 +58,6 @@ from auto_round.formats import OutputFormat, get_formats from auto_round.logger import logger from auto_round.schemes import ( - SPECIAL_SCHEMES, QuantizationScheme, _handle_special_schemes, get_gguf_scheme, @@ -610,8 +609,7 @@ def _parse_and_set(scheme, kwargs): scheme = scheme.strip("'\" ") res = scheme scheme = scheme.upper() - if scheme in SPECIAL_SCHEMES: - self.layer_config = _handle_special_schemes(scheme, self.layer_config, self.model) + self.layer_config = _handle_special_schemes(scheme, self.layer_config, self.model) scheme = asdict(preset_name_to_scheme(scheme)) scheme_keys = [f.name for f in fields(QuantizationScheme)] for key in scheme_keys: diff --git a/auto_round/export/export_to_gguf/config.py b/auto_round/export/export_to_gguf/config.py index d4e53ed5b..56f2af953 100644 --- a/auto_round/export/export_to_gguf/config.py +++ b/auto_round/export/export_to_gguf/config.py @@ -193,6 +193,7 @@ class ModelType(IntEnum): # GGUF_CONFIG["gguf:fp16"]["mostly"]= "gguf:fp16" # GGUF_CONFIG["gguf:bf16"] = GGUF_INNER_CONFIG["gguf:fp16"] # GGUF_CONFIG["gguf:bf16"]["mostly"]= "gguf:bf16" +GGUF_CONFIG["gguf:q2_k_mixed"] = GGUF_INNER_CONFIG["gguf:q2_k"] QK_K = 256 diff --git a/auto_round/export/export_to_gguf/convert.py b/auto_round/export/export_to_gguf/convert.py index 1b7e0735d..ed99b1485 100644 --- a/auto_round/export/export_to_gguf/convert.py +++ b/auto_round/export/export_to_gguf/convert.py @@ -229,51 +229,45 @@ def _quant_data(cls, data_torch, data_qtype, name, modify_name, new_name, bid, d """ suffix = ".weight" device = data_torch.device if device is None else device - if suffix in name: + + if name.endswith(suffix): layer_name = name[: -len(suffix)] - module = get_module(cls.model, layer_name) - kwargs = { - "scale": None, - "zp": None, - "d_scale": None, - "d_wmin": None, - "wmin": None, - "imatrix": None, - } - if hasattr(module, "scale"): + else: + layer_name = name + module = get_module(cls.model, layer_name) + kwargs = { + "scale": None, + "zp": None, + "d_scale": None, + "d_wmin": None, + "wmin": None, + "imatrix": None, + } + # support for MOE model with cls eexperts not linear + # if hasattr(module, "scale") or ("exps" in new_name and len(data_torch.shape) == 3): + for attr in ["scale", "zp", "w_d_scale", "w_d_wmin", "w_wmin"]: + if hasattr(module, attr) and getattr(module, attr) is not None: + attr_tensor = getattr(module, attr) + if not isinstance(attr_tensor, torch.Tensor): + continue if hasattr(cls, "permute"): bs = module.weight.shape[0] - for attr in ["scale", "zp", "w_d_scale", "w_d_wmin", "w_wmin"]: - if hasattr(module, attr) and getattr(module, attr) is not None: - attr_tensor = getattr(module, attr) - if not isinstance(attr_tensor, torch.Tensor): - continue - attr_tensors_dict = dict(cls.modify_tensors(attr_tensor.reshape(bs, -1), modify_name, bid)) - attr_tensor = attr_tensors_dict[new_name] - if attr in kwargs: - kwargs[attr] = attr_tensor.to(torch.float32) - else: - kwargs[attr.replace("w_", "")] = attr_tensor.to(torch.float32) - data_torch = data_torch.to(torch.float32) + attr_tensors_dict = dict(cls.modify_tensors(attr_tensor.reshape(bs, -1), modify_name, bid)) + attr_tensor = attr_tensors_dict[new_name] + if attr in kwargs: + kwargs[attr] = attr_tensor.to(torch.float32) + else: + kwargs[attr.replace("w_", "")] = attr_tensor.to(torch.float32) + data_torch = data_torch.to(torch.float32) - data = ggml_quant(data_torch, data_qtype.name.lower(), device=device, **kwargs) - else: - # if data_torch.dtype ==torch.float32: - # data_qtype = gguf.GGMLQuantizationType.F32 - # else: - # data_qtype = gguf.GGMLQuantizationType.F16 - data_qtype = gguf.GGMLQuantizationType.F32 ##FP16 has issues at inference - data = data_torch.to(torch.float32).squeeze().cpu().numpy() - else: - # for Llama-4 - # if data_torch.dtype == torch.float32: - # data_qtype = gguf.GGMLQuantizationType.F32 - # else: - # data_qtype = gguf.GGMLQuantizationType.F16 - # data = data_torch.squeeze().cpu().numpy() - # data_qtype = gguf.GGMLQuantizationType.F32 - # data = data_torch.to(torch.float32).squeeze().cpu().numpy() - data = ggml_quant(data_torch, data_qtype.name.lower(), device=device) + data = ggml_quant(data_torch, data_qtype.name.lower(), device=device, **kwargs) + # else: + # # if data_torch.dtype ==torch.float32: + # # data_qtype = gguf.GGMLQuantizationType.F32 + # # else: + # # data_qtype = gguf.GGMLQuantizationType.F16 + # data_qtype = gguf.GGMLQuantizationType.F32 ##FP16 has issues at inference + # data = data_torch.to(torch.float32).squeeze().cpu().numpy() return data, data_qtype @@ -419,7 +413,9 @@ def prepare_tensors(cls): break if skip: continue - data = data_torch.squeeze() + # sync with new version of gguf + # data = data_torch.squeeze() + data = data_torch n_dims = len(data.shape) data_qtype: gguf.GGMLQuantizationType | bool = cls.tensor_force_quant(name, new_name, bid, n_dims) @@ -529,17 +525,30 @@ def prepare_tensors(cls): elif data_qtype == gguf.GGMLQuantizationType.Q6_K: data_qtype = gguf.GGMLQuantizationType.Q8_0 + from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES + + if data_qtype.name.lower() in GGML_QUANT_SIZES: + block_size, type_size = GGML_QUANT_SIZES[data_qtype.name.lower()] + if data_torch.shape[-1] % block_size != 0: + logger.warning( + f"{new_name}: Can't quantize tensor with shape {data_torch.shape} to {data_qtype.name}," + " fallback to F16" + ) + data_qtype = gguf.GGMLQuantizationType.F16 + if isinstance(data_qtype, bool) or data_qtype in [ gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16, gguf.GGMLQuantizationType.F32, ]: - data = data_torch.squeeze().cpu().numpy() + # sync with new version of gguf + # data = data_torch.squeeze().cpu().numpy() # if data ends up empty, it means data_torch was a scalar tensor -> restore - if len(data.shape) == 0: + if len(data_torch.shape) == 0: data = data_torch.numpy() try: + data = data_torch.cpu().numpy() data = gguf.quants.quantize(data, data_qtype) except gguf.QuantError as e: logger.warning("%s, %s", e, "falling back to F16") diff --git a/auto_round/export/export_to_gguf/packing.py b/auto_round/export/export_to_gguf/packing.py index 9a48cfa0d..c64066932 100644 --- a/auto_round/export/export_to_gguf/packing.py +++ b/auto_round/export/export_to_gguf/packing.py @@ -22,6 +22,7 @@ def register_qtype(name): + def register(cls): GGML_QUANT_TYPE[name] = cls return cls @@ -109,7 +110,6 @@ def ggml_quant( else: new_data = np.concatenate(results, axis=0) new_data = new_data.reshape(*shape[:-1], shape[-1] // block_size * type_size) # Check shape correctness - new_data = new_data.reshape(*shape[:-1], -1) return new_data diff --git a/auto_round/formats.py b/auto_round/formats.py index b2f172f40..15d8b34a6 100644 --- a/auto_round/formats.py +++ b/auto_round/formats.py @@ -76,8 +76,8 @@ def _check_compatibility(formats: list[str], ar: BaseCompressor): ) gguf_format_name = get_gguf_scheme(ar.scheme) if gguf_format_name: - if gguf_format_name.lower().endswith("mixed"): - gguf_format_name = gguf_format_name.lower().replace("_mixed", "_s") + # if gguf_format_name.lower().endswith("mixed"): + # gguf_format_name = gguf_format_name.lower().replace("_mixed", "_s") if any([f.lower() not in ["fake", gguf_format_name.lower()] for f in formats]): tmp_format_name = gguf_format_name.lower() if "fake" not in formats else f"{gguf_format_name.lower()},fake" logger.warning( @@ -98,7 +98,7 @@ def remove_duplicates(lst): seen = set() return [x for x in lst if not (x in seen or seen.add(x))] - formats = format.replace("q*_", f"q{ar.bits}_").replace(" ", "").split(",") + formats = format.lower().replace("q*_", f"q{ar.bits}_").replace(" ", "").split(",") formats = remove_duplicates(formats) # need the keep origin order formats = _check_compatibility(formats, ar) @@ -650,6 +650,7 @@ class GGUFFormat(OutputFormat): "GGUF:Q5_K_M", "GGUF:Q6_K", "GGUF:Q8_0", + "GGUF:Q2_K_MIXED", ] format_name = "gguf" @@ -658,13 +659,22 @@ def __init__(self, format: str, ar: BaseCompressor): self.gguf_args_check(ar, format, model_type=ModelType.TEXT) if ar.mllm: self.gguf_args_check(ar, format, model_type=ModelType.MMPROJ) - ar.scheme = format.upper() self.output_format = "gguf" self.backend_cls = GGUFFormat self.backend = GGUFFormat(format.split(":")[-1], ar) else: - self.output_format = f"gguf:{format}" + scheme = ar.scheme + gguf_format = f"gguf:{format.lower()}" + if format.lower().endswith("_mixed"): + from auto_round.schemes import _handle_special_schemes + + ar.layer_config = _handle_special_schemes(scheme, ar.layer_config, ar.model) + gguf_format = gguf_format.lower().replace("_mixed", "_s") + if isinstance(scheme, str) and scheme.lower() != gguf_format: + logger.warning(f"reset scheme {scheme.lower()} to {gguf_format} for gguf format export") + ar.scheme = gguf_format + self.output_format = gguf_format self.backend = None self.mllm = ar.mllm diff --git a/auto_round/schemes.py b/auto_round/schemes.py index 0ca7710e0..95078742c 100644 --- a/auto_round/schemes.py +++ b/auto_round/schemes.py @@ -295,16 +295,13 @@ def is_preset_scheme(name: str) -> bool: value.pop("lm_head", None) PRESET_SCHEMES[key.upper()] = QuantizationScheme.from_dict(value) -SPECIAL_SCHEMES = {"GGUF:Q2_K_MIXED": PRESET_SCHEMES["GGUF:Q2_K_S"]} -PRESET_SCHEMES.update(SPECIAL_SCHEMES) - def _handle_special_schemes(scheme_name: str, layer_config: dict, model: torch.nn.Module) -> dict: """handle special schemes, like GGUF:Q2_K_MIXED. Provide some special auto_round recipes. """ - if scheme_name == "GGUF:Q2_K_MIXED": + if scheme_name.lower() == "gguf:q2_k_mixed": for n, m in model.named_modules(): if n in layer_config: continue diff --git a/docs/step_by_step.md b/docs/step_by_step.md index 06e65c56a..654d55d50 100644 --- a/docs/step_by_step.md +++ b/docs/step_by_step.md @@ -149,12 +149,14 @@ adopted within the community, **only 4-bits quantization is supported**. Please **LLM-Compressor Format**: **NVFP4, MXFP4(kernel in WIP), MXFP8 are supported**. Please set `--format llm_compressor` #### Format and scheme support matrix +> Italics indicates the absence of a kernel or the presence of only an inefficient/reference kernel. + |export format | supported scheme | |--------------|------------------| -|**auto_round** | W4A16, W2A16, W3A16, W8A16, MXFP4, MXFP8, NVFP4, FPW8A16, W2A16G64, W2A16G32, FP8_STATIC, BF16| +|**auto_round** | W4A16, W2A16, W3A16, W8A16, *MXFP4*, *MXFP8*, NVFP4, FPW8A16, W2A16G64, W2A16G32, *FP8_STATIC*, BF16| |**auto_awq / auto_round:auto_awq** | W4A16| |**auto_gptq / auto_round:auto_gptq / auto_round:gptqmodel**|W4A16, W2A16, W3A16, W8A16, BF16, W2A16G64, W2A16G32| -|**llm_compressor / auto_round:llm_compressor** | MXFP4, MXFP8, NVFP4, FPW8A16, FP8_STATIC | +|**llm_compressor / auto_round:llm_compressor** | *MXFP4*, *MXFP8*, NVFP4, FPW8A16, *FP8_STATIC* | |**gguf** | GGUF:Q4_0, GGUF:Q4_1, GGUF:Q5_0, GGUF:Q5_1, GGUF:Q2_K_S, GGUF:Q3_K_S, GGUF:Q3_K_M, GGUF:Q3_K_L, GGUF:Q4_K_S, GGUF:Q4_K_M, GGUF:Q5_K_S, GGUF:Q5_K_M, GGUF:Q6_K, GGUF:Q8_0 | |**fake** | all scheme| ### Hardware Compatibility diff --git a/test/helpers.py b/test/helpers.py index 6bd5ad30c..e8630e1fa 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -70,6 +70,18 @@ def slice_layers(module): if hasattr(model.config, "num_hidden_layers"): model.config.num_hidden_layers = num_layers + if hasattr(model.config, "text_config"): + n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"] + for key in n_block_keys: + if hasattr(model.config.text_config, key): + setattr(model.config.text_config, key, num_layers) + if hasattr(model.config.text_config, "layer_types"): + model.config.text_config.layer_types = model.config.text_config.layer_types[:num_layers] + if hasattr(model.config, "vision_config"): + n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"] + for key in n_block_keys: + if hasattr(model.config.vision_config, key): + setattr(model.config.vision_config, key, num_layers) if hasattr(model.config, "layer_types"): model.config.layer_types = model.config.layer_types[:num_layers] diff --git a/test/test_cpu/export/test_gguf_format.py b/test/test_cpu/export/test_gguf_format.py index 7304f22c7..feeddcd0f 100644 --- a/test/test_cpu/export/test_gguf_format.py +++ b/test/test_cpu/export/test_gguf_format.py @@ -189,19 +189,26 @@ def test_all_format(self, tiny_qwen_model_path): assert False, "cmd line test fail, please have a check" shutil.rmtree("../../tmp_autoround", ignore_errors=True) + res = os.system( + f"PYTHONPATH='AUTO_ROUND_PATH:$PYTHONPATH' {python_path} -m auto_round --model {model_name}" + f" --bs 16 --iters 0 --nsamples 1 --seqlen 16 --format gguf:q2_k_mixed" + ) + if res > 0 or res == -1: + assert False, "cmd line test fail, please have a check" + shutil.rmtree("../../tmp_autoround", ignore_errors=True) + def test_vlm_gguf(self): + from ...helpers import save_tiny_model + model_name = get_model_path("Qwen/Qwen2-VL-2B-Instruct") + tiny_model_path = save_tiny_model(model_name, "./tmp/tiny_qwen_vl_model_path", num_layers=3, is_mllm=True) from auto_round import AutoRoundMLLM - from auto_round.utils import mllm_load_model - model, processor, tokenizer, image_processor = mllm_load_model(model_name) autoround = AutoRoundMLLM( - model, - tokenizer=tokenizer, - processor=processor, - image_processor=image_processor, + tiny_model_path, iters=0, nsamples=8, + disable_opt_rtn=True, ) quantized_model_path = "./saved" autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_0") @@ -209,10 +216,11 @@ def test_vlm_gguf(self): for file_name in os.listdir(quantized_model_path): file_size = os.path.getsize(os.path.join(quantized_model_path, file_name)) / 1024**2 if file_name == "mmproj-model.gguf": - assert abs(file_size - 2537) < 5.0 + assert abs(file_size - 56) < 5.0 else: - assert abs(file_size - 892) < 5.0 + assert abs(file_size - 264) < 5.0 shutil.rmtree("./saved", ignore_errors=True) + shutil.rmtree(tiny_model_path, ignore_errors=True) def test_qtype_setting(self): # Qwen2.5-0.5B-Instruct no output, token_embed q6_k fallbakc to q8_0 336M diff --git a/test/test_cuda/export/test_gguf.py b/test/test_cuda/export/test_gguf.py index 4146cb938..a6089a086 100644 --- a/test/test_cuda/export/test_gguf.py +++ b/test/test_cuda/export/test_gguf.py @@ -140,46 +140,52 @@ def test_all_format(self): shutil.rmtree(self.save_dir, ignore_errors=True) @require_gguf - def test_vlm_gguf(self): - model_name = "/models/Qwen2-VL-2B-Instruct" - from auto_round import AutoRoundMLLM - from auto_round.utils import mllm_load_model + def test_special_model(self): + from ...helpers import save_tiny_model - model, processor, tokenizer, image_processor = mllm_load_model(model_name) - autoround = AutoRoundMLLM( - model, - tokenizer=tokenizer, - processor=processor, - image_processor=image_processor, - device="auto", + model_name = get_model_path("ibm-granite/granite-4.0-h-tiny") + tiny_model_path = save_tiny_model(model_name, "tiny_model_path", num_layers=2) + from auto_round import AutoRound + + autoround = AutoRound( + tiny_model_path, iters=0, + nsamples=8, + disable_opt_rtn=True, ) quantized_model_path = "./saved" autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_0") - assert "mmproj-model.gguf" in os.listdir("./saved") - file_size = os.path.getsize("./saved/Qwen2-VL-2B-Instruct-Q4_0.gguf") / 1024**2 - assert abs(file_size - 894) < 5.0 - file_size = os.path.getsize("./saved/mmproj-model.gguf") / 1024**2 - assert abs(file_size - 2580) < 5.0 + file_name = os.listdir(quantized_model_path)[0] + file_size = os.path.getsize(os.path.join(quantized_model_path, file_name)) / 1024**2 + assert abs(file_size - 307) < 5.0 shutil.rmtree("./saved", ignore_errors=True) + shutil.rmtree(tiny_model_path, ignore_errors=True) - model_name = "/models/gemma-3-12b-it" + @require_gguf + def test_vlm_gguf(self): + from ...helpers import save_tiny_model - model, processor, tokenizer, image_processor = mllm_load_model(model_name) - autoround = AutoRoundMLLM( - model, - tokenizer=tokenizer, - processor=processor, - image_processor=image_processor, + model_name = "/models/gemma-3-4b-it" + tiny_model_path = save_tiny_model(model_name, "tiny_model_path", num_layers=3, is_mllm=True) + from auto_round import AutoRound + + autoround = AutoRound( + tiny_model_path, device="auto", nsamples=32, iters=0, + disable_opt_rtn=True, ) quantized_model_path = "./saved" autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_k_m") assert "mmproj-model.gguf" in os.listdir("./saved") - file_size = os.path.getsize("./saved/gemma-3-12B-it-Q4_K_M.gguf") / 1024**2 - assert abs(file_size - 6568) < 5.0 - file_size = os.path.getsize("./saved/mmproj-model.gguf") / 1024**2 - assert abs(file_size - 1599) < 5.0 + for file in os.listdir("./saved"): + print(f"{file}: {os.path.getsize(os.path.join('./saved', file)) / 1024**2} MB") + file_size = os.path.getsize(os.path.join("./saved", file)) / 1024**2 + if "mmproj-model.gguf" in file: + assert abs(file_size - 75) < 5.0 + else: + assert abs(file_size - 690) < 5.0 + shutil.rmtree(quantized_model_path, ignore_errors=True) + shutil.rmtree(tiny_model_path, ignore_errors=True)