diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 6998b76e12..cd2ff5bc29 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -81,7 +81,11 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2 if param.__class__.__name__ == "Params4bit": - num_bytes = param.quant_storage.itemsize if hasattr(param, "quant_storage") else 1 + if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"): + num_bytes = param.quant_storage.itemsize + else: + num_bytes = 1 + num_params = num_params * 2 * num_bytes all_param += num_params