1111import argparse
1212import conversation as convo
1313import retrieval .wikipedia as wp
14- from transformers import AutoTokenizer , AutoModelForCausalLM , AutoConfig , StoppingCriteria , StoppingCriteriaList
14+ from transformers import AutoTokenizer , AutoModelForCausalLM , AutoConfig , StoppingCriteria , StoppingCriteriaList , BitsAndBytesConfig
1515from accelerate import infer_auto_device_map , init_empty_weights
1616
1717
@@ -48,16 +48,19 @@ class ChatModel:
4848 human_id = "<human>"
4949 bot_id = "<bot>"
5050
51- def __init__ (self , model_name , gpu_id , max_memory ):
51+ def __init__ (self , model_name , gpu_id , max_memory , load_in_8bit ):
5252 device = torch .device ('cuda' , gpu_id ) # TODO: allow sending to cpu
5353
54+ quantization_config = BitsAndBytesConfig (
55+ load_in_8bit = load_in_8bit ,
56+ llm_int8_enable_fp32_cpu_offload = True ,
57+ ) # config to load in 8-bit if load_in_8bit
58+
5459 # recommended default for devices with > 40 GB VRAM
5560 # load model onto one device
5661 if max_memory is None :
57- self ._model = AutoModelForCausalLM .from_pretrained (
58- model_name , torch_dtype = torch .float16 , device_map = "auto" )
59- self ._model .to (device )
60- # load the model with the given max_memory config (for devices with insufficient VRAM or multi-gpu)
62+ device_map = "auto"
63+
6164 else :
6265 config = AutoConfig .from_pretrained (model_name )
6366 # load empty weights
@@ -66,21 +69,24 @@ def __init__(self, model_name, gpu_id, max_memory):
6669
6770 model_from_conf .tie_weights ()
6871
69- # create a device_map from max_memory
72+ #create a device_map from max_memory
7073 device_map = infer_auto_device_map (
7174 model_from_conf ,
7275 max_memory = max_memory ,
7376 no_split_module_classes = ["GPTNeoXLayer" ],
74- dtype = "float16"
75- )
76- # load the model with the above device_map
77- self ._model = AutoModelForCausalLM .from_pretrained (
78- model_name ,
79- device_map = device_map ,
80- offload_folder = "offload" , # optional offload-to-disk overflow directory (auto-created)
81- offload_state_dict = True ,
82- torch_dtype = torch .float16
77+ dtype = "float16" ,
8378 )
79+
80+ self ._model = AutoModelForCausalLM .from_pretrained (
81+ model_name ,
82+ torch_dtype = torch .float16 ,
83+ device_map = device_map ,
84+ offload_folder = "offload" ,
85+ quantization_config = quantization_config ,
86+ )
87+ if not load_in_8bit :
88+ self ._model .to (device ) # not supported by load_in_8bit
89+
8490 self ._tokenizer = AutoTokenizer .from_pretrained (model_name )
8591
8692 def do_inference (self , prompt , max_new_tokens , do_sample , temperature , top_k , stream_callback = None ):
@@ -110,7 +116,7 @@ class OpenChatKitShell(cmd.Cmd):
110116 intro = "Welcome to OpenChatKit shell. Type /help or /? to list commands.\n "
111117 prompt = ">>> "
112118
113- def __init__ (self , gpu_id , model_name_or_path , max_tokens , sample , temperature , top_k , retrieval , max_memory , do_stream ):
119+ def __init__ (self , gpu_id , model_name_or_path , max_tokens , sample , temperature , top_k , retrieval , max_memory , do_stream , load_in_8bit ):
114120 super ().__init__ ()
115121 self ._gpu_id = int (gpu_id )
116122 self ._model_name_or_path = model_name_or_path
@@ -121,10 +127,11 @@ def __init__(self, gpu_id, model_name_or_path, max_tokens, sample, temperature,
121127 self ._retrieval = retrieval
122128 self ._max_memory = max_memory
123129 self ._do_stream = do_stream
130+ self ._load_in_8bit = load_in_8bit
124131
125132 def preloop (self ):
126133 print (f"Loading { self ._model_name_or_path } to cuda:{ self ._gpu_id } ..." )
127- self ._model = ChatModel (self ._model_name_or_path , self ._gpu_id , self ._max_memory )
134+ self ._model = ChatModel (self ._model_name_or_path , self ._gpu_id , self ._max_memory , self . _load_in_8bit )
128135
129136 if self ._retrieval :
130137 print (f"Loading retrieval index..." )
@@ -253,6 +260,13 @@ def main():
253260 help = 'max CPU RAM to allocate' ,
254261 required = False
255262 )
263+ # `pip install bitsandbytes` to use. No effect when used with -g or -r.
264+ parser .add_argument (
265+ '--load-in-8bit' ,
266+ default = False ,
267+ action = 'store_true' ,
268+ help = 'indicates whether to load model in 8 bit'
269+ )
256270 args = parser .parse_args ()
257271
258272 # set max_memory dictionary if given
@@ -278,6 +292,7 @@ def main():
278292 args .retrieval ,
279293 max_memory ,
280294 not args .no_stream ,
295+ args .load_in_8bit ,
281296 ).cmdloop ()
282297
283298
0 commit comments