From 3be607bed91ca50ebc7ac16a8bf8bae7d3201545 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 5 Jul 2024 20:38:27 +0530 Subject: [PATCH] Support custom embedding layer resizing to the desired multiple (#227) * feat: embedding layer as a multiple Signed-off-by: Mehant Kammakomati * docs: add new argument docs Signed-off-by: Mehant Kammakomati * fix: allow resizing embedding layers for peft Signed-off-by: Mehant Kammakomati * fix: naming and error type Signed-off-by: Mehant Kammakomati * feat: add test case for embedding resize Signed-off-by: Mehant Kammakomati --------- Signed-off-by: Mehant Kammakomati --- scripts/run_inference.py | 26 ++++++++-- tests/test_sft_trainer.py | 8 +-- tests/utils/test_embedding_resize.py | 76 ++++++++++++++++++++++++++++ tuning/config/configs.py | 7 +++ tuning/data/tokenizer_data_utils.py | 12 ++--- tuning/sft_trainer.py | 1 + 6 files changed, 117 insertions(+), 13 deletions(-) create mode 100644 tests/utils/test_embedding_resize.py diff --git a/scripts/run_inference.py b/scripts/run_inference.py index 70820049e..d64bf926b 100644 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -28,11 +28,14 @@ import os # Third Party -from peft import AutoPeftModelForCausalLM +from peft import PeftModel from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer import torch +# Local +from tuning.data import tokenizer_data_utils + ### Utilities class AdapterConfigPatcher: @@ -178,14 +181,31 @@ def load( try: with AdapterConfigPatcher(checkpoint_path, overrides): try: - model = AutoPeftModelForCausalLM.from_pretrained( + if base_model_name_or_path is None: + raise ValueError("base_model_name_or_path has to be passed") + base_model = AutoModelForCausalLM.from_pretrained( + base_model_name_or_path, + attn_implementation="flash_attention_2" + if use_flash_attn + else None, + torch_dtype=torch.bfloat16 if use_flash_attn else None, + ) + # since the peft library (PEFTModelForCausalLM) does not handle cases + # where the model's layers are modified, in our case the embedding layer + # is modified, so we resize the backbone model's embedding layer with our own + # utility before passing it along to load the PEFT model. + tokenizer_data_utils.tokenizer_and_embedding_resize( + {}, tokenizer=tokenizer, model=base_model + ) + model = PeftModel.from_pretrained( + base_model, checkpoint_path, attn_implementation="flash_attention_2" if use_flash_attn else None, torch_dtype=torch.bfloat16 if use_flash_attn else None, ) - except OSError as e: + except (OSError, ValueError) as e: print("Failed to initialize checkpoint model!") raise e except FileNotFoundError: diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index d7508d8de..32f3e8140 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -178,7 +178,7 @@ def test_run_causallm_pt_and_inference(): _validate_adapter_config(adapter_config, "PROMPT_TUNING", PEFT_PT_ARGS) # Load the model - loaded_model = TunedCausalLM.load(checkpoint_path) + loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME) # Run inference on the text output_inference = loaded_model.run( @@ -211,7 +211,7 @@ def test_run_causallm_pt_and_inference_with_formatting_data(): _validate_adapter_config(adapter_config, "PROMPT_TUNING", PEFT_PT_ARGS) # Load the model - loaded_model = TunedCausalLM.load(checkpoint_path) + loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME) # Run inference on the text output_inference = loaded_model.run( @@ -242,7 +242,7 @@ def test_run_causallm_pt_and_inference_JSON_file_formatter(): _validate_adapter_config(adapter_config, "PROMPT_TUNING", PEFT_PT_ARGS) # Load the model - loaded_model = TunedCausalLM.load(checkpoint_path) + loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME) # Run inference on the text output_inference = loaded_model.run( @@ -370,7 +370,7 @@ def test_run_causallm_lora_and_inference(request, target_modules, expected): assert module in adapter_config.get("target_modules") # Load the model - loaded_model = TunedCausalLM.load(checkpoint_path) + loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME) # Run inference on the text output_inference = loaded_model.run( diff --git a/tests/utils/test_embedding_resize.py b/tests/utils/test_embedding_resize.py new file mode 100644 index 000000000..9a72f397b --- /dev/null +++ b/tests/utils/test_embedding_resize.py @@ -0,0 +1,76 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Third Party +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch + +# Local +from tuning.data import tokenizer_data_utils + +MODEL_NAME = "Maykeye/TinyLLama-v0" + + +def _inference( + tokenizer: AutoTokenizer, + model: AutoModelForCausalLM, + input_text: str, + max_new_tokens: int, +) -> str: + device = "cuda" if torch.cuda.is_available() else "cpu" + tokenized_input = tokenizer(input_text, return_tensors="pt").to(device) + generated_output = model.generate( + **tokenized_input, + max_new_tokens=max_new_tokens, + ) + return tokenizer.decode(generated_output[0], skip_special_tokens=True) + + +def test_output_unaltered_across_embedding_resizes(): + input_text = "### Text: @NortonSupport Thanks much.\n\n### Label:" + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + model_not_resized = AutoModelForCausalLM.from_pretrained(MODEL_NAME) + model_resized = AutoModelForCausalLM.from_pretrained(MODEL_NAME) + + tokenizer_data_utils.tokenizer_and_embedding_resize( + special_tokens_dict={}, tokenizer=tokenizer, model=model_resized, multiple_of=8 + ) + + tokenizer_data_utils.tokenizer_and_embedding_resize( + special_tokens_dict={}, + tokenizer=tokenizer, + model=model_not_resized, + multiple_of=1, + ) + + # embedding size of the resized model should be a multiple of 8 + assert model_resized.get_output_embeddings().out_features % 8 == 0 + + output_from_model_not_resized = _inference( + model=model_not_resized, + tokenizer=tokenizer, + input_text=input_text, + max_new_tokens=50, + ) + output_from_model_resized = _inference( + model=model_not_resized, + tokenizer=tokenizer, + input_text=input_text, + max_new_tokens=50, + ) + + assert output_from_model_not_resized == output_from_model_resized diff --git a/tuning/config/configs.py b/tuning/config/configs.py index bccf5d15b..a6a015f61 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -38,6 +38,13 @@ class ModelArguments: metadata={"help": "Use Flash attention v2 from transformers, default is True"}, ) torch_dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16 + embedding_size_multiple_of: Optional[int] = field( + default=8, + metadata={ + "help": "Resize model embedding layer to the nearest multiple of \ + the given number after tokenizer modifications." + }, + ) @dataclass diff --git a/tuning/data/tokenizer_data_utils.py b/tuning/data/tokenizer_data_utils.py index 7c314a187..62a615ace 100644 --- a/tuning/data/tokenizer_data_utils.py +++ b/tuning/data/tokenizer_data_utils.py @@ -14,6 +14,7 @@ # Standard from typing import Dict +import math # Third Party import transformers @@ -23,14 +24,13 @@ def tokenizer_and_embedding_resize( special_tokens_dict: Dict, tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, + multiple_of: int = 8, ): - """Resize tokenizer and embedding. - - TODO: In the future, make sure we can have vocab size divisible by 64. - """ + """Resize tokenizer and embedding.""" num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) - model.resize_token_embeddings(len(tokenizer)) - + embedding_size = int(multiple_of * math.ceil(len(tokenizer) / multiple_of)) + num_new_tokens = num_new_tokens + embedding_size - len(tokenizer) + model.resize_token_embeddings(embedding_size) if num_new_tokens > 0: input_embeddings = model.get_input_embeddings().weight.data output_embeddings = model.get_output_embeddings().weight.data diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index b221db898..bd7fcc4fe 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -228,6 +228,7 @@ def train( special_tokens_dict=special_tokens_dict, tokenizer=tokenizer, model=model, + multiple_of=model_args.embedding_size_multiple_of, ) # Configure the collator and validate args related to packing prior to formatting the dataset