Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import accelerate
import torch
from accelerate.big_modeling import dispatch_model, infer_auto_device_map
from accelerate.utils import get_max_memory
from accelerate.utils import get_balanced_memory, get_max_memory
from torch import autocast
from tqdm import tqdm
from transformers import set_seed
Expand Down Expand Up @@ -1770,7 +1770,7 @@ def calib(self, nsamples, bs):
data_new[key] = data[key].to(self.model.device)
input_ids = data_new["input_ids"]
elif isinstance(data, tuple) or isinstance(data, list):
data_new = to_device(data)
data_new = to_device(data, self.model.device)
input_ids = data_new[0]
else:
data_new = {}
Expand Down Expand Up @@ -1904,6 +1904,7 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
if str(self.model.device) == "cpu" and (not self.device.startswith("hpu")):
no_split_modules = getattr(self.model, "_no_split_modules", [])
devices = parse_available_devices(self.device_map)

max_memory = get_max_memory()
new_max_memory = {}
if "cpu" not in devices:
Expand All @@ -1915,13 +1916,18 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
device = "cpu"
else:
raise ValueError(f"Unsupported device {device} in device_map: {self.device_map}")
new_max_memory[device] = max_memory[device]
new_max_memory[device] = max_memory[device] * 0.9
new_max_memory = get_balanced_memory(
self.model,
max_memory=new_max_memory,
no_split_module_classes=no_split_modules,
)
device_map = infer_auto_device_map(
self.model, max_memory=new_max_memory, no_split_module_classes=no_split_modules
)
if len(devices) > 1 and "cpu" in device_map.values():
logger.warning(
"Not enough vram cause the ram to be used, which may severely impact speed."
"Some layers are offloaded to cpu, which may severely impact calibration speed."
" Please consider using more cards."
)

Expand Down
6 changes: 3 additions & 3 deletions auto_round/compressors/mllm/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def calib(self, nsamples, bs):
pbar.update(1)
continue
if isinstance(data, torch.Tensor):
input_ids = data.to(self.device)
input_ids = data.to(self.model.device)
data_new = input_ids
elif isinstance(data, str):
if self.tokenizer is None:
Expand All @@ -360,7 +360,7 @@ def calib(self, nsamples, bs):
)
data_new = {}
for key in data.keys():
data_new[key] = data[key].to(self.device)
data_new[key] = data[key].to(self.model.device)
input_ids = data_new["input_ids"]
elif isinstance(data, dict) and "text" in data.keys():
text = data["text"]
Expand All @@ -381,7 +381,7 @@ def calib(self, nsamples, bs):
data_new[key] = to_dtype(data_new[key], self.model.dtype)
input_ids = data_new["input_ids"]
elif isinstance(data, tuple) or isinstance(data, list):
data_new = data
data_new = to_device(data, self.model.device)
input_ids = data_new[0]
else:
data_new = {}
Expand Down
4 changes: 2 additions & 2 deletions auto_round/modelling/qwen3_vl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def __init__(
self.gate = original.gate
self.calibrate_all_experts = calibrate_all_experts
self.experts = SequentialQwen3VLMoeTextExperts(text_config, original.experts)
if not transformers_version <= version.parse(
"4.57.3"
if not transformers_version < version.parse(
"5.0"
): # remove conversion_mapping for qwen3_vl_moe when transformers>=5.0
from transformers.conversion_mapping import register_checkpoint_conversion_mapping

Expand Down