Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __init__(self, *args, **kwargs):
basic.add_argument("--low_cpu_mem_usage", action="store_true", help="Lower CPU memory mode. Defaults to False.")
basic.add_argument(
"--format",
"--formats",
default="auto_round",
type=str,
help="Output format for the quantized model."
Expand Down Expand Up @@ -466,7 +467,7 @@ def list_item():
args = argparse.ArgumentParser()
args.add_argument("item", type=str, help="item to list, e.g., format")
args = args.parse_args()
if args.item == "format":
if args.item == "format" or args.item == "formats":
from auto_round.formats import OutputFormat

print("AutoRound supported output formats and quantization scheme:")
Expand Down
151 changes: 48 additions & 103 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,39 @@
)
from auto_round.wrapper import WrapperLinear, WrapperMultiblock, unwrapper_block, unwrapper_layer, wrapper_block

SERIALIZATION_KEYS = (
"bits",
"act_bits",
"data_type",
"act_data_type",
"group_size",
"act_group_size",
"sym",
"act_sym",
"act_dynamic",
"amp",
"batch_size",
"enable_minmax_tuning",
"enable_norm_bias_tuning",
"enable_quanted_input",
"gradient_accumulate_steps",
"iters",
"lr",
"low_gpu_mem_usage",
"minmax_lr",
"nsamples",
"quant_block_list",
"regex_config",
"scale_dtype",
"seqlen",
"supported_types",
"static_attention_dtype",
"static_kv_dtype",
"super_bits",
"super_group_size",
"to_quant_block_names",
)


class BaseCompressor(object):
"""Base compressor for LLM quantization
Expand Down Expand Up @@ -1105,35 +1138,17 @@ def _quantize_layer_via_rtn(self, name: str, dtype: torch.dtype = None, to_cpu=T
def _immediate_pack(self, name: str):
if not self.immediate_packing:
return
m = get_module(self.model, name)
if not check_to_quantized(m):
return
from auto_round.export import PACKING_LAYER_WITH_FORMAT

target_backend = self.formats[0].output_format
has_gguf = any(fmt.is_gguf() for fmt in self.formats)

if has_gguf:
from auto_round.export.export_to_gguf.export import pack_gguf_layer

output_dir = self._get_save_folder_name(self.formats[0])
model_type = ModelType.MMPROJ if self.mllm else ModelType.TEXT
pack_gguf_layer(
name,
self.model,
self.formats[0].get_backend_name(),
output_dir,
self.layer_config,
self.tokenizer,
processor=self.processor if hasattr(self, "processor") else None,
image_processor=self.image_processor if hasattr(self, "image_processor") else None,
model_type=model_type,
device=self.device,
)
else:
PACKING_LAYER_WITH_FORMAT[target_backend](
name, self.model, self.formats[0].get_backend_name(), device=self.device
)
self.formats[0].immediate_pack(
name=name,
model=self.model,
device=self.device,
output_dir=self._get_save_folder_name(self.formats[0]),
mllm=self.mllm,
layer_config=self.layer_config,
tokenizer=self.tokenizer,
processor=self.processor if hasattr(self, "processor") else None,
image_processor=self.image_processor if hasattr(self, "image_processor") else None,
)

@torch.inference_mode()
def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
Expand Down Expand Up @@ -2931,98 +2946,28 @@ def save_quantized(
folders = []
for format in formats:
save_folder = self._get_save_folder_name(format)
if format.is_fake(): # TODO fix act quantization later
self.model = self.model.to("cpu")
self.model.save_pretrained(output_dir)
if self.tokenizer is not None and hasattr(self.tokenizer, "save_pretrained"):
self.tokenizer.save_pretrained(output_dir)
processor = kwargs.get("processor", None)
if processor is not None:
processor.save_pretrained(output_dir)
try:
copy_python_files_from_model_cache(self.model, output_dir)
except Exception as e:
logger.warning("Skipping source model Python file copy due to error: %s", e)
compressed_model = self.model
continue
if self.act_bits <= 8 and format.is_fake():
logger.warning(
"Support for exporting activation quantization is limited. "
"Please ensure that your configuration is supported."
)
from auto_round.export import EXPORT_FORMAT

backend = format.get_backend_name()
output_format = format.output_format
if output_format not in EXPORT_FORMAT:
raise ValueError(f"export format only supports {EXPORT_FORMAT.keys()}, but got {output_format}")
save_quantized_as_format = EXPORT_FORMAT.get(output_format)
serialization_keys = [
"bits",
"group_size",
"sym",
"data_type",
"enable_quanted_input",
"enable_minmax_tuning",
"seqlen",
"batch_size",
"scale_dtype",
"lr",
"minmax_lr",
"gradient_accumulate_steps",
"iters",
"amp",
"nsamples",
"low_gpu_mem_usage",
"to_quant_block_names",
"enable_norm_bias_tuning",
"act_bits",
"act_group_size",
"act_sym",
"act_dynamic",
"act_data_type",
"super_bits",
"super_group_size",
"regex_config",
"static_kv_dtype",
"static_attention_dtype",
]
if isinstance(self.dataset, str):
serialization_keys.append("dataset")

serialization_dict = {}
for key in serialization_keys:
for key in SERIALIZATION_KEYS:
serialization_dict[key] = getattr(self, key)
from auto_round.version import __version__

serialization_dict["autoround_version"] = __version__
if "scale_dtype" in serialization_dict.keys():
serialization_dict["scale_dtype"] = str(serialization_dict["scale_dtype"])
compressed_model = save_quantized_as_format( # TODO refine the code
compressed_model = format.save_quantized(
save_folder,
model=self.model,
layer_config=self.layer_config,
inplace=inplace,
bits=self.bits,
act_bits=self.act_bits,
group_size=self.group_size,
sym=self.sym,
iters=self.iters,
lr=self.lr,
minmax_lr=self.minmax_lr,
enable_minmax_tuning=self.enable_minmax_tuning,
enable_quanted_input=self.enable_quanted_input,
scale_dtype=self.scale_dtype,
tokenizer=self.tokenizer,
supported_types=self.supported_types,
data_type=self.data_type,
act_data_type=self.act_data_type,
serialization_dict=serialization_dict,
backend=backend,
to_quant_block_names=self.to_quant_block_names,
quant_block_list=self.quant_block_list,
device=self.device,
static_kv_dtype=self.static_kv_dtype,
static_attention_dtype=self.static_attention_dtype,
serialization_dict=serialization_dict,
**kwargs,
)
folders.append(save_folder)
Expand Down
Loading