@@ -48,15 +48,16 @@ class ChatModel:
4848 human_id = "<human>"
4949 bot_id = "<bot>"
5050
51- def __init__ (self , model_name , gpu_id , max_memory ):
51+ def __init__ (self , model_name , gpu_id , max_memory , load_in_8bit ):
5252 device = torch .device ('cuda' , gpu_id ) # TODO: allow sending to cpu
5353
5454 # recommended default for devices with > 40 GB VRAM
5555 # load model onto one device
5656 if max_memory is None :
5757 self ._model = AutoModelForCausalLM .from_pretrained (
58- model_name , torch_dtype = torch .float16 , device_map = "auto" )
59- self ._model .to (device )
58+ model_name , torch_dtype = torch .float16 , device_map = "auto" , load_in_8bit = load_in_8bit )
59+ if not load_in_8bit :
60+ self ._model .to (device ) # not supported by load_in_8bit
6061 # load the model with the given max_memory config (for devices with insufficient VRAM or multi-gpu)
6162 else :
6263 config = AutoConfig .from_pretrained (model_name )
@@ -110,7 +111,7 @@ class OpenChatKitShell(cmd.Cmd):
110111 intro = "Welcome to OpenChatKit shell. Type /help or /? to list commands.\n "
111112 prompt = ">>> "
112113
113- def __init__ (self , gpu_id , model_name_or_path , max_tokens , sample , temperature , top_k , retrieval , max_memory , do_stream ):
114+ def __init__ (self , gpu_id , model_name_or_path , max_tokens , sample , temperature , top_k , retrieval , max_memory , do_stream , load_in_8bit ):
114115 super ().__init__ ()
115116 self ._gpu_id = int (gpu_id )
116117 self ._model_name_or_path = model_name_or_path
@@ -121,10 +122,11 @@ def __init__(self, gpu_id, model_name_or_path, max_tokens, sample, temperature,
121122 self ._retrieval = retrieval
122123 self ._max_memory = max_memory
123124 self ._do_stream = do_stream
125+ self ._load_in_8bit = load_in_8bit
124126
125127 def preloop (self ):
126128 print (f"Loading { self ._model_name_or_path } to cuda:{ self ._gpu_id } ..." )
127- self ._model = ChatModel (self ._model_name_or_path , self ._gpu_id , self ._max_memory )
129+ self ._model = ChatModel (self ._model_name_or_path , self ._gpu_id , self ._max_memory , self . _load_in_8bit )
128130
129131 if self ._retrieval :
130132 print (f"Loading retrieval index..." )
@@ -253,6 +255,13 @@ def main():
253255 help = 'max CPU RAM to allocate' ,
254256 required = False
255257 )
258+ # `pip install bitsandbytes` to use. No effect when used with -g or -r.
259+ parser .add_argument (
260+ '--load-in-8bit' ,
261+ default = False ,
262+ action = 'store_true' ,
263+ help = 'indicates whether to load model in 8 bit'
264+ )
256265 args = parser .parse_args ()
257266
258267 # set max_memory dictionary if given
@@ -278,6 +287,7 @@ def main():
278287 args .retrieval ,
279288 max_memory ,
280289 not args .no_stream ,
290+ args .load_in_8bit ,
281291 ).cmdloop ()
282292
283293
0 commit comments