Skip to content
Merged
69 changes: 38 additions & 31 deletions auto_round/export/export_to_gguf/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def download_convert_file(redownload=False):
f" from https://github.com/ggml-org/llama.cpp manually and move it to {gguf_export_dir}."
)
sys.exit(-1)
with open(os.path.join(gguf_export_dir, FILE_NAME), "w") as f:
with open(os.path.join(gguf_export_dir, FILE_NAME), "w", encoding="utf-8") as f:
f.write(response.text)


Expand Down Expand Up @@ -214,12 +214,32 @@ def _quant_data_with_args(
return data


def _quant_data(cls, data_torch, data_qtype, name, modify_name, bid, device=None):
def _quant_data(cls, data_torch, data_qtype, name, modify_name, new_name, bid, device=None):
"""

Args:
data_torch: original data tensor
data_qtype: quantization type
name: original tensor name, for getting auto_round config, model.language_model.layers.0.input_linear.weight
modify_name: modified tensor name, for gguf mapping, model.layers.0.input_linear.weight
new_name: name after modify_tensors, gguf using this to save tensor, like blk.0.ffn_gate.weight
bid: block id
device: device to perform quantization. Defaults to None.

"""
suffix = ".weight"
device = data_torch.device if device is None else device
if suffix in name:
layer_name = name[: -len(suffix)]
module = get_module(cls.model, layer_name)
kwargs = {
"scale": None,
"zp": None,
"d_scale": None,
"d_wmin": None,
"wmin": None,
"imatrix": None,
}
if hasattr(module, "scale"):
if hasattr(cls, "permute"):
bs = module.weight.shape[0]
Expand All @@ -228,34 +248,17 @@ def _quant_data(cls, data_torch, data_qtype, name, modify_name, bid, device=None
attr_tensor = getattr(module, attr)
if not isinstance(attr_tensor, torch.Tensor):
continue
ori_shape = attr_tensor.shape
attr_tensor = cls.modify_tensors(attr_tensor.reshape(bs, -1), modify_name, bid)[0][1]
attr_tensor = attr_tensor.reshape(ori_shape)
setattr(module, attr, attr_tensor)
scale = module.scale if hasattr(module, "scale") else None
zp = module.zp if hasattr(module, "zp") else None
attr_tensors_dict = dict(cls.modify_tensors(attr_tensor.reshape(bs, -1), modify_name, bid))
attr_tensor = attr_tensors_dict[new_name]
if attr in kwargs:
if attr != "imatrix":
attr_tensor = attr_tensor.to(torch.float32)
kwargs[attr] = attr_tensor
else:
kwargs[attr.replace("w_", "")] = attr_tensor
data_torch = data_torch.to(torch.float32)
scale = scale.to(torch.float32) if isinstance(scale, torch.Tensor) else scale
zp = zp.to(torch.float32) if isinstance(zp, torch.Tensor) else zp
if data_qtype.name.lower().endswith("_k"):
d_scale = module.w_d_scale.to(torch.float32) if hasattr(module, "w_d_scale") else None
d_wmin = module.w_d_wmin.to(torch.float32) if hasattr(module, "w_d_wmin") else None
wmin = module.w_wmin.to(torch.float32) if hasattr(module, "w_wmin") else None
imatrix = module.imatrix.to(torch.float32) if hasattr(module, "imatrix") else None

data = ggml_quant(
data_torch,
data_qtype.name.lower(),
scale,
zp,
wmin=wmin,
d_scale=d_scale,
d_wmin=d_wmin,
imatrix=imatrix,
device=device,
)
else:
data = ggml_quant(data_torch, data_qtype.name.lower(), scale, zp, device=device)

data = ggml_quant(data_torch, data_qtype.name.lower(), device=device, **kwargs)
else:
# if data_torch.dtype ==torch.float32:
# data_qtype = gguf.GGMLQuantizationType.F32
Expand Down Expand Up @@ -583,12 +586,16 @@ def prepare_tensors(cls):
if arr_name[i].isdecimal() and int(arr_name[i]) == (data_torch.shape[0] - 1):
arr_name[i] = str(idx)
arr_name = ".".join(arr_name)
arr, data_qtype = _quant_data(cls, arr, data_qtype, arr_name, modify_name, bid, device=device)
arr, data_qtype = _quant_data(
cls, arr, data_qtype, arr_name, modify_name, new_name, bid, device=device
)
new_data.append(arr)
data = np.array(new_data)
del new_data
else:
data, data_qtype = _quant_data(cls, data_torch, data_qtype, name, modify_name, bid, device=device)
data, data_qtype = _quant_data(
cls, data_torch, data_qtype, name, modify_name, new_name, bid, device=device
)

shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape

Expand Down
6 changes: 5 additions & 1 deletion auto_round/export/export_to_gguf/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ def create_model_class(
raise TypeError(f"{output_type} type is not supported")
output_type = FTYPE_MAP.get(output_type.lower())

hparams = convert_hf_to_gguf.ModelBase.load_hparams(Path(tmp_work_dir), "mistral" in model.config.model_type)
if "mistral" in model.config.model_type and "params.json" in os.listdir(tmp_work_dir):
is_mistral_format = True
else:
is_mistral_format = False
hparams = convert_hf_to_gguf.ModelBase.load_hparams(Path(tmp_work_dir), is_mistral_format)
hparams.pop("quantization_config", None)
model_instance = model_class(
dir_model=Path(tmp_work_dir),
Expand Down
5 changes: 4 additions & 1 deletion auto_round/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,10 @@ def get_gguf_architecture(dir_model, model_type=ModelType.TEXT):
tmp_model_type = hparams.model_type
if "mistral" == tmp_model_type:
is_mistral_format = True
hparams = ModelBase.load_hparams(dir_model, is_mistral_format)
try:
hparams = ModelBase.load_hparams(dir_model, is_mistral_format)
except Exception:
is_mistral_format = False
if not is_mistral_format:
model_class = get_model_architecture(hparams, model_type)
elif model_type == ModelType.MMPROJ:
Expand Down
Loading