Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
dchourasia committed Jun 13, 2024
2 parents 1368642 + 087e0c4 commit c5c0a99
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 64 deletions.
75 changes: 19 additions & 56 deletions build/accelerate_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
get_highest_checkpoint,
)
from tuning.utils.config_utils import get_json_config
from tuning.utils.merge_model_utils import create_merged_model
from tuning.config.tracker_configs import FileLoggingTrackerConfig
from tuning.utils.error_logging import (
write_termination_log,
Expand Down Expand Up @@ -117,61 +116,25 @@ def main():
write_termination_log(f"Unhandled exception during training. {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)

merge_model = False
if job_config.get("peft_method") == "lora":
merge_model = True

if merge_model:
try:
export_path = os.getenv(
"LORA_MERGE_MODELS_EXPORT_PATH", original_output_dir
)

# get the highest checkpoint dir (last checkpoint)
lora_checkpoint_dir = get_highest_checkpoint(tempdir)
full_checkpoint_dir = os.path.join(tempdir, lora_checkpoint_dir)

logging.info(
"Merging lora tuned checkpoint %s with base model into output path: %s",
lora_checkpoint_dir,
export_path,
)

# ensure checkpoint dir has correct files, important with multi-gpu tuning
if os.path.exists(
os.path.join(full_checkpoint_dir, "adapter_config.json")
):
create_merged_model(
checkpoint_models=full_checkpoint_dir,
export_path=export_path,
save_tokenizer=True,
)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(
f"Exception encountered merging base model with checkpoint. {e}"
)
sys.exit(INTERNAL_ERROR_EXIT_CODE)
else:
try:
# copy last checkpoint into mounted output dir
pt_checkpoint_dir = get_highest_checkpoint(tempdir)
logging.info(
"Copying last checkpoint %s into output dir %s",
pt_checkpoint_dir,
original_output_dir,
)
shutil.copytree(
os.path.join(tempdir, pt_checkpoint_dir),
original_output_dir,
dirs_exist_ok=True,
)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(
f"Exception encountered writing output model to storage: {e}"
)
sys.exit(INTERNAL_ERROR_EXIT_CODE)
try:
# copy last checkpoint into mounted output dir
pt_checkpoint_dir = get_highest_checkpoint(tempdir)
logging.info(
"Copying last checkpoint %s into output dir %s",
pt_checkpoint_dir,
original_output_dir,
)
shutil.copytree(
os.path.join(tempdir, pt_checkpoint_dir),
original_output_dir,
dirs_exist_ok=True,
)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(
f"Exception encountered writing output model to storage: {e}"
)
sys.exit(INTERNAL_ERROR_EXIT_CODE)

# copy over any loss logs
try:
Expand Down
65 changes: 57 additions & 8 deletions tests/build/test_launch_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# Standard
import os
import tempfile
import glob

# Third Party
import pytest
Expand All @@ -33,7 +34,7 @@

SCRIPT = "tuning/sft_trainer.py"
MODEL_NAME = "Maykeye/TinyLLama-v0"
BASE_PEFT_KWARGS = {
BASE_KWARGS = {
"model_name_or_path": MODEL_NAME,
"training_data_path": TWITTER_COMPLAINTS_DATA,
"num_train_epochs": 5,
Expand All @@ -52,13 +53,27 @@
"use_flash_attn": False,
"torch_dtype": "float32",
"max_seq_length": 4096,
"peft_method": "pt",
"prompt_tuning_init": "RANDOM",
"num_virtual_tokens": 8,
"prompt_tuning_init_text": "hello",
"tokenizer_name_or_path": MODEL_NAME,
"save_strategy": "epoch",
"output_dir": "tmp",
}
BASE_PEFT_KWARGS = {
**BASE_KWARGS,
**{
"peft_method": "pt",
"prompt_tuning_init": "RANDOM",
"num_virtual_tokens": 8,
"prompt_tuning_init_text": "hello",
"tokenizer_name_or_path": MODEL_NAME,
"save_strategy": "epoch",
"output_dir": "tmp",
},
}
BASE_LORA_KWARGS = {
**BASE_KWARGS,
**{
"peft_method": "lora",
"r": 8,
"lora_alpha": 32,
"lora_dropout": 0.05,
},
}


Expand All @@ -74,6 +89,22 @@ def cleanup_env():
os.environ.pop("TERMINATION_LOG_FILE", None)


def test_successful_ft():
"""Check if we can bootstrap and fine tune causallm models"""
with tempfile.TemporaryDirectory() as tempdir:
setup_env(tempdir)
TRAIN_KWARGS = {**BASE_KWARGS, **{"output_dir": tempdir}}
serialized_args = serialize_args(TRAIN_KWARGS)
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args

assert main() == 0
# check termination log and .complete files
assert os.path.exists(tempdir + "/termination-log") is False
assert os.path.exists(os.path.join(tempdir, ".complete")) is True
assert os.path.exists(tempdir + "/adapter_config.json") is False
assert len(glob.glob(f"{tempdir}/model*.safetensors")) > 0


def test_successful_pt():
"""Check if we can bootstrap and peft tune causallm models"""
with tempfile.TemporaryDirectory() as tempdir:
Expand All @@ -86,6 +117,24 @@ def test_successful_pt():
# check termination log and .complete files
assert os.path.exists(tempdir + "/termination-log") is False
assert os.path.exists(os.path.join(tempdir, ".complete")) is True
assert os.path.exists(tempdir + "/adapter_model.safetensors") is True
assert os.path.exists(tempdir + "/adapter_config.json") is True


def test_successful_lora():
"""Check if we can bootstrap and LoRA tune causallm models"""
with tempfile.TemporaryDirectory() as tempdir:
setup_env(tempdir)
TRAIN_KWARGS = {**BASE_LORA_KWARGS, **{"output_dir": tempdir}}
serialized_args = serialize_args(TRAIN_KWARGS)
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args

assert main() == 0
# check termination log and .complete files
assert os.path.exists(tempdir + "/termination-log") is False
assert os.path.exists(os.path.join(tempdir, ".complete")) is True
assert os.path.exists(tempdir + "/adapter_model.safetensors") is True
assert os.path.exists(tempdir + "/adapter_config.json") is True


def test_bad_script_path():
Expand Down

0 comments on commit c5c0a99

Please sign in to comment.