-
Notifications
You must be signed in to change notification settings - Fork 52
add autoround._generate_recipe() #758
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
d8b831e
a78f99c
62f81e5
79323d6
086eae2
b67c79a
715e2a1
5a631f1
2a0e0b8
f02331d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,6 +38,8 @@ | |
SUPPORTED_LAYER_TYPES, | ||
TORCH_VERSION_AT_LEAST_2_6, | ||
CpuInfo, | ||
_generate_block_recipe, | ||
_generate_recipe, | ||
_gguf_args_check, | ||
block_forward, | ||
check_and_mark_fp8_model, | ||
|
@@ -67,9 +69,9 @@ | |
infer_bits_by_data_type, | ||
init_cache, | ||
is_debug_mode, | ||
is_hpex_available, | ||
is_mx_fp, | ||
is_nv_fp, | ||
is_optimum_habana_available, | ||
is_standard_fp, | ||
llm_load_model, | ||
logger, | ||
|
@@ -101,6 +103,11 @@ class AutoRound(object): | |
enable_torch_compile (bool): Whether to enable torch.compile for quant blocks/layers. | ||
""" | ||
|
||
# If function is not necessary for AutoRound, putting it in other place and | ||
# assembling the function here, can improve code readability and maintainability | ||
_generate_recipe = _generate_recipe | ||
_generate_block_recipe = _generate_block_recipe | ||
|
||
def __init__( | ||
self, | ||
model: Union[torch.nn.Module, str], | ||
|
@@ -364,7 +371,7 @@ def __init__( | |
|
||
torch.set_printoptions(precision=3, sci_mode=True) | ||
|
||
if is_optimum_habana_available(): | ||
if is_hpex_available(): | ||
logger.info("Optimum Habana is available, import htcore explicitly.") | ||
import habana_frameworks.torch.core as htcore # pylint: disable=E0401 | ||
import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401] | ||
|
@@ -430,6 +437,9 @@ def _adjust_torch_compile(self, enable_torch_compile: bool) -> None: | |
self.enable_torch_compile = False | ||
logger.warning("reset enable_torch_compile to `False` as fp8 is enabled") | ||
|
||
self.recipe_mode = False | ||
self.recipe_results = {} | ||
|
||
def _set_device_map_in_blocks(self, device_map: Union[str, dict, None]) -> None: | ||
"""Sets the device map for specific blocks in the model. | ||
|
||
|
@@ -1394,7 +1404,11 @@ def quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) - | |
if self.device_map is not None: | ||
accelerate.hooks.remove_hook_from_submodules(block) | ||
|
||
if is_nv_fp(self.act_data_type) and any("nv_fp" in format_ for format_ in self.formats): | ||
if ( | ||
hasattr(self, "formats") | ||
and is_nv_fp(self.act_data_type) | ||
and any("nv_fp" in format_ for format_ in self.formats) | ||
): | ||
from auto_round.utils import set_amax_for_all_moe_layers | ||
|
||
# enable moe experts act_max automatic generation for linears | ||
|
@@ -1433,6 +1447,8 @@ def quantize(self): | |
m.tmp_name = n | ||
self._check_compatibility() | ||
self.has_qlayer_outside_block = self.set_layerwise_config(self.layer_config) | ||
if not self.recipe_mode: | ||
self._dump_average_bits() # leverage updated self.layer_config | ||
if not hasattr(self, "formats"): | ||
logger.warning("this API is deprecated, please use `quantize_and_save` instead") | ||
else: | ||
|
@@ -1549,6 +1565,8 @@ def quantize(self): | |
f"Expected exactly one packing format when 'is_packing_immediate' is True, " | ||
f"but got {len(self.formats)} formats." | ||
) | ||
if self.recipe_mode: | ||
return | ||
|
||
self.quant_layers(layer_names, all_inputs) ##TODO pack layer immediately | ||
|
||
|
@@ -2439,7 +2457,14 @@ def quantize_block(self, block, input_ids, input_others, q_input=None, device=to | |
|
||
modules = block.modules() | ||
for module in modules: | ||
update_fused_layer_global_scales(module) | ||
try: | ||
update_fused_layer_global_scales(module) | ||
except: | ||
# mix-precision may cause error, since q,k,v are not the same dtype. | ||
logger.warning_once( | ||
"Cannot keep the same global scale for fused layers, " | ||
+ "so the model may not work with vLLM with fused QKV or else." | ||
) | ||
round_params = [] | ||
minmax_params = [] | ||
for n, m in block.named_modules(): | ||
|
@@ -2561,7 +2586,7 @@ def quantize_block(self, block, input_ids, input_others, q_input=None, device=to | |
logger.info(f"{unquantized_layer_names} have not been quantized") | ||
with torch.no_grad(): | ||
unwrapper_block(block, best_params) | ||
if self.enable_quanted_input: | ||
if self.enable_quanted_input and hasattr(self, "formats"): | ||
if is_nv_fp(self.act_data_type) and any("nv_fp" in format_ for format_ in self.formats): | ||
from auto_round.utils import set_amax_for_all_moe_layers | ||
|
||
|
@@ -2616,6 +2641,29 @@ def quantize_blocks( | |
clear_memory() | ||
input_ids = to_device(input_ids, self.cache_device) | ||
input_others = to_device(input_others, self.cache_device) | ||
if self.recipe_mode: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be better to wrap this new code into a function and call it as early as possible. |
||
pbar = tqdm(range(0, len(block_names), nblocks)) | ||
for i in range(0, len(block_names), nblocks): | ||
if i != 0: | ||
pbar.update(1) | ||
if nblocks == 1: | ||
n = block_names[i] | ||
pbar.set_description(f"[Recipe Mode] Processing {n}") | ||
block = get_module(model, n) | ||
else: | ||
names = block_names[i : min(i + nblocks, len(block_names))] | ||
pbar.set_description( | ||
f"[Recipe Mode] Processing [{i + 1}-{min(i + nblocks, len(block_names))}]/{len(block_names)}" | ||
) | ||
modules = [get_module(model, n) for n in names] | ||
block = WrapperMultiblock(modules) | ||
if not self.model.device.type == "meta" or self.low_cpu_mem_usage: | ||
block = block.to(device) | ||
block_recipe_results = self._generate_block_recipe(block, input_ids, input_others) | ||
for result in block_recipe_results: | ||
self.recipe_results.update({block_names[i] + "." + result: self.recipe_mp_dtype}) | ||
pbar.close() | ||
return | ||
## as in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage | ||
tmp_dtype = self.amp_dtype if self.amp else torch.float32 | ||
for i in range(len(input_ids)): | ||
|
@@ -2891,7 +2939,7 @@ def scale_loss_and_backward(self, scaler, loss): | |
""" | ||
scale_loss = loss * 1000 | ||
scale_loss.backward() | ||
if is_optimum_habana_available(): | ||
if is_hpex_available(): | ||
htcore.mark_step() | ||
return scale_loss | ||
|
||
|
@@ -2908,7 +2956,7 @@ def step(self, scaler, optimizer, lr_schedule): | |
""" | ||
optimizer.step() | ||
# for hpu | ||
if is_optimum_habana_available(): | ||
if is_hpex_available(): | ||
htcore.mark_step() | ||
optimizer.zero_grad() | ||
lr_schedule.step() | ||
|
@@ -2954,6 +3002,21 @@ def sampling_inputs(cls, input_ids, input_others, indices, seqlen, batch_dim=0, | |
|
||
return current_input_ids, current_input_others | ||
|
||
def _dump_average_bits(self, layer_config=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function cannot be used by AutoRound, since layers are converted to QuantizedLinear after quantization. If the function can correctly dump average bits in typical scenarios such as INT4, I’d prefer to keep it in the class. Otherwise, it would be better to move it elsewhere for now. |
||
"""Dumps the average bits of the model based on the layer configuration.""" | ||
total_numel = 0 | ||
total_bits = 0 | ||
for n, m in self.model.named_modules(): | ||
if isinstance(m, SUPPORTED_LAYER_TYPES): | ||
m_numel = m.weight.numel() | ||
layer_config = self.layer_config if layer_config is None else layer_config | ||
m_bits = layer_config[n]["bits"] if n in layer_config else self.bits | ||
total_numel += m_numel | ||
total_bits += m_numel * m_bits | ||
avg_bits = round(total_bits / total_numel, 3) | ||
logger.info(f"current average bits of model: {avg_bits}") | ||
return avg_bits | ||
|
||
|
||
class AutoRoundAdam(AutoRound): | ||
"""Class for automatic rounding-based quantization with optimizers like adamw of a PyTorch model. | ||
|
@@ -3117,7 +3180,7 @@ def scale_loss_and_backward(self, scaler, loss): | |
loss = scaler.scale(loss) | ||
|
||
loss.backward() | ||
if is_optimum_habana_available(): | ||
if is_hpex_available(): | ||
htcore.mark_step() | ||
return loss | ||
|
||
|
@@ -3131,5 +3194,5 @@ def step(self, scaler, optimizer, lr_schedule): | |
optimizer.step() | ||
optimizer.zero_grad() | ||
lr_schedule.step() | ||
if is_optimum_habana_available(): | ||
if is_hpex_available(): | ||
htcore.mark_step() |
Uh oh!
There was an error while loading. Please reload this page.