-
Notifications
You must be signed in to change notification settings - Fork 1
/
load_model.py
613 lines (527 loc) · 24 KB
/
load_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
import json
import logging
import os
from typing import List, Optional, Tuple, Union
import torch
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
BitsAndBytesConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
)
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
)
from transformers.utils import is_ipex_available
def find_all_linear_names(model, quantization: Optional[int] = None):
if quantization is None:
cls = torch.nn.Linear
elif quantization == 4:
from bitsandbytes.nn import Linear4bit
cls = Linear4bit
elif quantization == 8:
from bitsandbytes.nn import Linear8bitLt
cls = Linear8bitLt
else:
raise ValueError(f"Unknown quantization type: {quantization}")
lora_module_names = set()
for name, module in model.named_modules():
if isinstance(module, cls):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
return list(lora_module_names)
def get_trainable_parameters(model: PreTrainedModel) -> Tuple[int, int, float]:
"""
Prints the number of trainable parameters in the model.
Args:
model (`PreTrainedModel`):
The model to print the number of trainable parameters for.
Returns:
`Tuple[int, int, float]`:
The number of trainable parameters, the total number of parameters and the
percentage of trainable parameters.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
all_param += num_params
if param.requires_grad:
trainable_params += num_params
return trainable_params, all_param, 100 * trainable_params / all_param
def get_device_map(
force_auto_device_map: bool,
max_memory_MB: int = None,
use_better_transformer: bool = False,
) -> (str, Union[int, List[int]]):
"""
Get the device map to use for loading the model
Args:
force_auto_device_map (`bool`):
Whether to force the use of the auto device map. If set to True, the model will be split across
GPUs and CPU to fit the model in memory. If set to False, a full copy of the model will be loaded
into each GPU.
max_memory_MB (`int`):
Free memory per gpu in MB. Used to compute the device map when force_auto_device_map is set to True.
use_better_transformer (`bool`, optional):
Whether to transform the model using Better Transformer library:
https://huggingface.co/docs/optimum/bettertransformer/overview. Requires optimum
'https://huggingface.co/docs/optimum/installation'. Defaults to False.
Returns:
`str`:
The device map to use for loading the model
"""
if force_auto_device_map:
if os.environ.get("LOCAL_RANK") is not None:
# raise ValueError(
# "Found DDP environment and force_auto_device_map is set to True, this configuration "
# "is not supported. If you want to use DPP, set force_auto_device_map to False, so "
# "a copy of the model is loaded in each GPU. If you want the split the model across "
# "GPUs (force_auto_device_map=True), do not use DDP (launch your script with "
# "pyton -m src/run.py config.json). If you are not in a DDP environment but you see "
# "this error, you might have manually set the environment variable 'LOCAL_WORLD_SIZE' to a "
# "number different than 1, please, remove this environment variable or set it to 1"
# )
if torch.cuda.is_available():
n_gpus = torch.cuda.device_count()
elif is_ipex_available() and torch.xpu.is_available():
n_gpus = torch.xpu.device_count()
else:
logging.warning(
"You are in a DDP environment but no GPU is available, this may cause errors later on"
)
n_gpus = 0
max_memory = {i: max_memory_MB for i in range(n_gpus)}
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
device_map = {"": local_rank}
max_memory = (
{"": max_memory[local_rank]} if max_memory_MB is not None else None
)
else:
logging.warning(
"Using auto device map, we will split the model across GPUs and CPU to fit the model in memory."
)
device_map = "auto"
max_memory = max_memory_MB
else:
max_memory = None
word_size = int(os.environ.get("LOCAL_WORLD_SIZE", 1))
if word_size > 1:
logging.warning(
"Found DDP environment and force_auto_device_map is set to False, we will load a copy of the model "
"on each GPU."
)
device_map = None # {"": int(os.environ.get("LOCAL_RANK", 0))}
else:
if not use_better_transformer:
device_map = None
else:
logging.warning(
"Setting device map to 'auto' to use Better Transformers library."
)
device_map = "auto"
logging.info(
f"We will load the model using the following device map: {device_map} and max_memory: {max_memory}"
)
return device_map, max_memory
def merge_lora_model(
weights_path: str,
lora_weights_name_or_path: str,
output_path: str,
torch_dtype: Optional[str] = None,
):
"""
Given a model path and the path to the LoRA weights, merge the LoRA weights into the model and save the merged model
weights_path (`str`):
The path to your local model weights and tokenizer. You can also provide a
huggingface hub model name.
lora_weights_name_or_path (`str`):
If the model has been trained with LoRA, path or huggingface hub name to the
pretrained weights. Defaults to `None`.
output_path (`str`):
The path to the output directory where the merged model will be saved.
torch_dtype (`Optional[str]`, optional):
The torch dtype to use for the model. If set to `"auto"`, the dtype will be
automatically derived.
"""
logging.info(
f"We will merge the LoRA weights from {lora_weights_name_or_path} into the model {weights_path}"
)
model, tokenizer = load_model(
inference=True,
model_weights_name_or_path=weights_path,
use_lora=True,
lora_weights_name_or_path=lora_weights_name_or_path,
torch_dtype=torch_dtype,
)
model.config.save_pretrained(output_path)
model.save_pretrained(output_path)
tokenizer.save_pretrained(output_path)
logging.info(f"Model merged and saved in {output_path}")
def find_end_turn_token(tokenizer: PreTrainedTokenizerBase) -> int:
"""
Find the end of turn token in the tokenizer chat template.
If the model does not have a chat template, the end of turn token will be the eos token.
Args:
tokenizer (`PreTrainedTokenizerBase`):
The tokenizer to find the end of turn token in.
Returns:
`int`:
The end of turn token.
"""
if tokenizer.chat_template is None:
return tokenizer.eos_token_id
conversation = tokenizer.apply_chat_template(
[
{"role": "user", "content": "a"},
{"role": "assistant", "content": "a"},
],
tokenize=True,
add_special_tokens=False,
)
end_turn_token = tokenizer.eos_token_id
for token in conversation[::-1]:
if tokenizer.decode(token) == "a":
break
end_turn_token = token
return end_turn_token
def load_model(
inference: bool,
model_weights_name_or_path: str,
add_labels_as_tokens: bool = False,
labels: List[str] = None,
quantization: Optional[int] = None,
use_lora: bool = False,
lora_weights_name_or_path: Optional[str] = None,
lora_target_modules: Optional[List[str]] = None,
lora_r: Optional[int] = 8,
lora_alpha: Optional[int] = 16,
lora_dropout: Optional[float] = 0.05,
torch_dtype: Optional[str] = None,
force_auto_device_map: bool = False,
use_gradient_checkpointing: bool = False,
trust_remote_code: bool = False,
use_flash_attention: bool = False,
use_better_transformer: bool = False,
fsdp_training: bool = False,
max_memory_MB: Optional[int] = None,
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase, str]:
"""
Load any Decoder model for training.
Args:
inference (`bool`):
Whether to load the model for inference or training. If set to `True`, the model will be loaded
in evaluation mode. In this case, if use_lora is set to `True`, you must provide the path to the
LoRA weights. Defaults to `False`.
model_weights_name_or_path (`str`):
The path to your local model weights and tokenizer or huggingface model name.
add_labels_as_tokens (`bool`, optional):
Whether to add the labels as tokens to the tokenizer. Defaults to `False`.
labels (`List[str]`, optional):
The list of labels to add to the tokenizer. Defaults to `None`.
quantization (`int`, optional):
'4' or '8' for 4 bits or 8 bits quantization or None for 16/32bits training. Defaults to `None`.
Requires bitsandbytes library: https://github.com/TimDettmers/bitsandbytes
use_lora (`bool`, optional):
Whether to use LORA. Defaults to False.
See https://arxiv.org/pdf/2106.09685.pdf for more details.
Requires huggingface PEFT library: https://github.com/huggingface/peft
lora_weights_name_or_path (`Optional[str]`, optional):
The name or path to the pre-trained LORA model weights. You can also provide
a huggingface hub model name to load the weights from there. If not provided,
the weights will be initialized randomly, this requires training the model.
Defaults to `None`.
lora_target_modules (`Optional[List[str]]`, optional):
The list of modules to apply LORA to. If not provided, we will use PEFT
default modules. Defaults to `None`.
lora_r (`Optional[int]`, optional):
Lora attention dimension. Defaults to `8`.
lora_alpha (`Optional[int]`, optional):
The alpha parameter for Lora scaling. Defaults to `16`.
lora_dropout (`Optional[float]`, optional):
The dropout probability for Lora layers. Defaults to 0.05.
torch_dtype (`Optional[str]`, optional):
Override the default `torch.dtype` and load the model under this dtype. If
`auto` is passed, the dtype will be automatically derived from the model's
weights. Defaults to `None`.
force_auto_device_map (`bool`, optional):
Whether to force the use of the auto device map. If set to True, the model will be split across
GPUs and CPU to fit the model in memory. If set to False, a full copy of the model will be loaded
into each GPU. Defaults to False.
use_gradient_checkpointing (`bool`, optiona):
Whether to use gradient checkpointing for training
trust_remote_code (`bool`, optional):
Trust the remote code from HuggingFace model hub. Defaults to False.
use_flash_attention (`bool`, optional):
Whether to use Flash Attention. Defaults to True. Flash attention must be installed, see:
'https://github.com/Dao-AILab/flash-attention' for more details.
use_better_transformer (`bool`, optional):
Whether to transform the model using Better Transformer library:
https://huggingface.co/docs/optimum/bettertransformer/overview. Requires optimum
'https://huggingface.co/docs/optimum/installation'. Only supported for inference!
Defaults to False.
fsdp_training: (`bool`, optional):
Whether Fully Sharded Data Parallelism is enabled for training. Defaults to False.
Used to prevent casting layers to fp32 if the model is already in fp16, which causes
an error: ValueError: Must flatten tensors with uniform dtype but got torch.float16 and torch.float32
max_memory_MB (`int`):
Free memory per gpu in MB. Used to compute the device map when force_auto_device_map is set to True.
Raises:
`ValueError`:
is raised when `int8_quantization=True` but `use_lora=False`.
Returns:
`Tuple[PreTrainedModel, PreTrainedTokenizerBase, str]`:
The loaded model, tokenizer and model_type.
"""
# Sanity checks
if isinstance(quantization, str):
quantization = int(quantization)
assert (
(quantization is None) or (quantization in [4, 8])
), f"Quantization must be 4 or 8, or None for FP32/FP16 training. You passed: {quantization}"
if not inference and quantization is not None and not use_lora:
raise ValueError(
"'Quantization' == 4/8 is only supported with LoRA. If you want "
"to train a 4/8bits quantified model, you must set `use_lora=True`. If you want to "
"use a 4/8 bits optimizer, set `quantization=None` and choose a 4/8 bit optimizer using 'optim' "
"argument (e.g 'adamw_bnb_8bit', 'lion_8bit', 'paged_adamw_8bit', ...)."
)
if inference and use_lora and lora_weights_name_or_path is None:
raise ValueError(
"You must provide the path to the LoRA weights when loading the model for inference."
)
if use_better_transformer and not inference:
logging.warning(
"Better Transformer is only supported for inference. Better Transformers does not support "
"attention mask for training, therefore it is not compatible with CoLLIE training. See "
"https://huggingface.co/docs/optimum/bettertransformer/overview for more details. We will "
"set use_better_transformer=False."
)
use_better_transformer = False
if use_better_transformer and use_flash_attention:
raise ValueError(
"You cannot use both Flash Attention and Better Transformer flags. Flash Attention is already part of"
" Better Transformers, so you can just set use_better_transformer=True to use Flash Attention. The Flash"
" Attention flag is intended for patching HuggingFace models."
)
if lora_weights_name_or_path is not None and not use_lora:
logging.warning(
"You provided a path to LoRA weights but use_lora is set to False. We will set use_lora=True."
)
use_lora = True
logging.info(f"Loading model model from {model_weights_name_or_path}")
# Get the device map config
device_map, max_memory = get_device_map(
force_auto_device_map=force_auto_device_map,
max_memory_MB=max_memory_MB,
use_better_transformer=use_better_transformer,
)
# Load the model config
if use_lora:
config = AutoConfig.from_pretrained(
model_weights_name_or_path,
trust_remote_code=trust_remote_code,
pretraining_tp=1, # Fix mat1 and mat2 shapes cannot be multiplied error with LLaMA-2
# See https://github.com/huggingface/transformers/pull/24906
)
else:
config = AutoConfig.from_pretrained(
model_weights_name_or_path,
trust_remote_code=trust_remote_code,
)
# Load the model tokenizer
tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
model_weights_name_or_path,
add_eos_token=True,
trust_remote_code=trust_remote_code,
legacy=True, # This library was developed with the legacy tokenizer.
# It might or might not work with the latest updates to the T5 tokenizers. So we set legacy=True to be safe.
)
if tokenizer.pad_token_id is None:
if "<|padding|>" in tokenizer.get_vocab():
# StabilityLM specific fix
tokenizer.add_special_tokens({"pad_token": "<|padding|>"})
elif tokenizer.unk_token is not None:
logging.warning(
"Tokenizer does not have a pad token, we will use the unk token as pad token."
)
tokenizer.pad_token_id = tokenizer.unk_token_id
else:
logging.warning(
"Tokenizer does not have a pad token. We will use the eos token as pad token."
)
tokenizer.pad_token_id = tokenizer.eos_token_id
# Load the model weights
# Get the quantization config
torch_dtype = (
torch_dtype if torch_dtype in ["auto", None] else getattr(torch, torch_dtype)
)
if quantization is not None:
if quantization == 4:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
if torch_dtype in ["auto", None]
else torch_dtype,
)
else:
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
)
logging.info(
f"Bits and Bytes config: {json.dumps(bnb_config.to_dict(),indent=4,ensure_ascii=False)}"
)
else:
logging.info(f"Loading model with dtype: {torch_dtype}")
bnb_config = None
# Get the correct load function for each model_type
if config.model_type in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES:
logging.warning(
f"Model {model_weights_name_or_path} is a encoder-decoder model. We will load it as a Seq2SeqLM model."
)
load_fn = AutoModelForSeq2SeqLM
model_type = "seq2seq"
elif config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
logging.warning(
f"Model {model_weights_name_or_path} is an decoder-only model. We will load it as a CausalLM model."
)
load_fn = AutoModelForCausalLM
tokenizer.padding_side = "left"
model_type = "causal"
else:
logging.warning(
f"Model {model_weights_name_or_path} is not in the "
f"MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES or MODEL_FOR_CAUSAL_LM_MAPPING_NAMES. "
f"We will attempt load it as a CausalLM model. This will fail if the model is not a CausalLM model."
)
load_fn = AutoModelForCausalLM
tokenizer.padding_side = "left"
model_type = "causal"
# Load the model weights
# Flash attention 2 was added to HuggingFace transformers very recently. Let's add it as kwargs to the load function
# so if it is set to False, we can load the model in older versions of transformers.
if use_flash_attention:
kwargs = {"use_flash_attention_2": True}
logging.info("Loading the model with flash attention 2")
else:
kwargs = {}
logging.info(
"Loading the model without flash attention. If the model supports it, "
"you can enable it by addding 'use_flash_attention: true' to "
"your config file."
)
logging.info(
"Loading model with config:\n"
f"pretrained_model_name_or_path: {model_weights_name_or_path}\n"
f"device_map: {device_map}\n"
f"max_memory: {max_memory}\n"
f"quantization_config: {bnb_config}\n"
f"torch_dtype: {torch_dtype}\n"
f"config: {config}\n"
f"trust_remote_code: {trust_remote_code}\n"
f"kwargs: {kwargs}\n"
)
model: PreTrainedModel = load_fn.from_pretrained(
pretrained_model_name_or_path=model_weights_name_or_path,
device_map=device_map,
max_memory=max_memory,
quantization_config=bnb_config,
torch_dtype=torch_dtype,
config=config,
trust_remote_code=trust_remote_code,
**kwargs,
)
if add_labels_as_tokens:
print(f"Adding labels as tokens: {labels}")
print(f"Model has {len(tokenizer)} tokens before adding labels.")
tokenizer.add_tokens(labels)
model.resize_token_embeddings(len(tokenizer))
print(f"Model has {len(tokenizer)} tokens after adding labels.")
logging.info(f"Model dtype: {model.dtype}")
logging.info(
"Total model memory footprint: "
+ str(model.get_memory_footprint() / 1e6)
+ " MB"
)
# Prepare the model for k-bit training and enable gradient checkpointing
if quantization is not None and not inference:
from peft import prepare_model_for_kbit_training
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=use_gradient_checkpointing
)
else:
if use_gradient_checkpointing and not inference:
model.gradient_checkpointing_enable()
# Load LoRA weights
if use_lora:
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
if not inference:
model.enable_input_require_grads() # Enables the gradients for the input embeddings
if lora_weights_name_or_path is None:
logging.info(
"No pretrained LORA weights provided, we will initialize the weights randomly."
)
if lora_target_modules is None or (
lora_target_modules is not None and len(lora_target_modules) == 0
):
logging.warning(
"No target modules provided, will use the default modules for the"
" model in huggingface PEFT library. "
)
lora_target_modules = None
if lora_target_modules == ["all"]:
logging.warning(
"You provided 'all' as target modules, we will use all the model to which LoRA can be applied."
)
lora_target_modules = find_all_linear_names(
model, quantization=quantization
)
logging.warning(f"LoRA target modules: {lora_target_modules}")
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
bias="none",
task_type=TaskType.CAUSAL_LM
if model_type == "causal"
else TaskType.SEQ_2_SEQ_LM,
target_modules=lora_target_modules,
modules_to_save=["embed_tokens", "lm_head"],
)
model = get_peft_model(model, lora_config)
else:
logging.info(
f"Loading pretrained LORA weights from {lora_weights_name_or_path}"
)
model = PeftModel.from_pretrained(model, lora_weights_name_or_path)
logging.info(f"\nLoRA config:\n{model.peft_config}\n")
if inference:
if quantization is None and use_lora:
# If we are not using quantization, we merge the LoRA layers into the model for faster inference.
# This is not possible if we are using 4/8 bit quantization.
logging.info("Merging LoRA layers into the model for faster inference.")
model = model.merge_and_unload()
else:
logging.info(
"Quantization is enabled, we will not merge LoRA layers into the model. Inference will be slower."
)
else:
trainable_params, total_params, trainable_percentage = get_trainable_parameters(
model
)
logging.info(
f"---> Trainable params: {trainable_params} || all params: {total_params} ||"
f" trainable%: {round(trainable_percentage,6)}\n"
)
return model, tokenizer, model_type