Skip to content

Commit

Permalink
Support custom embedding layer resizing to the desired multiple (foun…
Browse files Browse the repository at this point in the history
…dation-model-stack#227)

* feat: embedding layer as a multiple

Signed-off-by: Mehant Kammakomati <[email protected]>

* docs: add new argument docs

Signed-off-by: Mehant Kammakomati <[email protected]>

* fix: allow resizing embedding layers for peft

Signed-off-by: Mehant Kammakomati <[email protected]>

* fix: naming and error type

Signed-off-by: Mehant Kammakomati <[email protected]>

* feat: add test case for embedding resize

Signed-off-by: Mehant Kammakomati <[email protected]>

---------

Signed-off-by: Mehant Kammakomati <[email protected]>
  • Loading branch information
kmehant authored Jul 5, 2024
1 parent 06e8cbc commit 3be607b
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 13 deletions.
26 changes: 23 additions & 3 deletions scripts/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@
import os

# Third Party
from peft import AutoPeftModelForCausalLM
from peft import PeftModel
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Local
from tuning.data import tokenizer_data_utils


### Utilities
class AdapterConfigPatcher:
Expand Down Expand Up @@ -178,14 +181,31 @@ def load(
try:
with AdapterConfigPatcher(checkpoint_path, overrides):
try:
model = AutoPeftModelForCausalLM.from_pretrained(
if base_model_name_or_path is None:
raise ValueError("base_model_name_or_path has to be passed")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name_or_path,
attn_implementation="flash_attention_2"
if use_flash_attn
else None,
torch_dtype=torch.bfloat16 if use_flash_attn else None,
)
# since the peft library (PEFTModelForCausalLM) does not handle cases
# where the model's layers are modified, in our case the embedding layer
# is modified, so we resize the backbone model's embedding layer with our own
# utility before passing it along to load the PEFT model.
tokenizer_data_utils.tokenizer_and_embedding_resize(
{}, tokenizer=tokenizer, model=base_model
)
model = PeftModel.from_pretrained(
base_model,
checkpoint_path,
attn_implementation="flash_attention_2"
if use_flash_attn
else None,
torch_dtype=torch.bfloat16 if use_flash_attn else None,
)
except OSError as e:
except (OSError, ValueError) as e:
print("Failed to initialize checkpoint model!")
raise e
except FileNotFoundError:
Expand Down
8 changes: 4 additions & 4 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def test_run_causallm_pt_and_inference():
_validate_adapter_config(adapter_config, "PROMPT_TUNING", PEFT_PT_ARGS)

# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path)
loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME)

# Run inference on the text
output_inference = loaded_model.run(
Expand Down Expand Up @@ -211,7 +211,7 @@ def test_run_causallm_pt_and_inference_with_formatting_data():
_validate_adapter_config(adapter_config, "PROMPT_TUNING", PEFT_PT_ARGS)

# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path)
loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME)

# Run inference on the text
output_inference = loaded_model.run(
Expand Down Expand Up @@ -242,7 +242,7 @@ def test_run_causallm_pt_and_inference_JSON_file_formatter():
_validate_adapter_config(adapter_config, "PROMPT_TUNING", PEFT_PT_ARGS)

# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path)
loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME)

# Run inference on the text
output_inference = loaded_model.run(
Expand Down Expand Up @@ -370,7 +370,7 @@ def test_run_causallm_lora_and_inference(request, target_modules, expected):
assert module in adapter_config.get("target_modules")

# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path)
loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME)

# Run inference on the text
output_inference = loaded_model.run(
Expand Down
76 changes: 76 additions & 0 deletions tests/utils/test_embedding_resize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

# Third Party
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Local
from tuning.data import tokenizer_data_utils

MODEL_NAME = "Maykeye/TinyLLama-v0"


def _inference(
tokenizer: AutoTokenizer,
model: AutoModelForCausalLM,
input_text: str,
max_new_tokens: int,
) -> str:
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenized_input = tokenizer(input_text, return_tensors="pt").to(device)
generated_output = model.generate(
**tokenized_input,
max_new_tokens=max_new_tokens,
)
return tokenizer.decode(generated_output[0], skip_special_tokens=True)


def test_output_unaltered_across_embedding_resizes():
input_text = "### Text: @NortonSupport Thanks much.\n\n### Label:"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model_not_resized = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
model_resized = AutoModelForCausalLM.from_pretrained(MODEL_NAME)

tokenizer_data_utils.tokenizer_and_embedding_resize(
special_tokens_dict={}, tokenizer=tokenizer, model=model_resized, multiple_of=8
)

tokenizer_data_utils.tokenizer_and_embedding_resize(
special_tokens_dict={},
tokenizer=tokenizer,
model=model_not_resized,
multiple_of=1,
)

# embedding size of the resized model should be a multiple of 8
assert model_resized.get_output_embeddings().out_features % 8 == 0

output_from_model_not_resized = _inference(
model=model_not_resized,
tokenizer=tokenizer,
input_text=input_text,
max_new_tokens=50,
)
output_from_model_resized = _inference(
model=model_not_resized,
tokenizer=tokenizer,
input_text=input_text,
max_new_tokens=50,
)

assert output_from_model_not_resized == output_from_model_resized
7 changes: 7 additions & 0 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ class ModelArguments:
metadata={"help": "Use Flash attention v2 from transformers, default is True"},
)
torch_dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16
embedding_size_multiple_of: Optional[int] = field(
default=8,
metadata={
"help": "Resize model embedding layer to the nearest multiple of \
the given number after tokenizer modifications."
},
)


@dataclass
Expand Down
12 changes: 6 additions & 6 deletions tuning/data/tokenizer_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# Standard
from typing import Dict
import math

# Third Party
import transformers
Expand All @@ -23,14 +24,13 @@ def tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
multiple_of: int = 8,
):
"""Resize tokenizer and embedding.
TODO: In the future, make sure we can have vocab size divisible by 64.
"""
"""Resize tokenizer and embedding."""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))

embedding_size = int(multiple_of * math.ceil(len(tokenizer) / multiple_of))
num_new_tokens = num_new_tokens + embedding_size - len(tokenizer)
model.resize_token_embeddings(embedding_size)
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
Expand Down
1 change: 1 addition & 0 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def train(
special_tokens_dict=special_tokens_dict,
tokenizer=tokenizer,
model=model,
multiple_of=model_args.embedding_size_multiple_of,
)

# Configure the collator and validate args related to packing prior to formatting the dataset
Expand Down

0 comments on commit 3be607b

Please sign in to comment.