Skip to content
4 changes: 1 addition & 3 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
from auto_round.formats import OutputFormat, get_formats
from auto_round.logger import logger
from auto_round.schemes import (
SPECIAL_SCHEMES,
QuantizationScheme,
_handle_special_schemes,
get_gguf_scheme,
Expand Down Expand Up @@ -610,8 +609,7 @@ def _parse_and_set(scheme, kwargs):
scheme = scheme.strip("'\" ")
res = scheme
scheme = scheme.upper()
if scheme in SPECIAL_SCHEMES:
self.layer_config = _handle_special_schemes(scheme, self.layer_config, self.model)
self.layer_config = _handle_special_schemes(scheme, self.layer_config, self.model)
scheme = asdict(preset_name_to_scheme(scheme))
scheme_keys = [f.name for f in fields(QuantizationScheme)]
for key in scheme_keys:
Expand Down
1 change: 1 addition & 0 deletions auto_round/export/export_to_gguf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ class ModelType(IntEnum):
# GGUF_CONFIG["gguf:fp16"]["mostly"]= "gguf:fp16"
# GGUF_CONFIG["gguf:bf16"] = GGUF_INNER_CONFIG["gguf:fp16"]
# GGUF_CONFIG["gguf:bf16"]["mostly"]= "gguf:bf16"
GGUF_CONFIG["gguf:q2_k_mixed"] = GGUF_INNER_CONFIG["gguf:q2_k"]


QK_K = 256
Expand Down
97 changes: 53 additions & 44 deletions auto_round/export/export_to_gguf/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,51 +229,45 @@ def _quant_data(cls, data_torch, data_qtype, name, modify_name, new_name, bid, d
"""
suffix = ".weight"
device = data_torch.device if device is None else device
if suffix in name:

if name.endswith(suffix):
layer_name = name[: -len(suffix)]
module = get_module(cls.model, layer_name)
kwargs = {
"scale": None,
"zp": None,
"d_scale": None,
"d_wmin": None,
"wmin": None,
"imatrix": None,
}
if hasattr(module, "scale"):
else:
layer_name = name
module = get_module(cls.model, layer_name)
kwargs = {
"scale": None,
"zp": None,
"d_scale": None,
"d_wmin": None,
"wmin": None,
"imatrix": None,
}
# support for MOE model with cls eexperts not linear
# if hasattr(module, "scale") or ("exps" in new_name and len(data_torch.shape) == 3):
for attr in ["scale", "zp", "w_d_scale", "w_d_wmin", "w_wmin"]:
if hasattr(module, attr) and getattr(module, attr) is not None:
attr_tensor = getattr(module, attr)
if not isinstance(attr_tensor, torch.Tensor):
continue
if hasattr(cls, "permute"):
bs = module.weight.shape[0]
for attr in ["scale", "zp", "w_d_scale", "w_d_wmin", "w_wmin"]:
if hasattr(module, attr) and getattr(module, attr) is not None:
attr_tensor = getattr(module, attr)
if not isinstance(attr_tensor, torch.Tensor):
continue
attr_tensors_dict = dict(cls.modify_tensors(attr_tensor.reshape(bs, -1), modify_name, bid))
attr_tensor = attr_tensors_dict[new_name]
if attr in kwargs:
kwargs[attr] = attr_tensor.to(torch.float32)
else:
kwargs[attr.replace("w_", "")] = attr_tensor.to(torch.float32)
data_torch = data_torch.to(torch.float32)
attr_tensors_dict = dict(cls.modify_tensors(attr_tensor.reshape(bs, -1), modify_name, bid))
attr_tensor = attr_tensors_dict[new_name]
if attr in kwargs:
kwargs[attr] = attr_tensor.to(torch.float32)
else:
kwargs[attr.replace("w_", "")] = attr_tensor.to(torch.float32)
data_torch = data_torch.to(torch.float32)

data = ggml_quant(data_torch, data_qtype.name.lower(), device=device, **kwargs)
else:
# if data_torch.dtype ==torch.float32:
# data_qtype = gguf.GGMLQuantizationType.F32
# else:
# data_qtype = gguf.GGMLQuantizationType.F16
data_qtype = gguf.GGMLQuantizationType.F32 ##FP16 has issues at inference
data = data_torch.to(torch.float32).squeeze().cpu().numpy()
else:
# for Llama-4
# if data_torch.dtype == torch.float32:
# data_qtype = gguf.GGMLQuantizationType.F32
# else:
# data_qtype = gguf.GGMLQuantizationType.F16
# data = data_torch.squeeze().cpu().numpy()
# data_qtype = gguf.GGMLQuantizationType.F32
# data = data_torch.to(torch.float32).squeeze().cpu().numpy()
data = ggml_quant(data_torch, data_qtype.name.lower(), device=device)
data = ggml_quant(data_torch, data_qtype.name.lower(), device=device, **kwargs)
# else:
# # if data_torch.dtype ==torch.float32:
# # data_qtype = gguf.GGMLQuantizationType.F32
# # else:
# # data_qtype = gguf.GGMLQuantizationType.F16
# data_qtype = gguf.GGMLQuantizationType.F32 ##FP16 has issues at inference
# data = data_torch.to(torch.float32).squeeze().cpu().numpy()
return data, data_qtype


Expand Down Expand Up @@ -419,7 +413,9 @@ def prepare_tensors(cls):
break
if skip:
continue
data = data_torch.squeeze()
# sync with new version of gguf
# data = data_torch.squeeze()
data = data_torch
n_dims = len(data.shape)
data_qtype: gguf.GGMLQuantizationType | bool = cls.tensor_force_quant(name, new_name, bid, n_dims)

Expand Down Expand Up @@ -529,17 +525,30 @@ def prepare_tensors(cls):
elif data_qtype == gguf.GGMLQuantizationType.Q6_K:
data_qtype = gguf.GGMLQuantizationType.Q8_0

from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES

if data_qtype.name.lower() in GGML_QUANT_SIZES:
block_size, type_size = GGML_QUANT_SIZES[data_qtype.name.lower()]
if data_torch.shape[-1] % block_size != 0:
logger.warning(
f"{new_name}: Can't quantize tensor with shape {data_torch.shape} to {data_qtype.name},"
" fallback to F16"
)
data_qtype = gguf.GGMLQuantizationType.F16

if isinstance(data_qtype, bool) or data_qtype in [
gguf.GGMLQuantizationType.F16,
gguf.GGMLQuantizationType.BF16,
gguf.GGMLQuantizationType.F32,
]:
data = data_torch.squeeze().cpu().numpy()
# sync with new version of gguf
# data = data_torch.squeeze().cpu().numpy()

# if data ends up empty, it means data_torch was a scalar tensor -> restore
if len(data.shape) == 0:
if len(data_torch.shape) == 0:
data = data_torch.numpy()
try:
data = data_torch.cpu().numpy()
data = gguf.quants.quantize(data, data_qtype)
except gguf.QuantError as e:
logger.warning("%s, %s", e, "falling back to F16")
Expand Down
2 changes: 1 addition & 1 deletion auto_round/export/export_to_gguf/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@


def register_qtype(name):

def register(cls):
GGML_QUANT_TYPE[name] = cls
return cls
Expand Down Expand Up @@ -109,7 +110,6 @@ def ggml_quant(
else:
new_data = np.concatenate(results, axis=0)
new_data = new_data.reshape(*shape[:-1], shape[-1] // block_size * type_size) # Check shape correctness
new_data = new_data.reshape(*shape[:-1], -1)
return new_data


Expand Down
20 changes: 15 additions & 5 deletions auto_round/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def _check_compatibility(formats: list[str], ar: BaseCompressor):
)
gguf_format_name = get_gguf_scheme(ar.scheme)
if gguf_format_name:
if gguf_format_name.lower().endswith("mixed"):
gguf_format_name = gguf_format_name.lower().replace("_mixed", "_s")
# if gguf_format_name.lower().endswith("mixed"):
# gguf_format_name = gguf_format_name.lower().replace("_mixed", "_s")
if any([f.lower() not in ["fake", gguf_format_name.lower()] for f in formats]):
tmp_format_name = gguf_format_name.lower() if "fake" not in formats else f"{gguf_format_name.lower()},fake"
logger.warning(
Expand All @@ -98,7 +98,7 @@ def remove_duplicates(lst):
seen = set()
return [x for x in lst if not (x in seen or seen.add(x))]

formats = format.replace("q*_", f"q{ar.bits}_").replace(" ", "").split(",")
formats = format.lower().replace("q*_", f"q{ar.bits}_").replace(" ", "").split(",")
formats = remove_duplicates(formats) # need the keep origin order

formats = _check_compatibility(formats, ar)
Expand Down Expand Up @@ -650,6 +650,7 @@ class GGUFFormat(OutputFormat):
"GGUF:Q5_K_M",
"GGUF:Q6_K",
"GGUF:Q8_0",
"GGUF:Q2_K_MIXED",
]
format_name = "gguf"

Expand All @@ -658,13 +659,22 @@ def __init__(self, format: str, ar: BaseCompressor):
self.gguf_args_check(ar, format, model_type=ModelType.TEXT)
if ar.mllm:
self.gguf_args_check(ar, format, model_type=ModelType.MMPROJ)
ar.scheme = format.upper()

self.output_format = "gguf"
self.backend_cls = GGUFFormat
self.backend = GGUFFormat(format.split(":")[-1], ar)
else:
self.output_format = f"gguf:{format}"
scheme = ar.scheme
gguf_format = f"gguf:{format.lower()}"
if format.lower().endswith("_mixed"):
from auto_round.schemes import _handle_special_schemes

ar.layer_config = _handle_special_schemes(scheme, ar.layer_config, ar.model)
gguf_format = gguf_format.lower().replace("_mixed", "_s")
if isinstance(scheme, str) and scheme.lower() != gguf_format:
logger.warning(f"reset scheme {scheme.lower()} to {gguf_format} for gguf format export")
ar.scheme = gguf_format
self.output_format = gguf_format
self.backend = None
self.mllm = ar.mllm

Expand Down
5 changes: 1 addition & 4 deletions auto_round/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,16 +295,13 @@ def is_preset_scheme(name: str) -> bool:
value.pop("lm_head", None)
PRESET_SCHEMES[key.upper()] = QuantizationScheme.from_dict(value)

SPECIAL_SCHEMES = {"GGUF:Q2_K_MIXED": PRESET_SCHEMES["GGUF:Q2_K_S"]}
PRESET_SCHEMES.update(SPECIAL_SCHEMES)


def _handle_special_schemes(scheme_name: str, layer_config: dict, model: torch.nn.Module) -> dict:
"""handle special schemes, like GGUF:Q2_K_MIXED.
Provide some special auto_round recipes.

"""
if scheme_name == "GGUF:Q2_K_MIXED":
if scheme_name.lower() == "gguf:q2_k_mixed":
for n, m in model.named_modules():
if n in layer_config:
continue
Expand Down
6 changes: 4 additions & 2 deletions docs/step_by_step.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,14 @@ adopted within the community, **only 4-bits quantization is supported**. Please
**LLM-Compressor Format**: **NVFP4, MXFP4(kernel in WIP), MXFP8 are supported**. Please set `--format llm_compressor`

#### Format and scheme support matrix
> Italics indicates the absence of a kernel or the presence of only an inefficient/reference kernel.

|export format | supported scheme |
|--------------|------------------|
|**auto_round** | W4A16, W2A16, W3A16, W8A16, MXFP4, MXFP8, NVFP4, FPW8A16, W2A16G64, W2A16G32, FP8_STATIC, BF16|
|**auto_round** | W4A16, W2A16, W3A16, W8A16, *MXFP4*, *MXFP8*, NVFP4, FPW8A16, W2A16G64, W2A16G32, *FP8_STATIC*, BF16|
|**auto_awq / auto_round:auto_awq** | W4A16|
|**auto_gptq / auto_round:auto_gptq / auto_round:gptqmodel**|W4A16, W2A16, W3A16, W8A16, BF16, W2A16G64, W2A16G32|
|**llm_compressor / auto_round:llm_compressor** | MXFP4, MXFP8, NVFP4, FPW8A16, FP8_STATIC |
|**llm_compressor / auto_round:llm_compressor** | *MXFP4*, *MXFP8*, NVFP4, FPW8A16, *FP8_STATIC* |
|**gguf** | GGUF:Q4_0, GGUF:Q4_1, GGUF:Q5_0, GGUF:Q5_1, GGUF:Q2_K_S, GGUF:Q3_K_S, GGUF:Q3_K_M, GGUF:Q3_K_L, GGUF:Q4_K_S, GGUF:Q4_K_M, GGUF:Q5_K_S, GGUF:Q5_K_M, GGUF:Q6_K, GGUF:Q8_0 |
|**fake** | all scheme|
### Hardware Compatibility
Expand Down
12 changes: 12 additions & 0 deletions test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@ def slice_layers(module):

if hasattr(model.config, "num_hidden_layers"):
model.config.num_hidden_layers = num_layers
if hasattr(model.config, "text_config"):
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"]
for key in n_block_keys:
if hasattr(model.config.text_config, key):
setattr(model.config.text_config, key, num_layers)
if hasattr(model.config.text_config, "layer_types"):
model.config.text_config.layer_types = model.config.text_config.layer_types[:num_layers]
if hasattr(model.config, "vision_config"):
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"]
for key in n_block_keys:
if hasattr(model.config.vision_config, key):
setattr(model.config.vision_config, key, num_layers)
if hasattr(model.config, "layer_types"):
model.config.layer_types = model.config.layer_types[:num_layers]

Expand Down
24 changes: 16 additions & 8 deletions test/test_cpu/export/test_gguf_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,30 +189,38 @@ def test_all_format(self, tiny_qwen_model_path):
assert False, "cmd line test fail, please have a check"
shutil.rmtree("../../tmp_autoround", ignore_errors=True)

res = os.system(
f"PYTHONPATH='AUTO_ROUND_PATH:$PYTHONPATH' {python_path} -m auto_round --model {model_name}"
f" --bs 16 --iters 0 --nsamples 1 --seqlen 16 --format gguf:q2_k_mixed"
)
if res > 0 or res == -1:
assert False, "cmd line test fail, please have a check"
shutil.rmtree("../../tmp_autoround", ignore_errors=True)

def test_vlm_gguf(self):
from ...helpers import save_tiny_model

model_name = get_model_path("Qwen/Qwen2-VL-2B-Instruct")
tiny_model_path = save_tiny_model(model_name, "./tmp/tiny_qwen_vl_model_path", num_layers=3, is_mllm=True)
from auto_round import AutoRoundMLLM
from auto_round.utils import mllm_load_model

model, processor, tokenizer, image_processor = mllm_load_model(model_name)
autoround = AutoRoundMLLM(
model,
tokenizer=tokenizer,
processor=processor,
image_processor=image_processor,
tiny_model_path,
iters=0,
nsamples=8,
disable_opt_rtn=True,
)
quantized_model_path = "./saved"
autoround.quantize_and_save(output_dir=quantized_model_path, format="gguf:q4_0")
assert "mmproj-model.gguf" in os.listdir("./saved")
for file_name in os.listdir(quantized_model_path):
file_size = os.path.getsize(os.path.join(quantized_model_path, file_name)) / 1024**2
if file_name == "mmproj-model.gguf":
assert abs(file_size - 2537) < 5.0
assert abs(file_size - 56) < 5.0
else:
assert abs(file_size - 892) < 5.0
assert abs(file_size - 264) < 5.0
shutil.rmtree("./saved", ignore_errors=True)
shutil.rmtree(tiny_model_path, ignore_errors=True)

def test_qtype_setting(self):
# Qwen2.5-0.5B-Instruct no output, token_embed q6_k fallbakc to q8_0 336M
Expand Down
Loading