Skip to content

Commit

Permalink
tests: tracker unit tests (foundation-model-stack#172)
Browse files Browse the repository at this point in the history
* Unit tests for the generic tracker API

Signed-off-by: Dushyant Behl <[email protected]>

* Add a happy path test for aim tracker which runs when it is installed

Signed-off-by: Dushyant Behl <[email protected]>

---------

Signed-off-by: Dushyant Behl <[email protected]>
  • Loading branch information
dushyantbehl authored Jul 15, 2024
1 parent 77a195d commit 9e4b49f
Show file tree
Hide file tree
Showing 6 changed files with 391 additions and 80 deletions.
258 changes: 237 additions & 21 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
# Third Party
from datasets.exceptions import DatasetGenerationError
from transformers.trainer_callback import TrainerCallback
from transformers.utils.import_utils import _is_package_available
import pytest
import torch
import transformers
Expand All @@ -40,6 +41,11 @@
# Local
from tuning import sft_trainer
from tuning.config import configs, peft_config
from tuning.config.tracker_configs import (
AimConfig,
FileLoggingTrackerConfig,
TrackerConfigFactory,
)

MODEL_NAME = "Maykeye/TinyLLama-v0"
MODEL_ARGS = configs.ModelArguments(
Expand Down Expand Up @@ -399,35 +405,42 @@ def test_run_causallm_lora_and_inference(request, target_modules, expected):


############################# Finetuning Tests #############################


def test_run_causallm_ft_and_inference():
"""Check if we can bootstrap and finetune tune causallm models"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
_test_run_causallm_ft(TRAIN_ARGS, MODEL_ARGS, DATA_ARGS, tempdir)
_test_run_inference(tempdir=tempdir)

sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)

# validate ft tuning configs
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)
############################# Helper functions #############################
def _test_run_causallm_ft(training_args, model_args, data_args, tempdir):
train_args = copy.deepcopy(training_args)
train_args.output_dir = tempdir
sft_trainer.train(model_args, data_args, train_args, None)

# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path)
# validate ft tuning configs
_validate_training(tempdir)

# Run inference on the text
output_inference = loaded_model.run(
"### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50
)
assert len(output_inference) > 0
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference

def _test_run_inference(tempdir):
checkpoint_path = _get_checkpoint_path(tempdir)

############################# Helper functions #############################
def _validate_training(tempdir, check_eval=False):
# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path)

# Run inference on the text
output_inference = loaded_model.run(
"### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50
)
assert len(output_inference) > 0
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference


def _validate_training(
tempdir, check_eval=False, train_logs_file="training_logs.jsonl"
):
assert any(x.startswith("checkpoint-") for x in os.listdir(tempdir))
train_logs_file_path = "{}/training_logs.jsonl".format(tempdir)
train_logs_file_path = "{}/{}".format(tempdir, train_logs_file)
train_log_contents = ""
with open(train_logs_file_path, encoding="utf-8") as f:
train_log_contents = f.read()
Expand Down Expand Up @@ -640,12 +653,215 @@ def test_run_with_additional_callbacks():
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
model_args = copy.deepcopy(MODEL_ARGS)

sft_trainer.train(
model_args,
MODEL_ARGS,
DATA_ARGS,
train_args,
PEFT_PT_ARGS,
additional_callbacks=[TrainerCallback()],
)


def test_run_with_bad_additional_callbacks():
"""Ensure that train() raises error with bad additional_callbacks"""

with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

with pytest.raises(
ValueError, match="additional callbacks should be of type TrainerCallback"
):
sft_trainer.train(
MODEL_ARGS,
DATA_ARGS,
train_args,
PEFT_PT_ARGS,
additional_callbacks=["NotSupposedToBeHere"],
)


def test_run_with_bad_experimental_metadata():
"""Ensure that train() throws error with bad experimental metadata"""

with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

metadata = "deadbeef"

with pytest.raises(
ValueError, match="exp metadata passed should be a dict with valid json"
):
sft_trainer.train(
MODEL_ARGS,
DATA_ARGS,
train_args,
PEFT_PT_ARGS,
additional_callbacks=[TrainerCallback()],
exp_metadata=metadata,
)


def test_run_with_good_experimental_metadata():
"""Ensure that train() can work with good experimental metadata"""

with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

metadata = {"dead": "beef"}

sft_trainer.train(
MODEL_ARGS,
DATA_ARGS,
train_args,
PEFT_PT_ARGS,
additional_callbacks=[TrainerCallback()],
exp_metadata=metadata,
)


#### Tracker subsystem checks


def test_run_with_bad_tracker_config():
"""Ensure that train() raises error with bad tracker configs"""

with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

with pytest.raises(
ValueError,
match="tracker configs should adhere to the TrackerConfigFactory type",
):
sft_trainer.train(
MODEL_ARGS,
DATA_ARGS,
train_args,
PEFT_PT_ARGS,
tracker_configs="NotSupposedToBeHere",
)


def test_run_with_bad_tracker_name():
"""Ensure that train() raises error with bad tracker name"""

with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

bad_name = "NotAValidTracker"
train_args.trackers = [bad_name]

# ensure bad tracker name gets called out
with pytest.raises(
ValueError, match=r"Requested Tracker {} not found.".format(bad_name)
):
sft_trainer.train(
MODEL_ARGS,
DATA_ARGS,
train_args,
PEFT_PT_ARGS,
)


def test_run_with_file_logging_tracker():
"""Ensure that training succeeds with a good tracker name"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.trackers = ["file_logger"]

_test_run_causallm_ft(TRAIN_ARGS, MODEL_ARGS, DATA_ARGS, tempdir)
_test_run_inference(tempdir=tempdir)


def test_sample_run_with_file_logger_updated_filename():
"""Ensure that file_logger filename can be updated"""

with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

train_args.trackers = ["file_logger"]

logs_file = "new_train_logs.jsonl"

tracker_configs = TrackerConfigFactory(
file_logger_config=FileLoggingTrackerConfig(
training_logs_filename=logs_file
)
)

sft_trainer.train(
MODEL_ARGS, DATA_ARGS, train_args, tracker_configs=tracker_configs
)

# validate ft tuning configs
_validate_training(tempdir, train_logs_file=logs_file)


is_aim_available = _is_package_available("aim")


@pytest.mark.skipif(
not is_aim_available,
reason="This test is required only if aim is installed"
" else see test_run_with_bad_tracker_name.",
)
def test_run_with_good_tracker_name_but_no_args():
"""Ensure that train() raises error with aim tracker name but no args"""

with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

train_args.trackers = ["aim"]

with pytest.raises(
ValueError,
match="Aim tracker requested but repo or server is not specified.",
):
sft_trainer.train(
MODEL_ARGS,
DATA_ARGS,
train_args,
PEFT_PT_ARGS,
)


@pytest.mark.skipif(
not is_aim_available,
reason="E2E happy path test for aim tracker."
" Runs only when aim tracker is installed.",
)
def test_sample_run_with_aim_tracker():
"""Ensure that training succeeds with aim tracker"""

with tempfile.TemporaryDirectory() as tempdir:
# setup aim in the tempdir
os.system("cd " + tempdir + " ; aim init")

train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

# This should not mean file logger is not present.
# code will add it by default
# The below validate_training check will test for that too.
train_args.trackers = ["aim"]

tracker_configs = TrackerConfigFactory(
aim_config=AimConfig(experiment="unit_test", aim_repo=tempdir + "/")
)

sft_trainer.train(
MODEL_ARGS, DATA_ARGS, train_args, tracker_configs=tracker_configs
)

# validate ft tuning configs
_validate_training(tempdir)

# validate inference
_test_run_inference(tempdir)
5 changes: 4 additions & 1 deletion tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import torch
import transformers

# Local
from tuning.trackers.tracker_factory import FILE_LOGGING_TRACKER

DEFAULT_CONTEXT_LENGTH = 4096
DEFAULT_OPTIMIZER = "adamw_torch"

Expand Down Expand Up @@ -126,7 +129,7 @@ class TrainingArguments(transformers.TrainingArguments):
},
)
trackers: Optional[List[str.lower]] = field(
default_factory=lambda: ["file_logger"],
default_factory=lambda: [FILE_LOGGING_TRACKER],
metadata={
"help": "Experiment trackers to use.\n"
+ "Available trackers are - file_logger(default), aim, none\n"
Expand Down
21 changes: 19 additions & 2 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
TrackerConfigFactory,
)
from tuning.data import tokenizer_data_utils
from tuning.trackers.tracker_factory import get_tracker
from tuning.trackers.tracker_factory import FILE_LOGGING_TRACKER, get_tracker
from tuning.trainercontroller import TrainerControllerCallback
from tuning.utils.config_utils import get_hf_peft_config, get_json_config
from tuning.utils.data_type_utils import get_torch_dtype
Expand Down Expand Up @@ -128,11 +128,23 @@ def train(
trackers = []
trainer_callbacks = []

if exp_metadata and (not isinstance(exp_metadata, dict)):
raise ValueError("exp metadata passed should be a dict with valid json")

if train_args.trackers is not None:
requested_trackers = set(train_args.trackers)
else:
requested_trackers = set()

# Ensure file logging is present
if FILE_LOGGING_TRACKER not in requested_trackers:
requested_trackers.add(FILE_LOGGING_TRACKER)

if not isinstance(tracker_configs, TrackerConfigFactory):
raise ValueError(
"tracker configs should adhere to the TrackerConfigFactory type"
)

# Now initialize trackers one by one
for name in requested_trackers:
t = get_tracker(name, tracker_configs)
Expand All @@ -152,7 +164,12 @@ def train(

# Add any extra callback if passed by users
if additional_callbacks is not None:
trainer_callbacks.extend(additional_callbacks)
for cb in additional_callbacks:
if not isinstance(cb, TrainerCallback):
raise ValueError(
"additional callbacks should be of type TrainerCallback"
)
trainer_callbacks.append(cb)

framework = AccelerationFrameworkConfig.from_dataclasses(
quantized_lora_config, fusedops_kernels_config
Expand Down
Loading

0 comments on commit 9e4b49f

Please sign in to comment.