diff --git a/README.md b/README.md index 335a8f33f..b2309a975 100644 --- a/README.md +++ b/README.md @@ -195,7 +195,8 @@ ar.quantize_and_save(output_dir="./qmodel", format="auto_round") ##### Algorithm Settings - **`enable_alg_ext` (bool)**: [Experimental Feature] Only for `iters>0`. Enable algorithm variants for specific schemes (e.g., MXFP4/W2A16) that could bring notable improvements. Default is `False`. -- **`disable_opt_rtn` (bool)**: Use pure RTN mode for specific schemes (e.g., GGUF and WOQ). Default is `None` (improved RTN enabled). + +- **`disable_opt_rtn` (bool|None)**: Use pure RTN mode for specific schemes (e.g., GGUF and WOQ). Default is `None`. If None, it defaults to `False` in most cases to improve accuracy, but may be set to `True` due to known issues. ##### Tuning Process Parameters - **`iters` (int)**: Number of tuning iterations (default is `200`). Common values: 0 (RTN mode), 50 (with lr=5e-3 recommended), 1000. Higher values increase accuracy but slow down tuning. @@ -355,3 +356,4 @@ Special thanks to open-source low precision libraries such as AutoGPTQ, AutoAWQ, ## 🌟 Support Us If you find AutoRound helpful, please ⭐ star the repo and share it with your community! + diff --git a/auto_round/__main__.py b/auto_round/__main__.py index 6d95b4864..0a17a05d0 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -253,14 +253,23 @@ def __init__(self, *args, **kwargs): action="store_true", help="Enable PyTorch deterministic algorithms for reproducible results. ", ) - tuning.add_argument( + group = tuning.add_mutually_exclusive_group() + group.add_argument( "--disable_opt_rtn", - "--disable-opt-rtn", - action=argparse.BooleanOptionalAction, + action="store_const", + const=True, + dest="disable_opt_rtn", default=None, help="Disable optimization for RTN (Round-To-Nearest) mode when iters=0. " "RTN is fast but less accurate; keeping optimization enabled is recommended.", ) + group.add_argument( + "--enable_opt_rtn", + action="store_const", + const=False, + dest="disable_opt_rtn", + help="Enable optimization for RTN mode when iters=0.", + ) scheme = self.add_argument_group("Scheme Arguments") scheme.add_argument("--bits", default=None, type=int, help="Number of bits for weight quantization. ") diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 51c850c16..6d69c01c3 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -85,7 +85,7 @@ def __new__( enable_adam: bool = False, extra_config: ExtraConfig = None, enable_alg_ext: bool = None, - disable_opt_rtn: Optional[bool] = None, + disable_opt_rtn: bool | None = None, low_cpu_mem_usage: bool = False, **kwargs, ) -> BaseCompressor: diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index cd9fc516c..a2af9e011 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -189,7 +189,7 @@ def __init__( device_map: Union[str, torch.device, int, dict] = 0, enable_torch_compile: bool = False, enable_alg_ext: bool = False, - disable_opt_rtn: Optional[bool] = None, + disable_opt_rtn: bool | None = None, seed: int = 42, low_cpu_mem_usage: bool = False, **kwargs, @@ -397,9 +397,11 @@ def __init__( and self.data_type == "int" and disable_opt_rtn is None ): - logger.warning("for INT8 RTN quantization, set `--disable_opt_rtn` as default.") + logger.warning("For INT8 RTN quantization, set `--disable_opt_rtn` as default.") disable_opt_rtn = True if disable_opt_rtn is None: + if self.iters == 0: + logger.info("For the most RTN cases, set `--disable_opt_rtn` to False as default.") disable_otp_rtn = False self.minmax_lr = minmax_lr or self.lr diff --git a/auto_round/compressors/config.py b/auto_round/compressors/config.py index 4bab246c3..f7e0d92f2 100644 --- a/auto_round/compressors/config.py +++ b/auto_round/compressors/config.py @@ -32,7 +32,7 @@ def __init__( self, # tuning amp: bool = True, - disable_opt_rtn: Optional[bool] = True, + disable_opt_rtn: bool | None = None, enable_alg_ext: bool = False, enable_minmax_tuning: bool = True, enable_norm_bias_tuning: bool = False, @@ -247,7 +247,7 @@ def is_default(self): @dataclass class TuningExtraConfig(BaseExtraConfig): amp: bool = True - disable_opt_rtn: Optional[bool] = True + disable_opt_rtn: bool | None = True enable_alg_ext: bool = False enable_minmax_tuning: bool = True enable_norm_bias_tuning: bool = False