|
15 | 15 | import logging |
16 | 16 | from collections import OrderedDict |
17 | 17 | from copy import deepcopy |
18 | | -from typing import Dict, Iterable, List, Optional |
| 18 | +from typing import Iterable, List |
19 | 19 | from typing import OrderedDict as OrderedDictType |
20 | 20 | from typing import Union |
21 | 21 |
|
|
40 | 40 | ) |
41 | 41 | from compressed_tensors.utils.helpers import deprecated, replace_module |
42 | 42 | from compressed_tensors.utils.match import match_named_modules, match_targets |
43 | | -from compressed_tensors.utils.offload import update_parameter_data |
44 | | -from safetensors import safe_open |
45 | 43 | from torch.nn import Module |
46 | 44 |
|
47 | 45 |
|
@@ -196,58 +194,6 @@ def find_name_or_class_matches( |
196 | 194 | return match_targets(name, module, targets) |
197 | 195 |
|
198 | 196 |
|
199 | | -def _infer_status(model: Module) -> Optional[QuantizationStatus]: |
200 | | - for module in model.modules(): |
201 | | - status = getattr(module, "quantization_status", None) |
202 | | - if status is not None: |
203 | | - return status |
204 | | - return None |
205 | | - |
206 | | - |
207 | | -def _load_quant_args_from_mapping( |
208 | | - base_name: str, module_name: str, module: Module, mapping: Dict |
209 | | -): |
210 | | - # TODO: skip update and just register here, don't do it in initialize |
211 | | - """ |
212 | | - Loads scale and zero point from a state_dict into the specified module |
213 | | -
|
214 | | - :param base_name: quantization target, one of: weights, input_activations or |
215 | | - output_activations |
216 | | - :param module_name: pytorch module name to look up in state_dict |
217 | | - :module: pytorch module associated with module_name |
218 | | - :mapping: mapping to search fetch paths on disk for a given parameter |
219 | | - """ |
220 | | - scale_name = f"{base_name}_scale" |
221 | | - zp_name = f"{base_name}_zero_point" |
222 | | - g_idx_name = f"{base_name}_g_idx" |
223 | | - |
224 | | - state_dict_scale_path = mapping.get(f"{module_name}.{scale_name}", None) |
225 | | - state_dict_zp_path = mapping.get(f"{module_name}.{zp_name}", None) |
226 | | - state_dict_g_idx_path = mapping.get(f"{module_name}.{g_idx_name}", None) |
227 | | - |
228 | | - if state_dict_g_idx_path is not None: |
229 | | - with safe_open(state_dict_g_idx_path, framework="pt", device="cpu") as f: |
230 | | - state_dict_g_idx = f.get_tensor(f"{module_name}.{g_idx_name}") |
231 | | - |
232 | | - update_parameter_data(module, state_dict_g_idx, g_idx_name) |
233 | | - |
234 | | - if state_dict_scale_path is not None: |
235 | | - # module is quantized |
236 | | - with safe_open(state_dict_scale_path, framework="pt", device="cpu") as f: |
237 | | - state_dict_scale = f.get_tensor(f"{module_name}.{scale_name}") |
238 | | - |
239 | | - update_parameter_data(module, state_dict_scale, scale_name) |
240 | | - |
241 | | - if state_dict_zp_path is None: |
242 | | - # fill in zero point for symmetric quantization |
243 | | - state_dict_zp = torch.zeros_like(state_dict_scale, device="cpu") |
244 | | - else: |
245 | | - with safe_open(state_dict_zp_path, framework="pt", device="cpu") as f: |
246 | | - state_dict_zp = f.get_tensor(f"{module_name}.{zp_name}") |
247 | | - |
248 | | - update_parameter_data(module, state_dict_zp, zp_name) |
249 | | - |
250 | | - |
251 | 197 | def _scheme_from_targets( |
252 | 198 | target_to_scheme: OrderedDictType[str, QuantizationScheme], |
253 | 199 | targets: List[str], |
|
0 commit comments