@@ -216,6 +216,14 @@ def encode_tokens(tokenizer, string, bos=True, device=default_device):
216
216
tokens = [tokenizer .bos_id ()] + tokens
217
217
return torch .tensor (tokens , dtype = torch .int , device = device )
218
218
219
+ def _convert_weight (model ):
220
+ from quantize import WeightOnlyInt4Linear
221
+ for fqn , mod in model .named_modules ():
222
+ if isinstance (mod , WeightOnlyInt4Linear ):
223
+ weight = mod .weight .data
224
+ weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (weight , mod .inner_k_tiles )
225
+ mod .weight = weight_int4pack
226
+
219
227
def _load_model (checkpoint_path , device , precision , use_tp ):
220
228
use_cuda = 'cuda' in device
221
229
with torch .device ('meta' ):
@@ -240,19 +248,15 @@ def _load_model(checkpoint_path, device, precision, use_tp):
240
248
checkpoint = checkpoint ["model" ]
241
249
model .load_state_dict (checkpoint , assign = True )
242
250
251
+ model = model .to (device = device , dtype = precision )
252
+ # int4 packed weight needs to be converted after model loading to the specific device
253
+ if "int4" in str (checkpoint_path ):
254
+ _convert_weight (model )
255
+
243
256
if use_tp :
244
257
from tp import apply_tp
245
258
print ("Applying tensor parallel to model ..." )
246
259
apply_tp (model )
247
-
248
- model = model .to (device = device , dtype = precision )
249
- if "int4" in str (checkpoint_path ):
250
- from quantize import WeightOnlyInt4Linear
251
- for fqn , mod in model .named_modules ():
252
- if isinstance (mod , WeightOnlyInt4Linear ):
253
- weight = mod .weight .data
254
- weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (weight , mod .inner_k_tiles )
255
- mod .weight = weight_int4pack
256
260
return model .eval ()
257
261
258
262
def _get_model_size (model ):
0 commit comments