diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 699188ca1..4f3c39418 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -24,6 +24,7 @@ # Third Party from datasets.exceptions import DatasetGenerationError from transformers.trainer_callback import TrainerCallback +from transformers.utils.import_utils import _is_package_available import pytest import torch import transformers @@ -40,6 +41,11 @@ # Local from tuning import sft_trainer from tuning.config import configs, peft_config +from tuning.config.tracker_configs import ( + AimConfig, + FileLoggingTrackerConfig, + TrackerConfigFactory, +) MODEL_NAME = "Maykeye/TinyLLama-v0" MODEL_ARGS = configs.ModelArguments( @@ -399,35 +405,42 @@ def test_run_causallm_lora_and_inference(request, target_modules, expected): ############################# Finetuning Tests ############################# - - def test_run_causallm_ft_and_inference(): """Check if we can bootstrap and finetune tune causallm models""" with tempfile.TemporaryDirectory() as tempdir: - train_args = copy.deepcopy(TRAIN_ARGS) - train_args.output_dir = tempdir + _test_run_causallm_ft(TRAIN_ARGS, MODEL_ARGS, DATA_ARGS, tempdir) + _test_run_inference(tempdir=tempdir) - sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None) - # validate ft tuning configs - _validate_training(tempdir) - checkpoint_path = _get_checkpoint_path(tempdir) +############################# Helper functions ############################# +def _test_run_causallm_ft(training_args, model_args, data_args, tempdir): + train_args = copy.deepcopy(training_args) + train_args.output_dir = tempdir + sft_trainer.train(model_args, data_args, train_args, None) - # Load the model - loaded_model = TunedCausalLM.load(checkpoint_path) + # validate ft tuning configs + _validate_training(tempdir) - # Run inference on the text - output_inference = loaded_model.run( - "### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50 - ) - assert len(output_inference) > 0 - assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference +def _test_run_inference(tempdir): + checkpoint_path = _get_checkpoint_path(tempdir) -############################# Helper functions ############################# -def _validate_training(tempdir, check_eval=False): + # Load the model + loaded_model = TunedCausalLM.load(checkpoint_path) + + # Run inference on the text + output_inference = loaded_model.run( + "### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50 + ) + assert len(output_inference) > 0 + assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference + + +def _validate_training( + tempdir, check_eval=False, train_logs_file="training_logs.jsonl" +): assert any(x.startswith("checkpoint-") for x in os.listdir(tempdir)) - train_logs_file_path = "{}/training_logs.jsonl".format(tempdir) + train_logs_file_path = "{}/{}".format(tempdir, train_logs_file) train_log_contents = "" with open(train_logs_file_path, encoding="utf-8") as f: train_log_contents = f.read() @@ -640,12 +653,215 @@ def test_run_with_additional_callbacks(): with tempfile.TemporaryDirectory() as tempdir: train_args = copy.deepcopy(TRAIN_ARGS) train_args.output_dir = tempdir - model_args = copy.deepcopy(MODEL_ARGS) sft_trainer.train( - model_args, + MODEL_ARGS, + DATA_ARGS, + train_args, + PEFT_PT_ARGS, + additional_callbacks=[TrainerCallback()], + ) + + +def test_run_with_bad_additional_callbacks(): + """Ensure that train() raises error with bad additional_callbacks""" + + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + with pytest.raises( + ValueError, match="additional callbacks should be of type TrainerCallback" + ): + sft_trainer.train( + MODEL_ARGS, + DATA_ARGS, + train_args, + PEFT_PT_ARGS, + additional_callbacks=["NotSupposedToBeHere"], + ) + + +def test_run_with_bad_experimental_metadata(): + """Ensure that train() throws error with bad experimental metadata""" + + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + metadata = "deadbeef" + + with pytest.raises( + ValueError, match="exp metadata passed should be a dict with valid json" + ): + sft_trainer.train( + MODEL_ARGS, + DATA_ARGS, + train_args, + PEFT_PT_ARGS, + additional_callbacks=[TrainerCallback()], + exp_metadata=metadata, + ) + + +def test_run_with_good_experimental_metadata(): + """Ensure that train() can work with good experimental metadata""" + + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + metadata = {"dead": "beef"} + + sft_trainer.train( + MODEL_ARGS, DATA_ARGS, train_args, PEFT_PT_ARGS, additional_callbacks=[TrainerCallback()], + exp_metadata=metadata, ) + + +#### Tracker subsystem checks + + +def test_run_with_bad_tracker_config(): + """Ensure that train() raises error with bad tracker configs""" + + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + with pytest.raises( + ValueError, + match="tracker configs should adhere to the TrackerConfigFactory type", + ): + sft_trainer.train( + MODEL_ARGS, + DATA_ARGS, + train_args, + PEFT_PT_ARGS, + tracker_configs="NotSupposedToBeHere", + ) + + +def test_run_with_bad_tracker_name(): + """Ensure that train() raises error with bad tracker name""" + + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + bad_name = "NotAValidTracker" + train_args.trackers = [bad_name] + + # ensure bad tracker name gets called out + with pytest.raises( + ValueError, match=r"Requested Tracker {} not found.".format(bad_name) + ): + sft_trainer.train( + MODEL_ARGS, + DATA_ARGS, + train_args, + PEFT_PT_ARGS, + ) + + +def test_run_with_file_logging_tracker(): + """Ensure that training succeeds with a good tracker name""" + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.trackers = ["file_logger"] + + _test_run_causallm_ft(TRAIN_ARGS, MODEL_ARGS, DATA_ARGS, tempdir) + _test_run_inference(tempdir=tempdir) + + +def test_sample_run_with_file_logger_updated_filename(): + """Ensure that file_logger filename can be updated""" + + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + train_args.trackers = ["file_logger"] + + logs_file = "new_train_logs.jsonl" + + tracker_configs = TrackerConfigFactory( + file_logger_config=FileLoggingTrackerConfig( + training_logs_filename=logs_file + ) + ) + + sft_trainer.train( + MODEL_ARGS, DATA_ARGS, train_args, tracker_configs=tracker_configs + ) + + # validate ft tuning configs + _validate_training(tempdir, train_logs_file=logs_file) + + +is_aim_available = _is_package_available("aim") + + +@pytest.mark.skipif( + not is_aim_available, + reason="This test is required only if aim is installed" + " else see test_run_with_bad_tracker_name.", +) +def test_run_with_good_tracker_name_but_no_args(): + """Ensure that train() raises error with aim tracker name but no args""" + + with tempfile.TemporaryDirectory() as tempdir: + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + train_args.trackers = ["aim"] + + with pytest.raises( + ValueError, + match="Aim tracker requested but repo or server is not specified.", + ): + sft_trainer.train( + MODEL_ARGS, + DATA_ARGS, + train_args, + PEFT_PT_ARGS, + ) + + +@pytest.mark.skipif( + not is_aim_available, + reason="E2E happy path test for aim tracker." + " Runs only when aim tracker is installed.", +) +def test_sample_run_with_aim_tracker(): + """Ensure that training succeeds with aim tracker""" + + with tempfile.TemporaryDirectory() as tempdir: + # setup aim in the tempdir + os.system("cd " + tempdir + " ; aim init") + + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + # This should not mean file logger is not present. + # code will add it by default + # The below validate_training check will test for that too. + train_args.trackers = ["aim"] + + tracker_configs = TrackerConfigFactory( + aim_config=AimConfig(experiment="unit_test", aim_repo=tempdir + "/") + ) + + sft_trainer.train( + MODEL_ARGS, DATA_ARGS, train_args, tracker_configs=tracker_configs + ) + + # validate ft tuning configs + _validate_training(tempdir) + + # validate inference + _test_run_inference(tempdir) diff --git a/tuning/config/configs.py b/tuning/config/configs.py index a990f1e43..92fb4f8f8 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -20,6 +20,9 @@ import torch import transformers +# Local +from tuning.trackers.tracker_factory import FILE_LOGGING_TRACKER + DEFAULT_CONTEXT_LENGTH = 4096 DEFAULT_OPTIMIZER = "adamw_torch" @@ -126,7 +129,7 @@ class TrainingArguments(transformers.TrainingArguments): }, ) trackers: Optional[List[str.lower]] = field( - default_factory=lambda: ["file_logger"], + default_factory=lambda: [FILE_LOGGING_TRACKER], metadata={ "help": "Experiment trackers to use.\n" + "Available trackers are - file_logger(default), aim, none\n" diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index d866337f0..6e7f2eb67 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -52,7 +52,7 @@ TrackerConfigFactory, ) from tuning.data import tokenizer_data_utils -from tuning.trackers.tracker_factory import get_tracker +from tuning.trackers.tracker_factory import FILE_LOGGING_TRACKER, get_tracker from tuning.trainercontroller import TrainerControllerCallback from tuning.utils.config_utils import get_hf_peft_config, get_json_config from tuning.utils.data_type_utils import get_torch_dtype @@ -128,11 +128,23 @@ def train( trackers = [] trainer_callbacks = [] + if exp_metadata and (not isinstance(exp_metadata, dict)): + raise ValueError("exp metadata passed should be a dict with valid json") + if train_args.trackers is not None: requested_trackers = set(train_args.trackers) else: requested_trackers = set() + # Ensure file logging is present + if FILE_LOGGING_TRACKER not in requested_trackers: + requested_trackers.add(FILE_LOGGING_TRACKER) + + if not isinstance(tracker_configs, TrackerConfigFactory): + raise ValueError( + "tracker configs should adhere to the TrackerConfigFactory type" + ) + # Now initialize trackers one by one for name in requested_trackers: t = get_tracker(name, tracker_configs) @@ -152,7 +164,12 @@ def train( # Add any extra callback if passed by users if additional_callbacks is not None: - trainer_callbacks.extend(additional_callbacks) + for cb in additional_callbacks: + if not isinstance(cb, TrainerCallback): + raise ValueError( + "additional callbacks should be of type TrainerCallback" + ) + trainer_callbacks.append(cb) framework = AccelerationFrameworkConfig.from_dataclasses( quantized_lora_config, fusedops_kernels_config diff --git a/tuning/trackers/aimstack_tracker.py b/tuning/trackers/aimstack_tracker.py index 342983698..33eefd4ba 100644 --- a/tuning/trackers/aimstack_tracker.py +++ b/tuning/trackers/aimstack_tracker.py @@ -23,15 +23,25 @@ class AimStackTracker(Tracker): def __init__(self, tracker_config: AimConfig): - """ - Tracker which uses Aimstack to collect and store metrics. + """Tracker which uses Aimstack to collect and store metrics. + + Args: + tracker_config (AimConfig): A valid AimConfig which contains either + information about the repo or the server and port where aim db is present. """ super().__init__(name="aim", tracker_config=tracker_config) self.logger = logging.get_logger("aimstack_tracker") def get_hf_callback(self): - """ - Returns the aim.hugging_face.AimCallback object associated with this tracker. + """Returns the aim.hugging_face.AimCallback object associated with this tracker. + + Raises: + ValueError: If the config passed at initialise does not contain one of + aim_repo or server and port where aim db is present. + + Returns: + aim.hugging_face.AimCallback: The Aimcallback initialsed with the config + provided at init time. """ c = self.config exp = c.experiment @@ -43,25 +53,34 @@ def get_hf_callback(self): if repo: aim_callback = AimCallback(repo=repo, experiment=exp) else: - self.logger.warning( + self.logger.error( "Aim tracker requested but repo or server is not specified. " + "Please specify either aim repo or aim server ip and port for using Aim." ) - aim_callback = None + raise ValueError( + "Aim tracker requested but repo or server is not specified." + ) self.hf_callback = aim_callback return self.hf_callback def track(self, metric, name, stage="additional_metrics"): - """ - Track any additional `metric` with `name` under Aimstack tracker. - Expects metric and name to not be None. - stage can be used to pass the metadata associated with metric, - like, training metric or eval metric or additional metric + """Track any additional metric with name under Aimstack tracker. + + Args: + metric (int/float): Expected metrics to be tracked by Aimstack. + name (str): Name of the metric being tracked. + stage (str, optional): Can be used to pass the namespace/metadata to + associate with metric, e.g. at the stage the metric was generated like train, eval. + Defaults to "additional_metrics". + + Raises: + ValueError: If the metric or name are passed as None. """ if metric is None or name is None: - self.logger.warning("Tracked metric value or name should not be None") - return + raise ValueError( + "aimstack track function should not be called with None metric value or name" + ) context = {"subset": stage} callback = self.hf_callback run = callback.experiment @@ -69,13 +88,20 @@ def track(self, metric, name, stage="additional_metrics"): run.track(metric, name=name, context=context) def set_params(self, params, name="extra_params"): + """Attach any extra params with the run information stored in Aimstack tracker. + + Args: + params (dict): A dict of k:v pairs of parameters to be storeed in tracker. + name (str, optional): represents the namespace under which parameters + will be associated in Aim. Defaults to "extra_params". + + Raises: + ValueError: the params passed is None or not of type dict """ - Attach any extra params with the run information stored in Aimstack tracker. - Expects params to be a dict of k:v pairs of parameters to store. - name represents the namespace under which parameters will be associated in Aim. - """ - if params is None: - return + if params is None or (not isinstance(params, dict)): + raise ValueError( + "set_params passed to aimstack should be called with a dict of params" + ) callback = self.hf_callback run = callback.experiment if run is not None: diff --git a/tuning/trackers/filelogging_tracker.py b/tuning/trackers/filelogging_tracker.py index 66934191f..213377d96 100644 --- a/tuning/trackers/filelogging_tracker.py +++ b/tuning/trackers/filelogging_tracker.py @@ -72,16 +72,22 @@ def _track_loss(self, loss_key, log_name, log_file, logs, state): class FileLoggingTracker(Tracker): def __init__(self, tracker_config: FileLoggingTrackerConfig): - """ - Tracker which encodes callback to record metric, e.g., training loss + """Tracker which encodes callback to record metric, e.g., training loss to a file in the checkpoint directory. + + Args: + tracker_config (FileLoggingTrackerConfig): An instance of file logging tracker + which contains the location of file where logs are recorded. """ super().__init__(name="file_logger", tracker_config=tracker_config) self.logger = logging.get_logger("file_logging_tracker") def get_hf_callback(self): - """ - Returns the FileLoggingCallback object associated with this tracker. + """Returns the FileLoggingCallback object associated with this tracker. + + Returns: + FileLoggingCallback: The file logging callback which inherits + transformers.TrainerCallback and records the metrics to a file. """ file = self.config.training_logs_filename self.hf_callback = FileLoggingCallback(logs_filename=file) diff --git a/tuning/trackers/tracker_factory.py b/tuning/trackers/tracker_factory.py index 3ba127b7f..6b1e9787c 100644 --- a/tuning/trackers/tracker_factory.py +++ b/tuning/trackers/tracker_factory.py @@ -21,16 +21,20 @@ # Local from .filelogging_tracker import FileLoggingTracker -from .tracker import Tracker from tuning.config.tracker_configs import FileLoggingTrackerConfig, TrackerConfigFactory logger = logging.get_logger("tracker_factory") + # Information about all registered trackers -AVAILABLE_TRACKERS = {} +AIMSTACK_TRACKER = "aim" +FILE_LOGGING_TRACKER = "file_logger" + +AVAILABLE_TRACKERS = [AIMSTACK_TRACKER, FILE_LOGGING_TRACKER] + -AIMSTACK_TRACKER_NAME = "aim" -FILE_LOGGING_TRACKER_NAME = "file_logger" +# Trackers which can be used +REGISTERED_TRACKERS = {} # One time package check for list of external trackers. _is_aim_available = _is_package_available("aim") @@ -49,7 +53,7 @@ def _register_aim_tracker(): AimTracker = _get_tracker_class(AimStackTracker, AimConfig) - AVAILABLE_TRACKERS[AIMSTACK_TRACKER_NAME] = AimTracker + REGISTERED_TRACKERS[AIMSTACK_TRACKER] = AimTracker logger.info("Registered aimstack tracker") else: logger.info( @@ -59,9 +63,15 @@ def _register_aim_tracker(): ) +def _is_tracker_installed(t): + if t == "aim": + return _is_aim_available + return False + + def _register_file_logging_tracker(): FileTracker = _get_tracker_class(FileLoggingTracker, FileLoggingTrackerConfig) - AVAILABLE_TRACKERS[FILE_LOGGING_TRACKER_NAME] = FileTracker + REGISTERED_TRACKERS[FILE_LOGGING_TRACKER] = FileTracker logger.info("Registered file logging tracker") @@ -70,9 +80,9 @@ def _register_file_logging_tracker(): # aim - Aimstack Tracker def _register_trackers(): logger.info("Registering trackers") - if AIMSTACK_TRACKER_NAME not in AVAILABLE_TRACKERS: + if AIMSTACK_TRACKER not in REGISTERED_TRACKERS: _register_aim_tracker() - if FILE_LOGGING_TRACKER_NAME not in AVAILABLE_TRACKERS: + if FILE_LOGGING_TRACKER not in REGISTERED_TRACKERS: _register_file_logging_tracker() @@ -87,32 +97,65 @@ def _get_tracker_config_by_name(name: str, tracker_configs: TrackerConfigFactory def get_tracker(name: str, tracker_configs: TrackerConfigFactory): + """Returns an instance of the tracker object based on the requested name. + + Args: + name (str): name of the tracker requested. + tracker_configs (tuning.config.tracker_configs.TrackerConfigFactory): + An instance of TrackerConfigFactory passed which contains a + non None instance of config for the requested tracker + Raises: + ValueError: If a valid tracker config is not found this function raises a ValueError + ValueError: If a valid tracker is found but its config is not passed the tracker might + raise a ValueError. See tuning.trackers.tracker.aimstack_tracker.AimStackTracker + + Returns: + tuning.trackers.tracker.Tracker: A subclass of tuning.trackers.tracker.Tracker + Valid classes available are, + tuning.trackers.tracker.aimstack_tracker.AimStackTracker, + tuning.trackers.tracker.filelogging_tracker.FileLoggingTracker + + Examples: + file_logging_tracker = get_tracker("file_logger", TrackerConfigFactory( + file_logger_config=FileLoggingTrackerConfig( + training_logs_filename=logs_file + ) + )) + aim_tracker = get_tracker("aim", TrackerConfigFactory( + aim_config=AimConfig( + experiment="unit_test", + aim_repo=tempdir + "/" + ) + )) """ - Returns an instance of the tracker object based on the requested `name`. - Expects tracker config to be present as part of the TrackerConfigFactory - object passed as `tracker_configs` argument. - If a valid tracker config is not found this function tries tracker with - default config else returns an empty Tracker() - """ - if not AVAILABLE_TRACKERS: + if not REGISTERED_TRACKERS: # a one time step. _register_trackers() - if name in AVAILABLE_TRACKERS: - meta = AVAILABLE_TRACKERS[name] - C = meta["config"] - T = meta["tracker"] - - if tracker_configs is not None: - _conf = _get_tracker_config_by_name(name, tracker_configs) - if _conf is not None: - config = C(**_conf) - else: - config = C() - return T(config) - - logger.warning( - "Requested Tracker %s not found. Please check the argument before proceeding.", - name, - ) - return Tracker() + if name not in REGISTERED_TRACKERS: + if name in AVAILABLE_TRACKERS and (not _is_tracker_installed(name)): + err = ( + "Requested tracker " + name + " is not installed.\n" + "List of installed trackers is " + + (",".join(str(t) for t in AVAILABLE_TRACKERS)) + ) + else: + err = ( + "Requested Tracker " + + name + + " not found. Please check the argument before proceeding." + ) + logger.error(err) + raise ValueError(err) + + meta = REGISTERED_TRACKERS[name] + C = meta["config"] + T = meta["tracker"] + + if tracker_configs is not None: + _conf = _get_tracker_config_by_name(name, tracker_configs) + if _conf is not None: + config = C(**_conf) + else: + config = C() + return T(config)