Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
is_fp8_linear,
is_fp8_model,
is_hpex_available,
is_moe_model,
llm_load_model,
memory_monitor,
mv_module_from_gpu,
Expand Down Expand Up @@ -390,19 +391,28 @@ def __init__(
)

# Automatically adjust the disable_opt_rtn option if the user does not explicitly set it.
# To avoid None issue, we keep a copy though it's a little ugly
self.orig_disable_opt_rtn = disable_opt_rtn
if self.iters != 0 and self.orig_disable_opt_rtn is not None:
logger.warning("`disable_opt_rtn` only works when `iters` is set to 0, ignore it now.")
disable_opt_rtn = True
if (
self.bits >= 8
and self.act_bits >= 16
and self.iters == 0
and self.data_type == "int"
and disable_opt_rtn is None
):
logger.warning("For INT8 RTN quantization, set `--disable_opt_rtn` as default.")
logger.warning("`disable_opt_rtn` is turned on for W8A16 quantization to improve efficiency.")
disable_opt_rtn = True
if disable_opt_rtn is None:
if self.iters == 0:
logger.info("For the most RTN cases, set `--disable_opt_rtn` to False as default.")
disable_otp_rtn = False
if disable_opt_rtn is None and self.iters == 0:
logger.info(
"`enable_opt_rtn` is turned on, set `--disable_opt_rtn` for higher speed at the cost of accuracy."
)
disable_opt_rtn = False

# Important Note! This is not very robust, do NOT rely on it to do high risky thing
self.is_moe_model = is_moe_model(self.model)

self.minmax_lr = minmax_lr or self.lr
self.enable_alg_ext = enable_alg_ext
Expand Down Expand Up @@ -1105,6 +1115,20 @@ def _quantize_layer_via_rtn(self, name: str, dtype: torch.dtype = None, to_cpu=T
m.zp = None
else:
try:
disable_opt_rtn = self.disable_opt_rtn
if (
not disable_opt_rtn
and self.orig_disable_opt_rtn is None
and self.is_moe_model
and "expert" in m.tmp_name
and "shared_expert" not in m.tmp_name
and self.super_bits is None # GGUF still uses the optimized RTN for MoE layers
):
disable_opt_rtn = True
logger.warning_once(
"MoE layer detected: optimized RTN is disabled for efficiency. "
"Use `--enable_opt_rtn` to force-enable it for MoE layers."
)
m = m.to(tuning_device)
m = WrapperLinear(
m,
Expand All @@ -1113,7 +1137,7 @@ def _quantize_layer_via_rtn(self, name: str, dtype: torch.dtype = None, to_cpu=T
enable_norm_bias_tuning=False,
enable_round_tuning=False,
enable_torch_compile=self.enable_torch_compile,
disable_opt_rtn=self.disable_opt_rtn,
disable_opt_rtn=disable_opt_rtn,
)
m = m.unwrapper({})
except torch.OutOfMemoryError:
Expand Down
2 changes: 1 addition & 1 deletion auto_round/compressors/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def is_default(self):
@dataclass
class TuningExtraConfig(BaseExtraConfig):
amp: bool = True
disable_opt_rtn: bool | None = True
disable_opt_rtn: bool | None = None
enable_alg_ext: bool = False
enable_minmax_tuning: bool = True
enable_norm_bias_tuning: bool = False
Expand Down
11 changes: 11 additions & 0 deletions auto_round/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,17 @@ def mv_module_from_gpu(module):
return module.to("cpu")


def is_moe_model(model: torch.nn.Module) -> bool:
if hasattr(model, "config"):
for key in model.config.to_dict().keys():
if "moe" in key or "expert" in key:
return True
for n, m in model.named_modules():
if "expert" in n:
return True
return False


def to_dtype(input, dtype=torch.float32):
"""Moves input data to the specified data type.

Expand Down