Skip to content

Commit fe741f4

Browse files
committed
refactor
1 parent f7f8298 commit fe741f4

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

generate.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,14 @@ def encode_tokens(tokenizer, string, bos=True, device=default_device):
216216
tokens = [tokenizer.bos_id()] + tokens
217217
return torch.tensor(tokens, dtype=torch.int, device=device)
218218

219+
def _convert_weight(model):
220+
from quantize import WeightOnlyInt4Linear
221+
for fqn, mod in model.named_modules():
222+
if isinstance(mod, WeightOnlyInt4Linear):
223+
weight = mod.weight.data
224+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight, mod.inner_k_tiles)
225+
mod.weight = weight_int4pack
226+
219227
def _load_model(checkpoint_path, device, precision, use_tp):
220228
use_cuda = 'cuda' in device
221229
with torch.device('meta'):
@@ -240,19 +248,15 @@ def _load_model(checkpoint_path, device, precision, use_tp):
240248
checkpoint = checkpoint["model"]
241249
model.load_state_dict(checkpoint, assign=True)
242250

251+
model = model.to(device=device, dtype=precision)
252+
# int4 packed weight needs to be converted after model loading to the specific device
253+
if "int4" in str(checkpoint_path):
254+
_convert_weight(model)
255+
243256
if use_tp:
244257
from tp import apply_tp
245258
print("Applying tensor parallel to model ...")
246259
apply_tp(model)
247-
248-
model = model.to(device=device, dtype=precision)
249-
if "int4" in str(checkpoint_path):
250-
from quantize import WeightOnlyInt4Linear
251-
for fqn, mod in model.named_modules():
252-
if isinstance(mod, WeightOnlyInt4Linear):
253-
weight = mod.weight.data
254-
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight, mod.inner_k_tiles)
255-
mod.weight = weight_int4pack
256260
return model.eval()
257261

258262
def _get_model_size(model):

0 commit comments

Comments
 (0)