Skip to content

Commit ad0476c

Browse files
committed
Add argument to load model in 8bit
1 parent 71dd823 commit ad0476c

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@ dependencies:
2424
- datasets==2.10.1
2525
- loguru==0.6.0
2626
- netifaces==0.11.0
27-
- transformers==4.21.1
27+
- transformers==4.27.4
2828
- wandb==0.13.10
2929
- zstandard==0.20.0

inference/bot.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,16 @@ class ChatModel:
4848
human_id = "<human>"
4949
bot_id = "<bot>"
5050

51-
def __init__(self, model_name, gpu_id, max_memory):
51+
def __init__(self, model_name, gpu_id, max_memory, load_in_8bit):
5252
device = torch.device('cuda', gpu_id) # TODO: allow sending to cpu
5353

5454
# recommended default for devices with > 40 GB VRAM
5555
# load model onto one device
5656
if max_memory is None:
5757
self._model = AutoModelForCausalLM.from_pretrained(
58-
model_name, torch_dtype=torch.float16, device_map="auto")
59-
self._model.to(device)
58+
model_name, torch_dtype=torch.float16, device_map="auto", load_in_8bit=load_in_8bit)
59+
if not load_in_8bit:
60+
self._model.to(device) # not supported by load_in_8bit
6061
# load the model with the given max_memory config (for devices with insufficient VRAM or multi-gpu)
6162
else:
6263
config = AutoConfig.from_pretrained(model_name)
@@ -110,7 +111,7 @@ class OpenChatKitShell(cmd.Cmd):
110111
intro = "Welcome to OpenChatKit shell. Type /help or /? to list commands.\n"
111112
prompt = ">>> "
112113

113-
def __init__(self, gpu_id, model_name_or_path, max_tokens, sample, temperature, top_k, retrieval, max_memory, do_stream):
114+
def __init__(self, gpu_id, model_name_or_path, max_tokens, sample, temperature, top_k, retrieval, max_memory, do_stream, load_in_8bit):
114115
super().__init__()
115116
self._gpu_id = int(gpu_id)
116117
self._model_name_or_path = model_name_or_path
@@ -121,10 +122,11 @@ def __init__(self, gpu_id, model_name_or_path, max_tokens, sample, temperature,
121122
self._retrieval = retrieval
122123
self._max_memory = max_memory
123124
self._do_stream = do_stream
125+
self._load_in_8bit = load_in_8bit
124126

125127
def preloop(self):
126128
print(f"Loading {self._model_name_or_path} to cuda:{self._gpu_id}...")
127-
self._model = ChatModel(self._model_name_or_path, self._gpu_id, self._max_memory)
129+
self._model = ChatModel(self._model_name_or_path, self._gpu_id, self._max_memory, self._load_in_8bit)
128130

129131
if self._retrieval:
130132
print(f"Loading retrieval index...")
@@ -253,6 +255,13 @@ def main():
253255
help='max CPU RAM to allocate',
254256
required=False
255257
)
258+
# `pip install bitsandbytes` to use. No effect when used with -g or -r.
259+
parser.add_argument(
260+
'--load-in-8bit',
261+
default=False,
262+
action='store_true',
263+
help='indicates whether to load model in 8 bit'
264+
)
256265
args = parser.parse_args()
257266

258267
# set max_memory dictionary if given
@@ -278,6 +287,7 @@ def main():
278287
args.retrieval,
279288
max_memory,
280289
not args.no_stream,
290+
args.load_in_8bit,
281291
).cmdloop()
282292

283293

0 commit comments

Comments
 (0)