1111import argparse
1212import conversation as convo
1313import retrieval .wikipedia as wp
14- from transformers import AutoTokenizer , AutoModelForCausalLM , AutoConfig , StoppingCriteria , StoppingCriteriaList
14+ from transformers import AutoTokenizer , AutoModelForCausalLM , AutoConfig , StoppingCriteria , StoppingCriteriaList , BitsAndBytesConfig
1515from accelerate import infer_auto_device_map , init_empty_weights
1616
1717
@@ -51,14 +51,16 @@ class ChatModel:
5151 def __init__ (self , model_name , gpu_id , max_memory , load_in_8bit ):
5252 device = torch .device ('cuda' , gpu_id ) # TODO: allow sending to cpu
5353
54+ quantization_config = BitsAndBytesConfig (
55+ load_in_8bit = load_in_8bit ,
56+ llm_int8_enable_fp32_cpu_offload = True ,
57+ ) # config to load in 8-bit if load_in_8bit
58+
5459 # recommended default for devices with > 40 GB VRAM
5560 # load model onto one device
5661 if max_memory is None :
57- self ._model = AutoModelForCausalLM .from_pretrained (
58- model_name , torch_dtype = torch .float16 , device_map = "auto" , load_in_8bit = load_in_8bit )
59- if not load_in_8bit :
60- self ._model .to (device ) # not supported by load_in_8bit
61- # load the model with the given max_memory config (for devices with insufficient VRAM or multi-gpu)
62+ device_map = "auto"
63+
6264 else :
6365 config = AutoConfig .from_pretrained (model_name )
6466 # load empty weights
@@ -67,21 +69,24 @@ def __init__(self, model_name, gpu_id, max_memory, load_in_8bit):
6769
6870 model_from_conf .tie_weights ()
6971
70- # create a device_map from max_memory
72+ #create a device_map from max_memory
7173 device_map = infer_auto_device_map (
7274 model_from_conf ,
7375 max_memory = max_memory ,
7476 no_split_module_classes = ["GPTNeoXLayer" ],
75- dtype = "float16"
76- )
77- # load the model with the above device_map
78- self ._model = AutoModelForCausalLM .from_pretrained (
79- model_name ,
80- device_map = device_map ,
81- offload_folder = "offload" , # optional offload-to-disk overflow directory (auto-created)
82- offload_state_dict = True ,
83- torch_dtype = torch .float16
77+ dtype = "float16" ,
8478 )
79+
80+ self ._model = AutoModelForCausalLM .from_pretrained (
81+ model_name ,
82+ torch_dtype = torch .float16 ,
83+ device_map = device_map ,
84+ offload_folder = "offload" ,
85+ quantization_config = quantization_config ,
86+ )
87+ if not load_in_8bit :
88+ self ._model .to (device ) # not supported by load_in_8bit
89+
8590 self ._tokenizer = AutoTokenizer .from_pretrained (model_name )
8691
8792 def do_inference (self , prompt , max_new_tokens , do_sample , temperature , top_k , stream_callback = None ):
0 commit comments