Skip to content
75 changes: 33 additions & 42 deletions auto_round/export/export_to_gguf/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,51 +229,42 @@ def _quant_data(cls, data_torch, data_qtype, name, modify_name, new_name, bid, d
"""
suffix = ".weight"
device = data_torch.device if device is None else device
if suffix in name:
layer_name = name[: -len(suffix)]
module = get_module(cls.model, layer_name)
kwargs = {
"scale": None,
"zp": None,
"d_scale": None,
"d_wmin": None,
"wmin": None,
"imatrix": None,
}
if hasattr(module, "scale"):

layer_name = name[: -len(suffix)]
module = get_module(cls.model, layer_name)
kwargs = {
"scale": None,
"zp": None,
"d_scale": None,
"d_wmin": None,
"wmin": None,
"imatrix": None,
}
# support for MOE model with cls expoers not linear
# if hasattr(module, "scale") or ("exps" in new_name and len(data_torch.shape) == 3):
for attr in ["scale", "zp", "w_d_scale", "w_d_wmin", "w_wmin"]:
if hasattr(module, attr) and getattr(module, attr) is not None:
attr_tensor = getattr(module, attr)
if not isinstance(attr_tensor, torch.Tensor):
continue
if hasattr(cls, "permute"):
bs = module.weight.shape[0]
for attr in ["scale", "zp", "w_d_scale", "w_d_wmin", "w_wmin"]:
if hasattr(module, attr) and getattr(module, attr) is not None:
attr_tensor = getattr(module, attr)
if not isinstance(attr_tensor, torch.Tensor):
continue
attr_tensors_dict = dict(cls.modify_tensors(attr_tensor.reshape(bs, -1), modify_name, bid))
attr_tensor = attr_tensors_dict[new_name]
if attr in kwargs:
kwargs[attr] = attr_tensor.to(torch.float32)
else:
kwargs[attr.replace("w_", "")] = attr_tensor.to(torch.float32)
data_torch = data_torch.to(torch.float32)
attr_tensors_dict = dict(cls.modify_tensors(attr_tensor.reshape(bs, -1), modify_name, bid))
attr_tensor = attr_tensors_dict[new_name]
if attr in kwargs:
kwargs[attr] = attr_tensor.to(torch.float32)
else:
kwargs[attr.replace("w_", "")] = attr_tensor.to(torch.float32)
data_torch = data_torch.to(torch.float32)

data = ggml_quant(data_torch, data_qtype.name.lower(), device=device, **kwargs)
else:
# if data_torch.dtype ==torch.float32:
# data_qtype = gguf.GGMLQuantizationType.F32
# else:
# data_qtype = gguf.GGMLQuantizationType.F16
data_qtype = gguf.GGMLQuantizationType.F32 ##FP16 has issues at inference
data = data_torch.to(torch.float32).squeeze().cpu().numpy()
else:
# for Llama-4
# if data_torch.dtype == torch.float32:
# data_qtype = gguf.GGMLQuantizationType.F32
# else:
# data_qtype = gguf.GGMLQuantizationType.F16
# data = data_torch.squeeze().cpu().numpy()
# data_qtype = gguf.GGMLQuantizationType.F32
# data = data_torch.to(torch.float32).squeeze().cpu().numpy()
data = ggml_quant(data_torch, data_qtype.name.lower(), device=device)
data = ggml_quant(data_torch, data_qtype.name.lower(), device=device, **kwargs)
# else:
# # if data_torch.dtype ==torch.float32:
# # data_qtype = gguf.GGMLQuantizationType.F32
# # else:
# # data_qtype = gguf.GGMLQuantizationType.F16
# data_qtype = gguf.GGMLQuantizationType.F32 ##FP16 has issues at inference
# data = data_torch.to(torch.float32).squeeze().cpu().numpy()
return data, data_qtype


Expand Down
Loading