Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 72 additions & 9 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
SUPPORTED_LAYER_TYPES,
TORCH_VERSION_AT_LEAST_2_6,
CpuInfo,
_generate_block_recipe,
_generate_recipe,
_gguf_args_check,
block_forward,
check_and_mark_fp8_model,
Expand Down Expand Up @@ -67,9 +69,9 @@
infer_bits_by_data_type,
init_cache,
is_debug_mode,
is_hpex_available,
is_mx_fp,
is_nv_fp,
is_optimum_habana_available,
is_standard_fp,
llm_load_model,
logger,
Expand Down Expand Up @@ -101,6 +103,11 @@ class AutoRound(object):
enable_torch_compile (bool): Whether to enable torch.compile for quant blocks/layers.
"""

# If function is not necessary for AutoRound, putting it in other place and
# assembling the function here, can improve code readability and maintainability
_generate_recipe = _generate_recipe
_generate_block_recipe = _generate_block_recipe

def __init__(
self,
model: Union[torch.nn.Module, str],
Expand Down Expand Up @@ -364,7 +371,7 @@ def __init__(

torch.set_printoptions(precision=3, sci_mode=True)

if is_optimum_habana_available():
if is_hpex_available():
logger.info("Optimum Habana is available, import htcore explicitly.")
import habana_frameworks.torch.core as htcore # pylint: disable=E0401
import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401]
Expand Down Expand Up @@ -430,6 +437,9 @@ def _adjust_torch_compile(self, enable_torch_compile: bool) -> None:
self.enable_torch_compile = False
logger.warning("reset enable_torch_compile to `False` as fp8 is enabled")

self.recipe_mode = False
self.recipe_results = {}

def _set_device_map_in_blocks(self, device_map: Union[str, dict, None]) -> None:
"""Sets the device map for specific blocks in the model.

Expand Down Expand Up @@ -1394,7 +1404,11 @@ def quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) -
if self.device_map is not None:
accelerate.hooks.remove_hook_from_submodules(block)

if is_nv_fp(self.act_data_type) and any("nv_fp" in format_ for format_ in self.formats):
if (
hasattr(self, "formats")
and is_nv_fp(self.act_data_type)
and any("nv_fp" in format_ for format_ in self.formats)
):
from auto_round.utils import set_amax_for_all_moe_layers

# enable moe experts act_max automatic generation for linears
Expand Down Expand Up @@ -1433,6 +1447,8 @@ def quantize(self):
m.tmp_name = n
self._check_compatibility()
self.has_qlayer_outside_block = self.set_layerwise_config(self.layer_config)
if not self.recipe_mode:
self._dump_average_bits() # leverage updated self.layer_config
if not hasattr(self, "formats"):
logger.warning("this API is deprecated, please use `quantize_and_save` instead")
else:
Expand Down Expand Up @@ -1549,6 +1565,8 @@ def quantize(self):
f"Expected exactly one packing format when 'is_packing_immediate' is True, "
f"but got {len(self.formats)} formats."
)
if self.recipe_mode:
return

self.quant_layers(layer_names, all_inputs) ##TODO pack layer immediately

Expand Down Expand Up @@ -2439,7 +2457,14 @@ def quantize_block(self, block, input_ids, input_others, q_input=None, device=to

modules = block.modules()
for module in modules:
update_fused_layer_global_scales(module)
try:
update_fused_layer_global_scales(module)
except:
# mix-precision may cause error, since q,k,v are not the same dtype.
logger.warning_once(
"Cannot keep the same global scale for fused layers, "
+ "so the model may not work with vLLM with fused QKV or else."
)
round_params = []
minmax_params = []
for n, m in block.named_modules():
Expand Down Expand Up @@ -2561,7 +2586,7 @@ def quantize_block(self, block, input_ids, input_others, q_input=None, device=to
logger.info(f"{unquantized_layer_names} have not been quantized")
with torch.no_grad():
unwrapper_block(block, best_params)
if self.enable_quanted_input:
if self.enable_quanted_input and hasattr(self, "formats"):
if is_nv_fp(self.act_data_type) and any("nv_fp" in format_ for format_ in self.formats):
from auto_round.utils import set_amax_for_all_moe_layers

Expand Down Expand Up @@ -2616,6 +2641,29 @@ def quantize_blocks(
clear_memory()
input_ids = to_device(input_ids, self.cache_device)
input_others = to_device(input_others, self.cache_device)
if self.recipe_mode:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to wrap this new code into a function and call it as early as possible.

pbar = tqdm(range(0, len(block_names), nblocks))
for i in range(0, len(block_names), nblocks):
if i != 0:
pbar.update(1)
if nblocks == 1:
n = block_names[i]
pbar.set_description(f"[Recipe Mode] Processing {n}")
block = get_module(model, n)
else:
names = block_names[i : min(i + nblocks, len(block_names))]
pbar.set_description(
f"[Recipe Mode] Processing [{i + 1}-{min(i + nblocks, len(block_names))}]/{len(block_names)}"
)
modules = [get_module(model, n) for n in names]
block = WrapperMultiblock(modules)
if not self.model.device.type == "meta" or self.low_cpu_mem_usage:
block = block.to(device)
block_recipe_results = self._generate_block_recipe(block, input_ids, input_others)
for result in block_recipe_results:
self.recipe_results.update({block_names[i] + "." + result: self.recipe_mp_dtype})
pbar.close()
return
## as in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage
tmp_dtype = self.amp_dtype if self.amp else torch.float32
for i in range(len(input_ids)):
Expand Down Expand Up @@ -2891,7 +2939,7 @@ def scale_loss_and_backward(self, scaler, loss):
"""
scale_loss = loss * 1000
scale_loss.backward()
if is_optimum_habana_available():
if is_hpex_available():
htcore.mark_step()
return scale_loss

Expand All @@ -2908,7 +2956,7 @@ def step(self, scaler, optimizer, lr_schedule):
"""
optimizer.step()
# for hpu
if is_optimum_habana_available():
if is_hpex_available():
htcore.mark_step()
optimizer.zero_grad()
lr_schedule.step()
Expand Down Expand Up @@ -2954,6 +3002,21 @@ def sampling_inputs(cls, input_ids, input_others, indices, seqlen, batch_dim=0,

return current_input_ids, current_input_others

def _dump_average_bits(self, layer_config=None):
Copy link
Contributor

@wenhuach21 wenhuach21 Aug 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function cannot be used by AutoRound, since layers are converted to QuantizedLinear after quantization. If the function can correctly dump average bits in typical scenarios such as INT4, I’d prefer to keep it in the class. Otherwise, it would be better to move it elsewhere for now.

"""Dumps the average bits of the model based on the layer configuration."""
total_numel = 0
total_bits = 0
for n, m in self.model.named_modules():
if isinstance(m, SUPPORTED_LAYER_TYPES):
m_numel = m.weight.numel()
layer_config = self.layer_config if layer_config is None else layer_config
m_bits = layer_config[n]["bits"] if n in layer_config else self.bits
total_numel += m_numel
total_bits += m_numel * m_bits
avg_bits = round(total_bits / total_numel, 3)
logger.info(f"current average bits of model: {avg_bits}")
return avg_bits


class AutoRoundAdam(AutoRound):
"""Class for automatic rounding-based quantization with optimizers like adamw of a PyTorch model.
Expand Down Expand Up @@ -3117,7 +3180,7 @@ def scale_loss_and_backward(self, scaler, loss):
loss = scaler.scale(loss)

loss.backward()
if is_optimum_habana_available():
if is_hpex_available():
htcore.mark_step()
return loss

Expand All @@ -3131,5 +3194,5 @@ def step(self, scaler, optimizer, lr_schedule):
optimizer.step()
optimizer.zero_grad()
lr_schedule.step()
if is_optimum_habana_available():
if is_hpex_available():
htcore.mark_step()
8 changes: 6 additions & 2 deletions auto_round/data_type/mxfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
revert_tensor_by_pad,
round_ste,
)
from auto_round.utils import is_hpex_available

MXFP_FORMAT_CACHE = {
# data type: ebits, mbits, emax, max_norm, min_norm
Expand Down Expand Up @@ -77,7 +78,6 @@ def quant_element(tensor, ebits, mbits, max_norm, mantissa_rounding="even"):
return tensor


@torch.compile()
def quant_mx(tensor, bits=4, group_size=-1, v=0, max_scale=1.0, mantissa_rounding="even", data_type="mx_fp", **kwargs):
"""Quantize the given tensor using the specified parameters.

Expand Down Expand Up @@ -128,7 +128,6 @@ def quant_mx(tensor, bits=4, group_size=-1, v=0, max_scale=1.0, mantissa_roundin
return tensor.to(orig_dtype), shared_exp.to(orig_dtype), None


@torch.compile()
def quant_mx_rceil(
tensor, bits=4, group_size=-1, v=0, max_scale=1.0, mantissa_rounding="even", data_type="mx_fp", **kwargs
):
Expand Down Expand Up @@ -180,6 +179,11 @@ def quant_mx_rceil(
return tensor.to(orig_dtype), shared_exp.to(orig_dtype), None


# HPU returns error with Habana software 1.22.0, so skip torch.compile here.
if not is_hpex_available():
quant_mx = torch.compile(quant_mx)
quant_mx_rceil = torch.compile(quant_mx_rceil)

for key in MXFP_FORMAT_CACHE.keys():
QUANT_FUNC_WITH_DTYPE[key] = quant_mx
QUANT_FUNC_WITH_DTYPE[key + "_rceil"] = quant_mx_rceil
Expand Down
4 changes: 4 additions & 0 deletions auto_round/data_type/nvfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def nv_fp4(tensor, bits=4, group_size=16, v=0, global_scale=None, **kwargs):
tensor_max = tensor.abs().max().to(torch.float32)
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX * get_reciprocal(tensor_max)
global_scale = global_scale.to(tensor.device)
# Ensure global_scale is in float32, sometimes tensor is in bf16/fp16
global_scale = global_scale.to(torch.float32)
qdq_res, scale = ref_nvfp4_quant(tensor, global_scale, group_size, v)
qdq_res = revert_tensor_by_pad(qdq_res, orig_shape=orig_shape, pad_len=pad_len)
return qdq_res.to(orig_dtype), scale, None
Expand All @@ -109,6 +111,8 @@ def nv_fp4_with_static_gs(tensor, bits=4, group_size=16, v=0, tensor_max=None, *
tensor_max = tensor.abs().max().to(torch.float32)
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX * get_reciprocal(tensor_max)
global_scale = global_scale.to(tensor.device)
# Ensure global_scale is in float32, sometimes tensor is in bf16/fp16
global_scale = global_scale.to(torch.float32)
qdq_res, scale = ref_nvfp4_quant(tensor, global_scale, group_size, v)
qdq_res = revert_tensor_by_pad(qdq_res, orig_shape=orig_shape, pad_len=pad_len)
return qdq_res.to(orig_dtype), scale, None
Expand Down
Loading