@@ -216,6 +216,14 @@ def encode_tokens(tokenizer, string, bos=True, device=default_device):
216216        tokens  =  [tokenizer .bos_id ()] +  tokens 
217217    return  torch .tensor (tokens , dtype = torch .int , device = device )
218218
219+ def  _convert_weight (model ):
220+         from  quantize  import  WeightOnlyInt4Linear 
221+         for  fqn , mod  in  model .named_modules ():
222+             if  isinstance (mod , WeightOnlyInt4Linear ):
223+                 weight  =  mod .weight .data 
224+                 weight_int4pack  =  torch .ops .aten ._convert_weight_to_int4pack (weight , mod .inner_k_tiles )
225+                 mod .weight  =  weight_int4pack 
226+ 
219227def  _load_model (checkpoint_path , device , precision , use_tp ):
220228    use_cuda  =  'cuda'  in  device 
221229    with  torch .device ('meta' ):
@@ -240,19 +248,15 @@ def _load_model(checkpoint_path, device, precision, use_tp):
240248        checkpoint  =  checkpoint ["model" ]
241249    model .load_state_dict (checkpoint , assign = True )
242250
251+     model  =  model .to (device = device , dtype = precision )
252+     # int4 packed weight needs to be converted after model loading to the specific device 
253+     if  "int4"  in  str (checkpoint_path ):
254+         _convert_weight (model )
255+ 
243256    if  use_tp :
244257        from  tp  import  apply_tp 
245258        print ("Applying tensor parallel to model ..." )
246259        apply_tp (model )
247- 
248-     model  =  model .to (device = device , dtype = precision )
249-     if  "int4"  in  str (checkpoint_path ):
250-         from  quantize  import  WeightOnlyInt4Linear 
251-         for  fqn , mod  in  model .named_modules ():
252-             if  isinstance (mod , WeightOnlyInt4Linear ):
253-                 weight  =  mod .weight .data 
254-                 weight_int4pack  =  torch .ops .aten ._convert_weight_to_int4pack (weight , mod .inner_k_tiles )
255-                 mod .weight  =  weight_int4pack 
256260    return  model .eval ()
257261
258262def  _get_model_size (model ):
0 commit comments