From 537215f6c5be6a124c422a07eae790c13be46c7e Mon Sep 17 00:00:00 2001 From: Sukriti Sharma Date: Mon, 29 Jul 2024 14:20:02 -0600 Subject: [PATCH] fix: remove lm_head for granite with llama arch models (#258) * initial code for deleting lm_head Signed-off-by: Anh-Uong * fix logic for copying checkpoint Signed-off-by: Anh-Uong * fix check that embed_tokens and lm_head weights are the same Signed-off-by: Anh-Uong * fix warning assertion Signed-off-by: Anh-Uong * fix lm_head check, remove test Signed-off-by: Anh-Uong * small fixes from code review Signed-off-by: Anh-Uong * fmt Signed-off-by: Anh-Uong --------- Signed-off-by: Anh-Uong Co-authored-by: Anh-Uong --- build/accelerate_launch.py | 103 +++++++++++++++++++++++++++++++++---- 1 file changed, 94 insertions(+), 9 deletions(-) diff --git a/build/accelerate_launch.py b/build/accelerate_launch.py index 9af5ad809..ee8718b5d 100644 --- a/build/accelerate_launch.py +++ b/build/accelerate_launch.py @@ -26,9 +26,13 @@ import tempfile import shutil from pathlib import Path +import json # Third Party from accelerate.commands.launch import launch_command +from transformers import AutoModelForCausalLM, AutoTokenizer +from peft import PeftModel +from torch import bfloat16 # Local from build.utils import ( @@ -44,10 +48,18 @@ USER_ERROR_EXIT_CODE, INTERNAL_ERROR_EXIT_CODE, ) +from tuning.data import tokenizer_data_utils ERROR_LOG = "/dev/termination-log" +def get_base_model_from_adapter_config(adapter_config): + """Given path to adapter_config.json file, returns the base model name""" + with open(adapter_config, "r", encoding="utf-8") as config_file: + adapter_config = json.load(config_file) + return adapter_config.get("base_model_name_or_path") + + def main(): LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper() logging.basicConfig(level=LOGLEVEL) @@ -118,16 +130,89 @@ def main(): sys.exit(INTERNAL_ERROR_EXIT_CODE) try: - # copy last checkpoint into mounted output dir - pt_checkpoint_dir = get_highest_checkpoint(tempdir) - logging.info( - "Copying last checkpoint %s into output dir %s", - pt_checkpoint_dir, - original_output_dir, - ) - copy_checkpoint( - os.path.join(tempdir, pt_checkpoint_dir), original_output_dir + last_checkpoint_dir = get_highest_checkpoint(tempdir) + last_checkpoint_path = os.path.join(tempdir, last_checkpoint_dir) + + use_flash_attn = job_config.get("use_flash_attn", True) + adapter_config_path = os.path.join( + last_checkpoint_path, "adapter_config.json" ) + tokenizer = AutoTokenizer.from_pretrained(last_checkpoint_path) + + if os.path.exists(adapter_config_path): + base_model_path = get_base_model_from_adapter_config( + adapter_config_path + ) + base_model = AutoModelForCausalLM.from_pretrained( + base_model_path, + attn_implementation="flash_attention_2" if use_flash_attn else None, + torch_dtype=bfloat16 if use_flash_attn else None, + ) + + # since the peft library (PEFTModelForCausalLM) does not handle cases + # where the model's layers are modified, in our case the embedding layer + # is modified, so we resize the backbone model's embedding layer with our own + # utility before passing it along to load the PEFT model. + tokenizer_data_utils.tokenizer_and_embedding_resize( + {}, tokenizer=tokenizer, model=base_model + ) + model = PeftModel.from_pretrained( + base_model, + last_checkpoint_path, + attn_implementation="flash_attention_2" if use_flash_attn else None, + torch_dtype=bfloat16 if use_flash_attn else None, + ) + else: + model = AutoModelForCausalLM.from_pretrained( + last_checkpoint_path, + attn_implementation="flash_attention_2" if use_flash_attn else None, + torch_dtype=bfloat16 if use_flash_attn else None, + ) + + model_arch = model.config.model_type + # check that it is a granite model with llama architecture with tied weights + # ie. lm_head is duplicate of embeddings + + # a fine tuned model will have params_dict.get("model.embed_tokens.weight") + # a prompt adapter has params_dict.get("base_model.model.embed_tokens.weight") + # a lora adapter has params_dict.get("base_model.model.model.embed_tokens.weight") + copy_checkpoint_bool = True + if model_arch == "llama" and hasattr(model, "lm_head"): + if ( + # lora tuned model has an addt model layer + ( + hasattr(model.model, "model") + and model.lm_head.weight.untyped_storage().data_ptr() + == model.model.model.embed_tokens.weight.untyped_storage().data_ptr() + ) + # prompt tuned model or fine tuned model + or ( + hasattr(model.model, "embed_tokens") + and model.lm_head.weight.untyped_storage().data_ptr() + == model.model.embed_tokens.weight.untyped_storage().data_ptr() + ) + ): + + copy_checkpoint_bool = False + logging.info("Removing lm_head from checkpoint") + del model.lm_head.weight + + if hasattr(model, "lm_head.weight"): + logging.warning("Failed to delete lm_head.weight from model") + + logging.info("Saving checkpoint to %s", original_output_dir) + model.save_pretrained(original_output_dir) + # save tokenizer with model + tokenizer.save_pretrained(original_output_dir) + + # copy last checkpoint into mounted output dir + if copy_checkpoint_bool: + logging.info( + "Copying last checkpoint %s into output dir %s", + last_checkpoint_dir, + original_output_dir, + ) + copy_checkpoint(last_checkpoint_path, original_output_dir) except Exception as e: # pylint: disable=broad-except logging.error(traceback.format_exc()) write_termination_log(