forked from casper-hansen/AutoAWQ
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathquantizer_patch.py
More file actions
152 lines (125 loc) · 6.01 KB
/
quantizer_patch.py
File metadata and controls
152 lines (125 loc) · 6.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
Patch for AutoAWQ quantizer to support DeepSeek-V3/R1 models on ROCm
This addresses the rotary_emb attribute error
"""
import torch
import transformers
from awq.quantize.quantizer import AwqQuantizer
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
class PatchedAwqQuantizer(AwqQuantizer):
def quantize(self):
"""Override quantize method to handle DeepSeek-V3 rotary embeddings"""
from tqdm import tqdm
from awq.utils.utils import clear_memory, get_best_device
from awq.utils.module import (
append_str_prefix,
get_op_name,
get_named_linears,
set_op_by_name,
exclude_layers_to_not_quantize,
)
from awq.quantize.scale import apply_scale, apply_clip
for i in tqdm(range(len(self.modules)), desc="AWQ"):
# Move module and inputs to correct device
common_device = next(self.modules[i].parameters()).device
if common_device is None or str(common_device) == "cpu":
if torch.cuda.is_available():
best_device = "cuda:" + str(i % torch.cuda.device_count())
else:
best_device = get_best_device()
self.modules[i] = self.modules[i].to(best_device)
common_device = next(self.modules[i].parameters()).device
if self.module_kwargs.get("position_ids") is not None:
self.module_kwargs["position_ids"] = self.module_kwargs[
"position_ids"
].to(common_device)
if self.module_kwargs.get("attention_mask") is not None:
self.module_kwargs["attention_mask"] = self.module_kwargs[
"attention_mask"
].to(common_device)
self.inps = self.inps.to(common_device)
# Move embeddings
self.awq_model.move_embed(self.model, common_device)
# Handle position embeddings for transformers >= 4.48.0
if (
transformers.__version__ >= "4.48.0"
and self.module_kwargs.get("position_embeddings") is None
):
# Check if this is a DeepSeek model
model_type = getattr(self.awq_model, 'model_type', '')
if 'deepseek' in model_type.lower():
# For DeepSeek models, skip position embeddings computation
# as they handle it internally
pass
else:
# For other models, try to compute position embeddings
if hasattr(self.model.model, 'rotary_emb'):
self.module_kwargs["position_embeddings"] = self.model.model.rotary_emb(
self.inps, self.module_kwargs["position_ids"]
)
if (transformers.__version__ >= "4.48.0"
and self.module_kwargs.get('attention_mask') is None):
self.module_kwargs['attention_mask'] = None
for k, v in self.module_kwargs.items():
# position embeddings found in tuple
if isinstance(v, tuple):
self.module_kwargs[k] = tuple(
item.to(common_device) if isinstance(item, (torch.Tensor, torch.nn.Module))
else item for item in v
)
# Get layer, extract linear modules, extract input features
named_linears = get_named_linears(self.modules[i])
# Filter out the linear layers we don't want to exclude
named_linears = exclude_layers_to_not_quantize(
named_linears, self.modules_to_not_convert
)
input_feat = self._get_input_feat(self.modules[i], named_linears)
clear_memory()
# Compute and apply scale list
module_config = self.awq_model.get_layers_for_scaling(
self.modules[i], input_feat, self.module_kwargs
)
scales_list = [
self._search_best_scale(self.modules[i], **layer)
for layer in module_config
]
apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)
scales_list = append_str_prefix(
scales_list, get_op_name(self.model, self.modules[i]) + "."
)
# Compute and apply clipping list
if self.apply_clip:
clip_list = self._search_best_clip(
self.modules[i], named_linears, input_feat
)
apply_clip(self.modules[i], clip_list)
clip_list = append_str_prefix(
clip_list, get_op_name(self.model, self.modules[i]) + "."
)
# Quantize weights
if not self.export_compatible:
self._apply_quant(self.modules[i], named_linears)
clear_memory()
# Monkey patch the original quantizer
def patch_autoawq_for_deepseek():
"""Apply the patch to AutoAWQ for DeepSeek models"""
import awq.quantize.quantizer
awq.quantize.quantizer.AwqQuantizer = PatchedAwqQuantizer
print("Applied DeepSeek compatibility patch to AutoAWQ")
if __name__ == "__main__":
# Apply the patch
patch_autoawq_for_deepseek()
# Now run your quantization
model_path = '/home/hotaisle/workspace/models/DeepSeek-R1-0528-bf16'
quant_path = '/home/hotaisle/workspace/models/DeepSeek-R1-0528-awq'
quant_config = { "zero_point": True, "q_group_size": 64, "w_bit": 4, "version": "GEMM" }
print(f"Loading model from {model_path}...")
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
print("Starting quantization...")
model.quantize(tokenizer, quant_config=quant_config)
print("Saving quantized model...")
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f'Model is quantized and saved at "{quant_path}"')