Skip to content

Commit 5cfefc1

Browse files
committed
Add argument to load model in 8bit
1 parent 71dd823 commit 5cfefc1

File tree

2 files changed

+34
-19
lines changed

2 files changed

+34
-19
lines changed

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@ dependencies:
2424
- datasets==2.10.1
2525
- loguru==0.6.0
2626
- netifaces==0.11.0
27-
- transformers==4.21.1
27+
- transformers==4.27.4
2828
- wandb==0.13.10
2929
- zstandard==0.20.0

inference/bot.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import argparse
1212
import conversation as convo
1313
import retrieval.wikipedia as wp
14-
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, StoppingCriteria, StoppingCriteriaList
14+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, StoppingCriteria, StoppingCriteriaList, BitsAndBytesConfig
1515
from accelerate import infer_auto_device_map, init_empty_weights
1616

1717

@@ -48,16 +48,19 @@ class ChatModel:
4848
human_id = "<human>"
4949
bot_id = "<bot>"
5050

51-
def __init__(self, model_name, gpu_id, max_memory):
51+
def __init__(self, model_name, gpu_id, max_memory, load_in_8bit):
5252
device = torch.device('cuda', gpu_id) # TODO: allow sending to cpu
5353

54+
quantization_config = BitsAndBytesConfig(
55+
load_in_8bit=load_in_8bit,
56+
llm_int8_enable_fp32_cpu_offload=True,
57+
) # config to load in 8-bit if load_in_8bit
58+
5459
# recommended default for devices with > 40 GB VRAM
5560
# load model onto one device
5661
if max_memory is None:
57-
self._model = AutoModelForCausalLM.from_pretrained(
58-
model_name, torch_dtype=torch.float16, device_map="auto")
59-
self._model.to(device)
60-
# load the model with the given max_memory config (for devices with insufficient VRAM or multi-gpu)
62+
device_map="auto"
63+
6164
else:
6265
config = AutoConfig.from_pretrained(model_name)
6366
# load empty weights
@@ -66,21 +69,24 @@ def __init__(self, model_name, gpu_id, max_memory):
6669

6770
model_from_conf.tie_weights()
6871

69-
# create a device_map from max_memory
72+
#create a device_map from max_memory
7073
device_map = infer_auto_device_map(
7174
model_from_conf,
7275
max_memory=max_memory,
7376
no_split_module_classes=["GPTNeoXLayer"],
74-
dtype="float16"
75-
)
76-
# load the model with the above device_map
77-
self._model = AutoModelForCausalLM.from_pretrained(
78-
model_name,
79-
device_map=device_map,
80-
offload_folder="offload", # optional offload-to-disk overflow directory (auto-created)
81-
offload_state_dict=True,
82-
torch_dtype=torch.float16
77+
dtype="float16",
8378
)
79+
80+
self._model = AutoModelForCausalLM.from_pretrained(
81+
model_name,
82+
torch_dtype=torch.float16,
83+
device_map=device_map,
84+
offload_folder="offload",
85+
quantization_config=quantization_config,
86+
)
87+
if not load_in_8bit:
88+
self._model.to(device) # not supported by load_in_8bit
89+
8490
self._tokenizer = AutoTokenizer.from_pretrained(model_name)
8591

8692
def do_inference(self, prompt, max_new_tokens, do_sample, temperature, top_k, stream_callback=None):
@@ -110,7 +116,7 @@ class OpenChatKitShell(cmd.Cmd):
110116
intro = "Welcome to OpenChatKit shell. Type /help or /? to list commands.\n"
111117
prompt = ">>> "
112118

113-
def __init__(self, gpu_id, model_name_or_path, max_tokens, sample, temperature, top_k, retrieval, max_memory, do_stream):
119+
def __init__(self, gpu_id, model_name_or_path, max_tokens, sample, temperature, top_k, retrieval, max_memory, do_stream, load_in_8bit):
114120
super().__init__()
115121
self._gpu_id = int(gpu_id)
116122
self._model_name_or_path = model_name_or_path
@@ -121,10 +127,11 @@ def __init__(self, gpu_id, model_name_or_path, max_tokens, sample, temperature,
121127
self._retrieval = retrieval
122128
self._max_memory = max_memory
123129
self._do_stream = do_stream
130+
self._load_in_8bit = load_in_8bit
124131

125132
def preloop(self):
126133
print(f"Loading {self._model_name_or_path} to cuda:{self._gpu_id}...")
127-
self._model = ChatModel(self._model_name_or_path, self._gpu_id, self._max_memory)
134+
self._model = ChatModel(self._model_name_or_path, self._gpu_id, self._max_memory, self._load_in_8bit)
128135

129136
if self._retrieval:
130137
print(f"Loading retrieval index...")
@@ -253,6 +260,13 @@ def main():
253260
help='max CPU RAM to allocate',
254261
required=False
255262
)
263+
# `pip install bitsandbytes` to use. No effect when used with -g or -r.
264+
parser.add_argument(
265+
'--load-in-8bit',
266+
default=False,
267+
action='store_true',
268+
help='indicates whether to load model in 8 bit'
269+
)
256270
args = parser.parse_args()
257271

258272
# set max_memory dictionary if given
@@ -278,6 +292,7 @@ def main():
278292
args.retrieval,
279293
max_memory,
280294
not args.no_stream,
295+
args.load_in_8bit,
281296
).cmdloop()
282297

283298

0 commit comments

Comments
 (0)