Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 168 additions & 2 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,9 @@ def _adjust_torch_compile(self, enable_torch_compile: bool) -> None:
self.enable_torch_compile = False
logger.warning("reset enable_torch_compile to `False` as fp8 is enabled")

self.recipe_mode = False
self.recipe_results = {}

def _set_device_map_in_blocks(self, device_map: Union[str, dict, None]) -> None:
"""Sets the device map for specific blocks in the model.

Expand Down Expand Up @@ -1433,6 +1436,8 @@ def quantize(self):
m.tmp_name = n
self._check_compatibility()
self.has_qlayer_outside_block = self.set_layerwise_config(self.layer_config)
if not self.recipe_mode:
self._dump_average_bits() # leverage updated self.layer_config
if not hasattr(self, "formats"):
logger.warning("this API is deprecated, please use `quantize_and_save` instead")
else:
Expand Down Expand Up @@ -1549,6 +1554,8 @@ def quantize(self):
f"Expected exactly one packing format when 'is_packing_immediate' is True, "
f"but got {len(self.formats)} formats."
)
if self.recipe_mode:
return

self.quant_layers(layer_names, all_inputs) ##TODO pack layer immediately

Expand Down Expand Up @@ -2439,7 +2446,10 @@ def quantize_block(self, block, input_ids, input_others, q_input=None, device=to

modules = block.modules()
for module in modules:
update_fused_layer_global_scales(module)
try:
update_fused_layer_global_scales(module)
except:
pass # mix-precision may cause error, since q,k,v are not the same dtype.
round_params = []
minmax_params = []
for n, m in block.named_modules():
Expand Down Expand Up @@ -2561,7 +2571,7 @@ def quantize_block(self, block, input_ids, input_others, q_input=None, device=to
logger.info(f"{unquantized_layer_names} have not been quantized")
with torch.no_grad():
unwrapper_block(block, best_params)
if self.enable_quanted_input:
if self.enable_quanted_input and hasattr(self, "formats"):
if is_nv_fp(self.act_data_type) and any("nv_fp" in format_ for format_ in self.formats):
from auto_round.utils import set_amax_for_all_moe_layers

Expand Down Expand Up @@ -2616,6 +2626,26 @@ def quantize_blocks(
clear_memory()
input_ids = to_device(input_ids, self.cache_device)
input_others = to_device(input_others, self.cache_device)
if self.recipe_mode:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to wrap this new code into a function and call it as early as possible.

pbar = tqdm(range(0, len(block_names), nblocks))
for i in range(0, len(block_names), nblocks):
if i != 0:
pbar.update(1)
if nblocks == 1:
n = block_names[i]
pbar.set_description(f"[Recipe Mode] Processing {n}")
block = get_module(model, n)
else:
names = block_names[i : min(i + nblocks, len(block_names))]
pbar.set_description(
f"[Recipe Mode] Processing [{i + 1}-{min(i + nblocks, len(block_names))}]/{len(block_names)}"
)
modules = [get_module(model, n) for n in names]
block = WrapperMultiblock(modules)
block_recipe_results = self._generate_block_recipe(block, input_ids, input_others)
for result in block_recipe_results:
self.recipe_results.update({block_names[i] + "." + result: self.recipe_mp_dtype})
return
## as in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage
tmp_dtype = self.amp_dtype if self.amp else torch.float32
for i in range(len(input_ids)):
Expand Down Expand Up @@ -2954,6 +2984,142 @@ def sampling_inputs(cls, input_ids, input_others, indices, seqlen, batch_dim=0,

return current_input_ids, current_input_others

def _generate_recipe(
self,
# same data type config as before
mp_dtype={
"data_type": "mx_fp8",
"act_data_type": "mx_fp8",
},
# special mix-precision configuration
mp_config={
"mp_ratio": 1 / 3,
"loss_weight": 2.0,
"numel_weight": 1.0,
},
):
self.recipe_mode = True
self.recipe_mp_dtype = mp_dtype
self.recipe_mp_config = mp_config
self.quantize()
recipe_layer_config = copy.deepcopy(self.layer_config)
recipe_layer_config.update(self.recipe_results)
self._dump_average_bits(layer_config=recipe_layer_config)
self.recipe_mode = False
recipe_layer_config.pop("lm_head") # lm_head is not included in the recipe
return recipe_layer_config

def _generate_block_recipe(self, block, input_ids, input_others):
from itertools import combinations

from auto_round.utils import (
DTYPE_INFO_MAPPING,
create_mp_block,
get_best_combination,
get_numel,
recover_mp_block,
)

# fetch mix-precision recipe configuration
sample_num = self.recipe_mp_config.get("sample_num", 8)
mp_ratio = self.recipe_mp_config.get("mp_ratio", 1 / 7)
loss_weight = float(self.recipe_mp_config.get("loss_weight", 2.0))
numel_weight = float(self.recipe_mp_config.get("numel_weight", 1.0))
loss_numel_ratio = loss_weight / numel_weight

# calculate the number of layers to use mix-precision
quantizable_layers = [n for n, m in block.named_modules() if isinstance(m, SUPPORTED_LAYER_TYPES)]
quantizable_num = int(mp_ratio * len(quantizable_layers)) # It's ceiling
# fetch raw low-bits dtype of block for recovering mix-precision block
layer = get_module(block, quantizable_layers[0])
raw_dtype = {
"data_type": layer.data_type,
"bits": layer.bits,
"sym": layer.sym,
"act_data_type": layer.act_data_type,
"act_bits": layer.act_bits,
"act_sym": layer.act_sym,
}
# update self.recipe_mp_dtype
self.recipe_mp_dtype.update(
{
"bits": DTYPE_INFO_MAPPING[self.recipe_mp_dtype["data_type"]]["bits"],
"group_size": DTYPE_INFO_MAPPING[self.recipe_mp_dtype["data_type"]]["group_size"],
"sym": DTYPE_INFO_MAPPING[self.recipe_mp_dtype["data_type"]]["sym"],
"act_bits": DTYPE_INFO_MAPPING[self.recipe_mp_dtype["act_data_type"]]["bits"],
"act_group_size": DTYPE_INFO_MAPPING[self.recipe_mp_dtype["act_data_type"]]["group_size"],
"act_sym": DTYPE_INFO_MAPPING[self.recipe_mp_dtype["act_data_type"]]["sym"],
}
)

# generate reference output of sample input_ids
reference_output = self.get_block_outputs(
block,
input_ids[:sample_num],
input_others,
bs=self.batch_size,
device=self.device,
cache_device=self.cache_device,
save_output=True,
)

# generate q_output of sample input_ids and get loss
def get_loss(q_block):
q_output = self.get_block_outputs(
q_block,
input_ids[:sample_num],
input_others,
bs=self.batch_size,
device=self.device,
cache_device=self.cache_device,
save_output=True,
)
total_loss = 0
mse_loss = torch.nn.MSELoss(reduction="sum").to(self.device)
for i in range(len(q_output)):
loss = mse_loss( # pylint: disable=not-callable
q_output[i].to(torch.float32), reference_output[i].to(torch.float32)
)
total_loss += loss
if is_optimum_habana_available():
htcore.mark_step()
return loss

combination_list = []
numel_list = []
loss_list = []
for hp_layers in combinations(quantizable_layers, quantizable_num):
combination_list.append(hp_layers)
# get numel
numel = get_numel(block, hp_layers)
numel_list.append(numel)
# get loss
block = create_mp_block(block, hp_layers, self.recipe_mp_dtype)
loss = get_loss(block)
loss_list.append(loss)
block = recover_mp_block(block, hp_layers, raw_dtype)
if is_optimum_habana_available():
htcore.mark_step()
logger.debug(f"{hp_layers}, {loss}, {numel}")

hp_layers = get_best_combination(combination_list, numel_list, loss_list, loss_numel_ratio)
logger.info(f"final hp layers: {hp_layers}")
return hp_layers

def _dump_average_bits(self, layer_config=None):
Copy link
Contributor

@wenhuach21 wenhuach21 Aug 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function cannot be used by AutoRound, since layers are converted to QuantizedLinear after quantization. If the function can correctly dump average bits in typical scenarios such as INT4, I’d prefer to keep it in the class. Otherwise, it would be better to move it elsewhere for now.

total_numel = 0
total_bits = 0
for n, m in self.model.named_modules():
if isinstance(m, SUPPORTED_LAYER_TYPES):
m_numel = m.weight.numel()
layer_config = self.layer_config if layer_config is None else layer_config
m_bits = layer_config[n]["bits"] if n in layer_config else self.bits
total_numel += m_numel
total_bits += m_numel * m_bits
avg_bits = round(total_bits / total_numel, 3)
logger.info(f"current average bits of model: {avg_bits}")
return avg_bits


class AutoRoundAdam(AutoRound):
"""Class for automatic rounding-based quantization with optimizers like adamw of a PyTorch model.
Expand Down
2 changes: 0 additions & 2 deletions auto_round/data_type/mxfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def quant_element(tensor, ebits, mbits, max_norm, mantissa_rounding="even"):
return tensor


@torch.compile()
def quant_mx(tensor, bits=4, group_size=-1, v=0, max_scale=1.0, mantissa_rounding="even", data_type="mx_fp", **kwargs):
"""Quantize the given tensor using the specified parameters.

Expand Down Expand Up @@ -128,7 +127,6 @@ def quant_mx(tensor, bits=4, group_size=-1, v=0, max_scale=1.0, mantissa_roundin
return tensor.to(orig_dtype), shared_exp.to(orig_dtype), None


@torch.compile()
def quant_mx_rceil(
tensor, bits=4, group_size=-1, v=0, max_scale=1.0, mantissa_rounding="even", data_type="mx_fp", **kwargs
):
Expand Down
2 changes: 2 additions & 0 deletions auto_round/data_type/nvfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def nv_fp4(tensor, bits=4, group_size=16, v=0, global_scale=None, **kwargs):
tensor_max = tensor.abs().max().to(torch.float32)
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX * get_reciprocal(tensor_max)
global_scale = global_scale.to(tensor.device)
global_scale = global_scale.to(torch.float32)
qdq_res, scale = ref_nvfp4_quant(tensor, global_scale, group_size, v)
qdq_res = revert_tensor_by_pad(qdq_res, orig_shape=orig_shape, pad_len=pad_len)
return qdq_res.to(orig_dtype), scale, None
Expand All @@ -109,6 +110,7 @@ def nv_fp4_with_static_gs(tensor, bits=4, group_size=16, v=0, tensor_max=None, *
tensor_max = tensor.abs().max().to(torch.float32)
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX * get_reciprocal(tensor_max)
global_scale = global_scale.to(tensor.device)
global_scale = global_scale.to(torch.float32)
qdq_res, scale = ref_nvfp4_quant(tensor, global_scale, group_size, v)
qdq_res = revert_tensor_by_pad(qdq_res, orig_shape=orig_shape, pad_len=pad_len)
return qdq_res.to(orig_dtype), scale, None
Expand Down
81 changes: 81 additions & 0 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ def __getitem__(self, key):

SUPPORTED_LAYER_TYPES = SUPPORTED_LAYER_TYPES + (LinearLayer, LinearAllreduce)

DTYPE_INFO_MAPPING = {
"nv_fp4": {"bits": 4, "group_size": 16, "sym": True},
"mx_fp4": {"bits": 4, "group_size": 32, "sym": True},
"mx_fp8": {"bits": 8, "group_size": 32, "sym": True},
}


def infer_bits_by_data_type(data_type: str):
if data_type is None:
Expand Down Expand Up @@ -2502,3 +2508,78 @@ def is_mx_fp(backend):

def is_nv_fp(backend):
return BackendDataType.NV_FP in backend


######################################### Recipe related codes ####################################
def create_mp_block(block, mp_layers, mp_dtype):
from auto_round.wrapper import WrapperLinear

for layer_name in mp_layers:
layer = get_module(block, layer_name)
layer.data_type, layer.bits, layer.sym = mp_dtype["data_type"], mp_dtype["bits"], mp_dtype["sym"]
layer.act_data_type, layer.act_bits, layer.act_sym = (
mp_dtype["act_data_type"],
mp_dtype["act_bits"],
mp_dtype["act_sym"],
)
for n, m in block.named_modules():
if isinstance(m, SUPPORTED_LAYER_TYPES):
if check_to_quantized(m):
new_m = WrapperLinear(
m,
enable_minmax_tuning=False,
enable_norm_bias_tuning=False,
device=m.weight.device,
)
set_module(block, n, new_m)
if is_optimum_habana_available():
htcore.mark_step()
return block


def recover_mp_block(block, mp_layers, raw_dtype):
from auto_round.wrapper import WrapperLinear

for n, m in block.named_modules():
if isinstance(m, WrapperLinear):
set_module(block, n, m.orig_layer)
for layer_name in mp_layers:
layer = get_module(block, layer_name)
layer.data_type, layer.bits, layer.sym = raw_dtype["data_type"], raw_dtype["bits"], raw_dtype["sym"]
layer.act_data_type, layer.act_bits, layer.act_sym = (
raw_dtype["act_data_type"],
raw_dtype["act_bits"],
raw_dtype["act_sym"],
)
if is_optimum_habana_available():
htcore.mark_step()
return block


def get_numel(block, hp_layers):
numel = 0
for layer_name in hp_layers:
layer = get_module(block, layer_name)
numel += layer.weight.numel()
return numel


def get_best_combination(combination_list, numel_list, loss_list, loss_numel_ratio=2.0):
# Get ranks for numel_list and
numel_ranks = [sorted(numel_list).index(x) for x in numel_list]
loss_ranks = [(sorted(loss_list).index(x)) * loss_numel_ratio for x in loss_list]

# Calculate rank sums
rank_sums = [x + y for x, y in zip(numel_ranks, loss_ranks)]
logger.debug(f"numel_ranks: {numel_ranks}")
logger.debug(f"loss_ranks: {loss_ranks}")
logger.debug(f"rank sum: {rank_sums}")

# Find the index of the smallest rank sum
best_index = rank_sums.index(min(rank_sums))

# Return the best combination
return combination_list[best_index]


###############################################################################################
4 changes: 4 additions & 0 deletions auto_round/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(
self.orig_layer = orig_layer
self.output_device = device
self.device = self.orig_layer.tuning_device if hasattr(self.orig_layer, "tuning_device") else device
self.extra_repr_org = orig_layer.extra_repr
self.enable_minmax_tuning = enable_minmax_tuning
self.enable_round_tuning = enable_round_tuning
self.enable_norm_bias_tuning = enable_norm_bias_tuning and (orig_layer.bias is not None)
Expand Down Expand Up @@ -451,6 +452,9 @@ def forward(self, x):
output = self.orig_forward(x, weight_q, bias).to(self.output_device)
return output

def extra_repr(self):
return f"{self.extra_repr_org()}, weight_type={self.data_type}, act_data_type={self.act_data_type}"


class WrapperWALayer(torch.nn.Module):
def __init__(self, orig_layer):
Expand Down
Loading