diff --git a/auto_round/export/export_to_gguf/convert.py b/auto_round/export/export_to_gguf/convert.py index 96847c2ea..1b7e0735d 100644 --- a/auto_round/export/export_to_gguf/convert.py +++ b/auto_round/export/export_to_gguf/convert.py @@ -80,7 +80,7 @@ def download_convert_file(redownload=False): f" from https://github.com/ggml-org/llama.cpp manually and move it to {gguf_export_dir}." ) sys.exit(-1) - with open(os.path.join(gguf_export_dir, FILE_NAME), "w") as f: + with open(os.path.join(gguf_export_dir, FILE_NAME), "w", encoding="utf-8") as f: f.write(response.text) @@ -214,12 +214,32 @@ def _quant_data_with_args( return data -def _quant_data(cls, data_torch, data_qtype, name, modify_name, bid, device=None): +def _quant_data(cls, data_torch, data_qtype, name, modify_name, new_name, bid, device=None): + """ + + Args: + data_torch: original data tensor + data_qtype: quantization type + name: original tensor name, for getting auto_round config, model.language_model.layers.0.input_linear.weight + modify_name: modified tensor name, for gguf mapping, model.layers.0.input_linear.weight + new_name: name after modify_tensors, gguf using this to save tensor, like blk.0.ffn_gate.weight + bid: block id + device: device to perform quantization. Defaults to None. + + """ suffix = ".weight" device = data_torch.device if device is None else device if suffix in name: layer_name = name[: -len(suffix)] module = get_module(cls.model, layer_name) + kwargs = { + "scale": None, + "zp": None, + "d_scale": None, + "d_wmin": None, + "wmin": None, + "imatrix": None, + } if hasattr(module, "scale"): if hasattr(cls, "permute"): bs = module.weight.shape[0] @@ -228,34 +248,15 @@ def _quant_data(cls, data_torch, data_qtype, name, modify_name, bid, device=None attr_tensor = getattr(module, attr) if not isinstance(attr_tensor, torch.Tensor): continue - ori_shape = attr_tensor.shape - attr_tensor = cls.modify_tensors(attr_tensor.reshape(bs, -1), modify_name, bid)[0][1] - attr_tensor = attr_tensor.reshape(ori_shape) - setattr(module, attr, attr_tensor) - scale = module.scale if hasattr(module, "scale") else None - zp = module.zp if hasattr(module, "zp") else None + attr_tensors_dict = dict(cls.modify_tensors(attr_tensor.reshape(bs, -1), modify_name, bid)) + attr_tensor = attr_tensors_dict[new_name] + if attr in kwargs: + kwargs[attr] = attr_tensor.to(torch.float32) + else: + kwargs[attr.replace("w_", "")] = attr_tensor.to(torch.float32) data_torch = data_torch.to(torch.float32) - scale = scale.to(torch.float32) if isinstance(scale, torch.Tensor) else scale - zp = zp.to(torch.float32) if isinstance(zp, torch.Tensor) else zp - if data_qtype.name.lower().endswith("_k"): - d_scale = module.w_d_scale.to(torch.float32) if hasattr(module, "w_d_scale") else None - d_wmin = module.w_d_wmin.to(torch.float32) if hasattr(module, "w_d_wmin") else None - wmin = module.w_wmin.to(torch.float32) if hasattr(module, "w_wmin") else None - imatrix = module.imatrix.to(torch.float32) if hasattr(module, "imatrix") else None - - data = ggml_quant( - data_torch, - data_qtype.name.lower(), - scale, - zp, - wmin=wmin, - d_scale=d_scale, - d_wmin=d_wmin, - imatrix=imatrix, - device=device, - ) - else: - data = ggml_quant(data_torch, data_qtype.name.lower(), scale, zp, device=device) + + data = ggml_quant(data_torch, data_qtype.name.lower(), device=device, **kwargs) else: # if data_torch.dtype ==torch.float32: # data_qtype = gguf.GGMLQuantizationType.F32 @@ -583,12 +584,16 @@ def prepare_tensors(cls): if arr_name[i].isdecimal() and int(arr_name[i]) == (data_torch.shape[0] - 1): arr_name[i] = str(idx) arr_name = ".".join(arr_name) - arr, data_qtype = _quant_data(cls, arr, data_qtype, arr_name, modify_name, bid, device=device) + arr, data_qtype = _quant_data( + cls, arr, data_qtype, arr_name, modify_name, new_name, bid, device=device + ) new_data.append(arr) data = np.array(new_data) del new_data else: - data, data_qtype = _quant_data(cls, data_torch, data_qtype, name, modify_name, bid, device=device) + data, data_qtype = _quant_data( + cls, data_torch, data_qtype, name, modify_name, new_name, bid, device=device + ) shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape diff --git a/auto_round/export/export_to_gguf/export.py b/auto_round/export/export_to_gguf/export.py index df7264dd4..c1add0ed5 100644 --- a/auto_round/export/export_to_gguf/export.py +++ b/auto_round/export/export_to_gguf/export.py @@ -99,7 +99,11 @@ def create_model_class( raise TypeError(f"{output_type} type is not supported") output_type = FTYPE_MAP.get(output_type.lower()) - hparams = convert_hf_to_gguf.ModelBase.load_hparams(Path(tmp_work_dir), "mistral" in model.config.model_type) + if "mistral" in model.config.model_type and "params.json" in os.listdir(tmp_work_dir): + is_mistral_format = True + else: + is_mistral_format = False + hparams = convert_hf_to_gguf.ModelBase.load_hparams(Path(tmp_work_dir), is_mistral_format) hparams.pop("quantization_config", None) model_instance = model_class( dir_model=Path(tmp_work_dir), diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 4a40e9626..8b85bc26e 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -810,7 +810,10 @@ def get_gguf_architecture(dir_model, model_type=ModelType.TEXT): tmp_model_type = hparams.model_type if "mistral" == tmp_model_type: is_mistral_format = True - hparams = ModelBase.load_hparams(dir_model, is_mistral_format) + try: + hparams = ModelBase.load_hparams(dir_model, is_mistral_format) + except Exception: + is_mistral_format = False if not is_mistral_format: model_class = get_model_architecture(hparams, model_type) elif model_type == ModelType.MMPROJ: