From fc8938dc9e7748b77935adfeb02e05d9983ce36c Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 20 Jun 2024 23:16:51 +0800 Subject: [PATCH] Install Acceleration Framework into Training Script (#157) * add acceleration framework Signed-off-by: Yu Chin Fabian Lim * framework can add callbacks Signed-off-by: Yu Chin Fabian Lim * add basic acceleration framework unit tests, lint. Signed-off-by: Yu Chin Fabian Lim * add README, plugin installation tool Signed-off-by: Yu Chin Fabian Lim * updates to readme Signed-off-by: Yu Chin Fabian Lim * more readme updates Signed-off-by: Yu Chin Fabian Lim * update fms-accel dep Signed-off-by: Yu Chin Fabian Lim * add more tests Signed-off-by: Yu Chin Fabian Lim * fixes after rebase + linting Signed-off-by: Yu Chin Fabian Lim * make acceleration framework tests a module and lint,fmt. Signed-off-by: Yu Chin Fabian Lim * clarify the usages flows Signed-off-by: Yu Chin Fabian Lim * replace yaml with dataclass args Signed-off-by: Yu Chin Fabian Lim * fmt + lint Signed-off-by: Yu Chin Fabian Lim * improve tests Signed-off-by: Yu Chin Fabian Lim * test fixes Signed-off-by: Yu Chin Fabian Lim * improve data parsing logic Signed-off-by: Yu Chin Fabian Lim * add foak test Signed-off-by: Yu Chin Fabian Lim * fix bug and add bnb test Signed-off-by: Yu Chin Fabian Lim * add missing peft config test Signed-off-by: Yu Chin Fabian Lim * update README as per @Ssukriti's suggestions. Signed-off-by: Yu Chin Fabian Lim * remove test helpers Signed-off-by: Yu Chin Fabian Lim * fix merge errors and other issues Signed-off-by: Yu Chin Fabian Lim * add one more check in get_framework and other fixes. Signed-off-by: Yu Chin Fabian Lim * fix tests Signed-off-by: Yu Chin Fabian Lim --------- Signed-off-by: Yu Chin Fabian Lim --- README.md | 69 +++ pyproject.toml | 3 + tests/acceleration/__init__.py | 13 + tests/acceleration/spying_utils.py | 47 ++ .../test_acceleration_dataclasses.py | 135 ++++++ .../test_acceleration_framework.py | 448 ++++++++++++++++++ tests/test_sft_trainer.py | 8 +- .../config/acceleration_configs/__init__.py | 18 + .../acceleration_framework_config.py | 265 +++++++++++ .../fused_ops_and_kernels.py | 106 +++++ .../quantized_lora_config.py | 82 ++++ tuning/config/acceleration_configs/utils.py | 88 ++++ tuning/sft_trainer.py | 58 ++- tuning/utils/import_utils.py | 34 ++ 14 files changed, 1368 insertions(+), 6 deletions(-) create mode 100644 tests/acceleration/__init__.py create mode 100644 tests/acceleration/spying_utils.py create mode 100644 tests/acceleration/test_acceleration_dataclasses.py create mode 100644 tests/acceleration/test_acceleration_framework.py create mode 100644 tuning/config/acceleration_configs/__init__.py create mode 100644 tuning/config/acceleration_configs/acceleration_framework_config.py create mode 100644 tuning/config/acceleration_configs/fused_ops_and_kernels.py create mode 100644 tuning/config/acceleration_configs/quantized_lora_config.py create mode 100644 tuning/config/acceleration_configs/utils.py create mode 100644 tuning/utils/import_utils.py diff --git a/README.md b/README.md index e29f45440..8725cd8a0 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,12 @@ If you wish to use [aim](https://github.com/aimhubio/aim), then you need to inst pip install -e ".[aim]" ``` +If you wish to use [fms-acceleration](https://github.com/foundation-model-stack/fms-acceleration), you need to install it. +``` +pip install -e ".[fms-accel]" +``` +`fms-acceleration` is a collection of plugins that packages that accelerate fine-tuning / training of large models, as part of the `fms-hf-tuning` suite. For more details on see [this section below](#fms-acceleration). + ## Data format We support two data formats: @@ -377,6 +383,69 @@ Equally you can pass in a JSON configuration for running tuning. See [build doc] } ``` +### FMS Acceleration + +`fms-acceleration` is fuss-free approach to access a curated collection of acceleration plugins that acclerate your `tuning/sft-trainer.py` experience. Accelerations that apply to a variety of use-cases, e.g., PeFT / full-finetuning, are being planned for. As such, the accelerations are grouped into *plugins*; only install the plugins needed for the acceleration of interest. The plugins are housed in the [seperate repository found here](https://github.com/foundation-model-stack/fms-acceleration). + +To access `fms-acceleration` features the `[fms-accel]` dependency must first be installed: + ``` + $ pip install -e .[fms-accel] + ``` + +Furthermore, the required `fms-acceleration` plugin must be installed. This is done via the command line utility `fms_acceleration.cli`. To show available plugins: + ``` + $ python -m fms_acceleration.cli plugins + ``` +as well as to install the `fms_acceleration_peft`: + + ``` + $ python -m fms_acceleration.cli install fms_acceleration_peft + ``` + +If you do not know what plugin to install (or forget), the framework will remind + +``` +An acceleration feature is requested by specifying the '--auto_gptq' argument, but the this requires acceleration packages to be installed. Please do: +- python -m fms_acceleration.cli install fms_acceleration_peft +``` + +The list of configurations for various `fms_acceleration` plugins: +- [quantized_lora_config](./tuning/config/acceleration_configs/quantized_lora_config.py): For quantized 4bit LoRA training + - `--auto_gptq`: 4bit GPTQ-LoRA with AutoGPTQ + - `--bnb_qlora`: 4bit QLoRA with bitsandbytes +- [fused_ops_and_kernels](./tuning/config/acceleration_configs/fused_ops_and_kernels.py) (experimental): + - `--fused_lora`: fused lora for more efficient LoRA training. + - `--fast_kernels`: fast cross-entropy, rope, rms loss kernels. + +Notes: + * `quantized_lora_config` requires that it be used along with LoRA tuning technique. See [LoRA tuning section](https://github.com/foundation-model-stack/fms-hf-tuning/tree/main?tab=readme-ov-file#lora-tuning-example) on the LoRA parameters to pass. + * When setting `--auto_gptq triton_v2` plus note to also pass `--torch_dtype float16` and `--fp16`, or an exception will be raised. This is because these kernels only support this dtype. + * Currently, the `fused_ops_and_kernels` is to be used used together QLoRA or GPTQ-LORA via the `quantized_lora_config`. In the future it may be made more flexible such that `fast_kernels` can even be used with full-finetuning. + * When using `fused_ops_and_kernels` together with `quantized_lora_config`, + make sure to appropriately set `--fused_lora auto_gptq True` or `bitsandbytes True`; the `True` sets `fast_lora==True`. + * Currently `fused_ops_and_kernels` only supports activating `fast_loss,fast_rsm_layernorm,fast_rope_embeddings` all to `True`, so pass `--fast_kernels True True True`. + + +Activate `TRANSFORMERS_VERBOSITY=info` to see the huggingface trainer printouts and verify that `AccelerationFramework` is activated! + +``` +# this printout will be seen in huggingface trainer logs if acceleration is activated +***** FMS AccelerationFramework ***** +Active Plugin: AutoGPTQAccelerationPlugin. Python package: fms_acceleration_peft. Version: 0.0.1. +***** Running training ***** +Num examples = 1,549 +Num Epochs = 1 +Instantaneous batch size per device = 4 +Total train batch size (w. parallel, distributed & accumulation) = 4 +Gradient Accumulation steps = 1 +Total optimization steps = 200 +Number of trainable parameters = 13,631,488 +``` + +The `fms_acceleration.cli` can do more to search for all available configs, plugins and arguments, [see the advanced flow](https://github.com/foundation-model-stack/fms-acceleration#advanced-flow). + + + ## Inference Currently, we do *not* offer inference support as part of the library, but we provide a standalone script for running inference on tuned models for testing purposes. For a full list of options run `python scripts/run_inference.py --help`. Note that no data formatting / templating is applied at inference time. diff --git a/pyproject.toml b/pyproject.toml index 052de5ee0..126d43252 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,9 @@ dependencies = [ dev = ["wheel", "packaging", "ninja", "scikit-learn>=1.0, <2.0", "boto3"] flash-attn = ["flash-attn"] aim = ["aim==3.19.0"] +fms-accel = [ + "fms_acceleration @ git+https://github.com/foundation-model-stack/fms-acceleration.git#subdirectory=plugins/framework" +] [tool.setuptools.packages.find] exclude = ["tests", "tests.*"] diff --git a/tests/acceleration/__init__.py b/tests/acceleration/__init__.py new file mode 100644 index 000000000..38a9531ef --- /dev/null +++ b/tests/acceleration/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/acceleration/spying_utils.py b/tests/acceleration/spying_utils.py new file mode 100644 index 000000000..ce1ae1f9a --- /dev/null +++ b/tests/acceleration/spying_utils.py @@ -0,0 +1,47 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def create_mock_plugin_class_and_spy(class_name, plugin_cls): + "helper function to create plugin class" + + spy = { + "model_loader_calls": 0, + "augmentation_calls": 0, + "get_ready_for_train_calls": 0, + } + + def model_loader(self, *args, **kwargs): + spy["model_loader_calls"] += 1 + return plugin_cls.model_loader(self, *args, **kwargs) + + def augmentation( + self, + *args, + **kwargs, + ): + spy["augmentation_calls"] += 1 + return plugin_cls.augmentation(self, *args, **kwargs) + + def get_callbacks_and_ready_for_train(self, *args, **kwargs): + spy["get_ready_for_train_calls"] += 1 + return plugin_cls.get_callbacks_and_ready_for_train(self, args, **kwargs) + + attributes = { + "model_loader": model_loader, + "augmentation": augmentation, + "get_callbacks_and_ready_for_train": get_callbacks_and_ready_for_train, + } + + return type(class_name, (plugin_cls,), attributes), spy diff --git a/tests/acceleration/test_acceleration_dataclasses.py b/tests/acceleration/test_acceleration_dataclasses.py new file mode 100644 index 000000000..fc031298b --- /dev/null +++ b/tests/acceleration/test_acceleration_dataclasses.py @@ -0,0 +1,135 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard + +# Third Party +import pytest +import transformers + +# Local +from tuning.config.acceleration_configs import ( + FusedOpsAndKernelsConfig, + QuantizedLoraConfig, +) +from tuning.config.acceleration_configs.fused_ops_and_kernels import ( + FastKernelsConfig, + FusedLoraConfig, +) +from tuning.config.acceleration_configs.quantized_lora_config import ( + AutoGPTQLoraConfig, + BNBQLoraConfig, +) + + +def test_dataclass_parse_successfully(): + parser = transformers.HfArgumentParser(dataclass_types=QuantizedLoraConfig) + + # if nothing is specified then it will parse into the null class + (cfg, _) = parser.parse_args_into_dataclasses(return_remaining_strings=True) + assert cfg.auto_gptq is None + assert cfg.bnb_qlora is None + + # 1.1 specifying "--auto_gptq" with the first item of AutoGPTQLoraConfig + # will parse + (cfg,) = parser.parse_args_into_dataclasses( + ["--auto_gptq", "triton_v2"], + ) + assert isinstance(cfg.auto_gptq, AutoGPTQLoraConfig) + assert cfg.bnb_qlora is None + + # 1.2 specifying "--auto_gptq" with the two items of AutoGPTQLoraConfig + # will parse + (cfg,) = parser.parse_args_into_dataclasses( + ["--auto_gptq", "triton_v2", "true"], + ) + assert isinstance(cfg.auto_gptq, AutoGPTQLoraConfig) + assert cfg.bnb_qlora is None + + # 2. specifying "--bnb_qlora" with the first item of BNBQLoraConfig + # will parse + (cfg,) = parser.parse_args_into_dataclasses( + ["--bnb_qlora", "nf4"], + ) + assert cfg.auto_gptq is None + assert isinstance(cfg.bnb_qlora, BNBQLoraConfig) + + +def test_two_dataclasses_parse_successfully_together(): + """Ensure that the two dataclasses can parse arguments successfully + together. + """ + parser = transformers.HfArgumentParser( + dataclass_types=(QuantizedLoraConfig, FusedOpsAndKernelsConfig) + ) + + # 1. specifying "--auto_gptq" together with "--fused_lora" and + # "--fast_kernels" will parse. + cfg, cfg2 = parser.parse_args_into_dataclasses( + [ + "--auto_gptq", + "triton_v2", + "--fused_lora", + "auto_gptq", + "true", + "--fast_kernels", + "true", + "true", + "true", + ], + ) + assert isinstance(cfg.auto_gptq, AutoGPTQLoraConfig) + assert cfg.bnb_qlora is None + assert isinstance(cfg2.fused_lora, FusedLoraConfig) + assert isinstance(cfg2.fast_kernels, FastKernelsConfig) + + +def test_dataclass_will_fail_to_parse_with_no_args(): + """Ensure that the dataclass arg parser will refuse to parse if + only the key is specified without any following arguments. + """ + parser = transformers.HfArgumentParser(dataclass_types=QuantizedLoraConfig) + + # 1. passing only the key without any body will fail + # - at least the first argument of the dataclass will be expected. + with pytest.raises( + SystemExit, # argparse will exit + ): + (_,) = parser.parse_args_into_dataclasses( + ["--auto_gptq"], + ) + + +def test_dataclass_will_fail_to_accept_illegal_args(): + """Ensure that some basic rules that are put in the dataclasses will + fail at initialization of the class. + """ + + # 1. auto_gptq does not support from_quantized at the moment. + with pytest.raises( + ValueError, match="only 'from_quantized' == True currently supported." + ): + AutoGPTQLoraConfig(from_quantized=False) + + # 1.1 auto_gptq only supports triton_v2 at the moment + with pytest.raises( + ValueError, match="only 'triton_v2' kernel currently supported." + ): + AutoGPTQLoraConfig(kernel="fake-kernel") + + # 2 bnb only supports two quant types + with pytest.raises( + ValueError, match="quant_type can only be either 'nf4' or 'fp4." + ): + BNBQLoraConfig(quant_type="fake-quant-type") diff --git a/tests/acceleration/test_acceleration_framework.py b/tests/acceleration/test_acceleration_framework.py new file mode 100644 index 000000000..c907b8f06 --- /dev/null +++ b/tests/acceleration/test_acceleration_framework.py @@ -0,0 +1,448 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from dataclasses import dataclass, replace +from typing import Annotated +from unittest.mock import patch +import copy +import tempfile + +# Third Party +import pytest +import torch + +# First Party +from tests.test_sft_trainer import DATA_ARGS, MODEL_ARGS, PEFT_LORA_ARGS, TRAIN_ARGS + +# Local +from .spying_utils import create_mock_plugin_class_and_spy +from tuning import sft_trainer +from tuning.config.acceleration_configs import ( + AccelerationFrameworkConfig, + FusedOpsAndKernelsConfig, + QuantizedLoraConfig, +) +from tuning.config.acceleration_configs.acceleration_framework_config import ( + ConfigAnnotation, +) +from tuning.config.acceleration_configs.fused_ops_and_kernels import ( + FastKernelsConfig, + FusedLoraConfig, +) +from tuning.config.acceleration_configs.quantized_lora_config import ( + AutoGPTQLoraConfig, + BNBQLoraConfig, +) +from tuning.utils.import_utils import is_fms_accelerate_available + +# pylint: disable=import-error +if is_fms_accelerate_available(): + + # Third Party + from fms_acceleration.utils.test_utils import build_framework_and_maybe_instantiate + + if is_fms_accelerate_available(plugins="peft"): + # Third Party + from fms_acceleration_peft import ( + AutoGPTQAccelerationPlugin, + BNBAccelerationPlugin, + ) + + if is_fms_accelerate_available(plugins="foak"): + # Third Party + from fms_acceleration_foak import FastQuantizedPeftAccelerationPlugin + + +# There are more extensive unit tests in the +# https://github.com/foundation-model-stack/fms-acceleration +# repository. +# - see plugins/framework/tests/test_framework.py + + +@pytest.mark.skipif( + not is_fms_accelerate_available(plugins="peft"), + reason="Only runs if fms-accelerate is installed along with accelerated-peft plugin", +) +def test_acceleration_framework_fail_construction(): + """Ensure that construct of the framework will fail if rules regarding + the dataclasess are violated. + """ + + # 1. Rule 1: No two standalone dataclasses can exist at the same path + # - Test that the framework will fail to construct if there are multiple + # standalone plugins under the same path that are simultaneously requested. + invalid_quantized_lora_config = QuantizedLoraConfig( + auto_gptq=AutoGPTQLoraConfig(), bnb_qlora=BNBQLoraConfig() + ) + with pytest.raises( + ValueError, + match="Configuration path 'peft.quantization' already has one standalone config.", + ): + AccelerationFrameworkConfig.from_dataclasses( + invalid_quantized_lora_config + ).get_framework() + + def peft_unavailable(plugin=None): + if plugin == "peft": + return False + return True + + quantized_lora_config = QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig()) + + # 2. Rule 2: Dataclass cannot request a plugin that is not yet installed. + # - Test that framework will fail to construct if trying to activate a plugin + # that is not yet installed + with pytest.raises( + ValueError, + match="An acceleration feature is requested by specifying the '--auto_gptq' argument, " + "but the this requires acceleration packages to be installed.", + ): + with patch( + "tuning.config.acceleration_configs.acceleration_framework_config." + "is_fms_accelerate_available", + peft_unavailable, + ): + AccelerationFrameworkConfig.from_dataclasses( + quantized_lora_config + ).get_framework() + + # 3. Rule 3: Dataclass that corresponds to experimental plugin will + # give user a warning. + # - Test that if a plugin is experimental the user will be warned + + # - create a dataclas with an experimental annotation that to be + # used for mocking + # - mocked auto_gptq here to be experimental + @dataclass + class DataClassWithExperimental: + auto_gptq: Annotated[ + AutoGPTQLoraConfig, + ConfigAnnotation(path="peft.quantization", experimental=True), + ] = None + + with pytest.warns( + UserWarning, + match="An experimental acceleration feature is requested by specifying the " + "'--auto_gptq' argument. Please note this feature may not support certain " + "edge cases at this juncture. When the feature matures this " + "message will be turned off.", + ): + with patch.dict( + "tuning.config.acceleration_configs.acceleration_framework_config." + "AccelerationFrameworkConfig.__dataclass_fields__", + DataClassWithExperimental.__dataclass_fields__, + ): + + AccelerationFrameworkConfig.from_dataclasses( + quantized_lora_config + ).get_framework() + + +def test_acceleration_framework_pass_construction_with_no_active_configs(): + """Ensure framework is properly constructed in the null pattern where + no configs are active + """ + + # for the fallback, if the dataclasses + AccelerationFrameworkConfig.from_dataclasses(QuantizedLoraConfig) + assert QuantizedLoraConfig.auto_gptq is None + assert QuantizedLoraConfig.bnb_qlora is None + + +@pytest.mark.skip( + """ NOTE: this scenario will actually never happen, since in the code we always + provide at least one dataclass (can consider to remove this test). + """ +) +def test_construct_framework_config_raise_if_constructing_with_no_dataclassess(): + """Ensure that framework configuration config will refused to construct + if no dataclasses are provided. + """ + + with pytest.raises( + ValueError, + match="AccelerationFrameworkConfig construction requires at least one dataclass", + ): + AccelerationFrameworkConfig.from_dataclasses() + + +@pytest.mark.skipif( + not is_fms_accelerate_available(plugins="peft"), + reason="Only runs if fms-accelerate is installed along with accelerated-peft plugin", +) +def test_construct_framework_with_auto_gptq_peft_successfully(): + "Ensure that framework object is correctly configured." + + # 1. correctly initialize a set of quantized lora config dataclass + # with auto-gptq + quantized_lora_config = QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig()) + + # - instantiate the acceleration config + acceleration_config = AccelerationFrameworkConfig.from_dataclasses( + quantized_lora_config + ) + + # build the framework by + # - passing acceleration configuration contents (via .to_dict()). + # - NOTE: we skip the required packages check in the framework since it is + # not necessary for this test (e.g., we do not need auto_gptq installed) + # - check that the plugin is correctly activated + with build_framework_and_maybe_instantiate( + [], + acceleration_config.to_dict(), # pass in contents + reset_registrations=False, + require_packages_check=False, # not required + ) as framework: + + # plugin activated! + assert len(framework.active_plugins) == 1 + + +@pytest.mark.skipif( + not is_fms_accelerate_available(), + reason="Only runs if fms-accelerate is installed", +) +def test_framework_raises_if_used_with_missing_package(): + """Ensure that trying the use the framework, without first installing fms_acceleration + will raise. + """ + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = None + + quantized_lora_config = QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig()) + + # patch is_fms_accelerate_available to return False inside sft_trainer + # to simulate fms_acceleration not installed + with patch( + "tuning.config.acceleration_configs.acceleration_framework_config." + "is_fms_accelerate_available", + return_value=False, + ): + with pytest.raises( + ValueError, match="No acceleration framework package found." + ): + sft_trainer.train( + MODEL_ARGS, + DATA_ARGS, + TRAIN_ARGS, + PEFT_LORA_ARGS, + quantized_lora_config=quantized_lora_config, + ) + + +invalid_kwargs_map = [ + ( + { + "model_name_or_path": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", + }, + PEFT_LORA_ARGS, + AssertionError, + "need to run in fp16 mixed precision or load model in fp16", + ), + ( + { + "model_name_or_path": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", + "torch_dtype": torch.float16, + }, + None, + AssertionError, + "need peft_config to install PEFT adapters", + ), +] + + +@pytest.mark.skipif( + not is_fms_accelerate_available(plugins="peft"), + reason="Only runs if fms-accelerate is installed along with accelerated-peft plugin", +) +@pytest.mark.parametrize( + "bad_kwargs,peft_config,exception,exception_msg", + invalid_kwargs_map, + ids=["triton_v2 requires fp16", "accelerated peft requires peft config"], +) +def test_framework_raises_due_to_invalid_arguments( + bad_kwargs, peft_config, exception, exception_msg +): + """Ensure that invalid arguments will be checked by the activated framework + plugin. + """ + with tempfile.TemporaryDirectory() as tempdir: + model_args = copy.deepcopy(MODEL_ARGS) + model_args = replace(model_args, **bad_kwargs) + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + + quantized_lora_config = QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig()) + + # 1. activate the accelerated peft plugin + # 2. demonstrate that the invalid arguments will be checked + with pytest.raises(exception, match=exception_msg): + sft_trainer.train( + model_args, + DATA_ARGS, + train_args, + peft_config, + quantized_lora_config=quantized_lora_config, + ) + + +acceleration_configs_map = [ + ( + QuantizedLoraConfig(bnb_qlora=BNBQLoraConfig()), + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + ( + "peft.quantization.bitsandbytes", + create_mock_plugin_class_and_spy( + "PluginMock", + BNBAccelerationPlugin + if is_fms_accelerate_available(plugins="peft") + else object, + ), + ), + ), + ( + QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig()), + "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", + ( + "peft.quantization.auto_gptq", + create_mock_plugin_class_and_spy( + "PluginMock", + AutoGPTQAccelerationPlugin + if is_fms_accelerate_available(plugins="peft") + else object, + ), + ), + ), +] + + +@pytest.mark.skipif( + not is_fms_accelerate_available(plugins="peft"), + reason="Only runs if fms-accelerate is installed along with accelerated-peft plugin", +) +@pytest.mark.parametrize( + "quantized_lora_config,model_name_or_path,mock_and_spy", + acceleration_configs_map, + ids=["bitsandbytes", "auto_gptq"], +) +def test_framework_intialized_properly_peft( + quantized_lora_config, model_name_or_path, mock_and_spy +): + """Ensure that specifying a properly configured acceleration dataclass + properly activates the framework plugin and runs the train sucessfully. + """ + with tempfile.TemporaryDirectory() as tempdir: + model_args = copy.deepcopy(MODEL_ARGS) + model_args.model_name_or_path = model_name_or_path + model_args.torch_dtype = torch.float16 + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + train_args.save_strategy = "no" + train_args.fp16 = True + + installation_path, (MockedPlugin, spy) = mock_and_spy + + # 1. mock a plugin class + # 2. register the mocked plugin + # 3. call sft_trainer.train + with build_framework_and_maybe_instantiate( + [([installation_path], MockedPlugin)], + instantiate=False, + ): + sft_trainer.train( + model_args, + DATA_ARGS, + train_args, + PEFT_LORA_ARGS, + quantized_lora_config=quantized_lora_config, + ) + + # spy inside the train to ensure that the acceleration plugin + # was called. In the context of the AutoGPTQ plugin + # 1. Will sucessfully load the model (to load AutoGPTQ 4-bit linear) + # 2. Will successfully agument the model (to install PEFT) + # 3. Will succesfully call get_ready_for_train + assert spy["model_loader_calls"] == 1 + assert spy["augmentation_calls"] == 1 + assert spy["get_ready_for_train_calls"] == 1 + + +@pytest.mark.skipif( + not is_fms_accelerate_available(plugins=["peft", "foak"]), + reason=( + "Only runs if fms-accelerate is installed along with accelerated-peft " + "and foak plugins" + ), +) +def test_framework_intialized_properly_foak(): + """Ensure that specifying a properly configured acceleration dataclass + properly activates the framework plugin and runs the train sucessfully. + """ + with tempfile.TemporaryDirectory() as tempdir: + + model_args = copy.deepcopy(MODEL_ARGS) + model_args.model_name_or_path = "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ" + model_args.torch_dtype = torch.float16 + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.output_dir = tempdir + train_args.save_strategy = "no" + train_args.fp16 = True + + # setup default quantized lora args dataclass + # - with auth gptq as the quantized method + quantized_lora_config = QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig()) + fusedops_kernels_config = FusedOpsAndKernelsConfig( + fused_lora=FusedLoraConfig(base_layer="auto_gptq", fused_lora=True), + fast_kernels=FastKernelsConfig( + fast_loss=True, fast_rsm_layernorm=True, fast_rope_embeddings=True + ), + ) + + # create mocked plugin class for spying + MockedPlugin1, spy = create_mock_plugin_class_and_spy( + "AutoGPTQMock", AutoGPTQAccelerationPlugin + ) + MockedPlugin2, spy2 = create_mock_plugin_class_and_spy( + "FastPeftMock", FastQuantizedPeftAccelerationPlugin + ) + + # 1. mock a plugin class + # 2. register the mocked plugins + # 3. call sft_trainer.train + with build_framework_and_maybe_instantiate( + [ + (["peft.quantization.auto_gptq"], MockedPlugin1), + (["peft.quantization.fused_ops_and_kernels"], MockedPlugin2), + ], + instantiate=False, + ): + sft_trainer.train( + model_args, + DATA_ARGS, + train_args, + PEFT_LORA_ARGS, + quantized_lora_config=quantized_lora_config, + fusedops_kernels_config=fusedops_kernels_config, + ) + + # spy inside the train to ensure that the AutoGPTQ plugin is called + assert spy["model_loader_calls"] == 1 + assert spy["augmentation_calls"] == 1 + assert spy["get_ready_for_train_calls"] == 1 + + # also test that the FusedOpsPlugin is called + assert spy2["model_loader_calls"] == 0 + assert spy2["augmentation_calls"] == 1 + assert spy2["get_ready_for_train_calls"] == 1 diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 19db3402e..c02146bf1 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -119,6 +119,8 @@ def test_parse_arguments(job_config): _, _, _, + _, + _, ) = sft_trainer.parse_arguments(parser, job_config_copy) assert str(model_args.torch_dtype) == "torch.bfloat16" assert data_args.dataset_text_field == "output" @@ -132,7 +134,7 @@ def test_parse_arguments_defaults(job_config): assert "torch_dtype" not in job_config_defaults assert job_config_defaults["use_flash_attn"] is False assert "save_strategy" not in job_config_defaults - model_args, _, training_args, _, _, _, _, _ = sft_trainer.parse_arguments( + model_args, _, training_args, _, _, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_defaults ) assert str(model_args.torch_dtype) == "torch.bfloat16" @@ -144,14 +146,14 @@ def test_parse_arguments_peft_method(job_config): parser = sft_trainer.get_parser() job_config_pt = copy.deepcopy(job_config) job_config_pt["peft_method"] = "pt" - _, _, _, _, tune_config, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_pt ) assert isinstance(tune_config, peft_config.PromptTuningConfig) job_config_lora = copy.deepcopy(job_config) job_config_lora["peft_method"] = "lora" - _, _, _, _, tune_config, _, _, _ = sft_trainer.parse_arguments( + _, _, _, _, tune_config, _, _, _, _, _ = sft_trainer.parse_arguments( parser, job_config_lora ) assert isinstance(tune_config, peft_config.LoraConfig) diff --git a/tuning/config/acceleration_configs/__init__.py b/tuning/config/acceleration_configs/__init__.py new file mode 100644 index 000000000..f971e2108 --- /dev/null +++ b/tuning/config/acceleration_configs/__init__.py @@ -0,0 +1,18 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Local +from .acceleration_framework_config import AccelerationFrameworkConfig +from .fused_ops_and_kernels import FusedOpsAndKernelsConfig +from .quantized_lora_config import QuantizedLoraConfig diff --git a/tuning/config/acceleration_configs/acceleration_framework_config.py b/tuning/config/acceleration_configs/acceleration_framework_config.py new file mode 100644 index 000000000..be8627e5f --- /dev/null +++ b/tuning/config/acceleration_configs/acceleration_framework_config.py @@ -0,0 +1,265 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from dataclasses import asdict, dataclass, fields, is_dataclass +from typing import Annotated, Dict, List, Type +import warnings + +# Third Party +import yaml + +# Local +from .fused_ops_and_kernels import FastKernelsConfig, FusedLoraConfig +from .quantized_lora_config import AutoGPTQLoraConfig, BNBQLoraConfig +from tuning.utils.import_utils import is_fms_accelerate_available + +if is_fms_accelerate_available(): + # Third Party + from fms_acceleration import AccelerationFramework # pylint: disable=import-error + from fms_acceleration.framework import KEY_PLUGINS # pylint: disable=import-error + +# these are optional annotations that describe different behavior +@dataclass +class ConfigAnnotation: + + # AccelerationFramework configuration path + path: str + + # if omitted, will take the field name + key: str = None + + # only one that has single=True may exist under its path + # - this is used to indicate conflicting configurations + # - we do not allow two configurations that load the model to be + # activated at the same time + standalone: bool = False + + # set to true to throw a user warning + experimental: bool = False + + # set to indicate what acceeleration packages are needed + required_packages: List[str] = None + + def __post_init__(self): + if self.required_packages is None: + self.required_packages = [] + + +@dataclass +class AccelerationFrameworkConfig: + "Dataclass that manages configuration of AccelerationFramework" + + PACKAGE_PREFIX = "fms_acceleration_" + + # each field will a single-level use case dataclass + auto_gptq: Annotated[ + AutoGPTQLoraConfig, + ConfigAnnotation( + path="peft.quantization", standalone=True, required_packages=["peft"] + ), + ] = None + + bitsandbytes: Annotated[ + BNBQLoraConfig, + ConfigAnnotation( + path="peft.quantization", standalone=True, required_packages=["peft"] + ), + ] = None + + fused_lora: Annotated[ + FusedLoraConfig, + ConfigAnnotation( + path="peft.quantization", + key="fused_ops_and_kernels", + experimental=True, + required_packages=["foak"], + ), + ] = None + + fast_kernels: Annotated[ + FastKernelsConfig, + ConfigAnnotation( + path="peft.quantization", + key="fused_ops_and_kernels", + experimental=True, + required_packages=["foak"], + ), + ] = None + + @staticmethod + def from_dataclasses(*dataclasses: Type): + "Convert one or many FMS config dataclasses to a monolithic AccelerationConfig" + + # Assumption: AccelerationFrameworkConfig only has fields that are + # single level dataclasses + # Assumption: dataclasses is a list of nested dataclasses + # - each dc in dataclasses is a nested dataclass. + # - each dc.field in dc is a non-nested dataclass. + + if len(dataclasses) == 0: + raise ValueError( + "AccelerationFrameworkConfig construction requires at least one dataclass." + ) + + # first unroll all the dataclases into a single level + nested_dataclasses = [] + for dc in dataclasses: + if dc is None: + continue + + # make sure that it every field is a dataclass + for fi in fields(dc): + attr = getattr(dc, fi.name) + if attr is None: + continue # skip the None attributes + + if not is_dataclass(attr): + raise ValueError( + f"field '{fi.name}' is specified but not a dataclass" + ) + + # NOTE: should we also check that these are non-nested + # dataclasses? + nested_dataclasses.append(attr) + + config = AccelerationFrameworkConfig() + rem_fields = {fi.name: fi for fi in fields(config)} # these need to be parsed + + # process the dataclasses that were nested + # by assumption these are non-nested dataclasses + for dc in nested_dataclasses: + + # check the fields that are yet to be populated + found_field = False + for fi in rem_fields.values(): + + # check if it is an AccelerationFrameworkConfig field + if isinstance(dc, fi.type.__origin__): + found_field = True + break + + if not found_field: + raise ValueError( + f"dataclass '{dc}' cannot be placed into AccelerationFrameworkConfig." + ) + + # assign the dataclass + setattr(config, fi.name, dc) + del rem_fields[fi.name] # remove the field + + return config + + def get_framework(self): + + if is_fms_accelerate_available(): + + # to be eventually be made to be passed as a dict to Acceleration + # Framework + # Standard + from tempfile import ( # pylint: disable=import-outside-toplevel + NamedTemporaryFile, + ) + + with NamedTemporaryFile("w") as f: + self.to_yaml(f.name) + return AccelerationFramework(f.name) + else: + if not self.is_empty(): + raise ValueError( + "No acceleration framework package found. To use, first " + "ensure that 'pip install -e.[fms-accel]' is done first to " + "obtain the acceleration framework dependency. Additional " + "acceleration plugins make be required depending on the requsted " + "acceleration. See README.md for instructions." + ) + + def is_empty(self): + "check if the configuration is empty" + for fi in fields(self): + if getattr(self, fi.name) is not None: + return False + return True + + def to_dict(self): + """convert a valid AccelerationFrameworkConfig dataclass into a schema-less dictionary + as dictated by the header annotations. + """ + + # populate a dictionary + configuration_contents = {} + + # helper function to populate + def _descend_and_set(path: List[str], d: Dict): + r = configuration_contents + for p in path[:-1]: + if p not in r: + r[p] = {} # new branch + r = r[p] + + p = path[-1] + r[p] = {**r.get(p, {}), **d} # merge dict if exists + + # parse each field + already_set = set() + for fi in fields(self): + datacls = getattr(self, fi.name) + if datacls is not None: + # this is the documented way to get annotations + # https://docs.python.org/3/library/typing.html#typing.Annotated + annotate: ConfigAnnotation + (annotate,) = fi.type.__metadata__ + prefix_path = tuple(annotate.path.split(".")) + if annotate.standalone and prefix_path in already_set: + raise ValueError( + f"Configuration path '{'.'.join(prefix_path)}' " + "already has one standalone config." + ) + + if annotate.experimental: + warnings.warn( + "An experimental acceleration feature is requested by specifying the " + f"'--{fi.name}' argument. Please note this feature may not support certain " + "edge cases at this juncture. When the feature matures this " + "message will be turned off." + ) + + if not all( + is_fms_accelerate_available(x) for x in annotate.required_packages + ): + raise ValueError( + "An acceleration feature is requested by specifying the " + f"'--{fi.name}' argument, but the this requires acceleration packages " + "to be installed. Please do:\n" + + "\n".join( + [ + "- python -m fms_acceleration install " + f"{AccelerationFrameworkConfig.PACKAGE_PREFIX + x}" + for x in annotate.required_packages + ] + ) + ) + + key = annotate.key if annotate.key is not None else fi.name + path = prefix_path + (key,) + already_set.add(prefix_path) + _descend_and_set(path, asdict(datacls)) + + return configuration_contents + + def to_yaml(self, filename: str): + "convert a valid AccelerationConfig dataclass into a yaml" + configuration_contents = self.to_dict() + with open(filename, "w", encoding="utf-8") as f: + yaml.dump({KEY_PLUGINS: configuration_contents}, f) diff --git a/tuning/config/acceleration_configs/fused_ops_and_kernels.py b/tuning/config/acceleration_configs/fused_ops_and_kernels.py new file mode 100644 index 000000000..91df8c9dc --- /dev/null +++ b/tuning/config/acceleration_configs/fused_ops_and_kernels.py @@ -0,0 +1,106 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Standard +from dataclasses import dataclass +from typing import List + +# Local +from .utils import ( + EnsureTypes, + ensure_nested_dataclasses_initialized, + parsable_dataclass, +) + + +@parsable_dataclass +@dataclass +class FusedLoraConfig(List): + + # to help the HfArgumentParser arrive at correct types + __args__ = [EnsureTypes(str, bool)] + + # load unsloth optimizations for these 4bit base layer weights. + # currently only support "auto_gptq" and "bitsandbytes" + base_layer: str = None + + # fused kernels for lora linear layers + fused_lora: bool = False + + def __post_init__(self): + + # reset for another parse + self.__args__[0].reset() + + if self.base_layer is not None and self.base_layer not in { + "auto_gptq", + "bitsandbytes", + }: + raise ValueError(f"base_layer set to invalid value '{self.base_layer}'") + + if self.base_layer is not None and not self.fused_lora: + raise ValueError( + f"base_layer set to '{self.base_layer}' so fused_lora must be set to True" + ) + + +@parsable_dataclass +@dataclass +class FastKernelsConfig(List): + + # to help the HfArgumentParser arrive at correct types + __args__ = [EnsureTypes(bool, bool, bool)] + + # fast loss triton kernels + fast_loss: bool = False + + # fast rms norm triton kernels + fast_rsm_layernorm: bool = False + + # fast RoPE embedding triton kernels + fast_rope_embeddings: bool = False + + def __post_init__(self): + + # reset for another parse + self.__args__[0].reset() + + if not self.fast_loss == self.fast_rsm_layernorm == self.fast_rope_embeddings: + raise ValueError( + "fast_loss, fast_rms_layernorm and fast_rope_embedding must be enabled " + "together. This restriction may be relaxed in the future." + ) + + +@dataclass +class FusedOpsAndKernelsConfig: + + # fused lora ops + fused_lora: FusedLoraConfig = None + + # fast kernels + fast_kernels: FastKernelsConfig = None + + def __post_init__(self): + if (self.fused_lora is not None and self.fast_kernels is None) or ( + self.fused_lora is None and self.fast_kernels is not None + ): + raise ValueError( + "fused lora and fast_kernels must be used together. " + "This restriction may be relaxed in the future." + ) + + # ensure nested dataclasses initialized + ensure_nested_dataclasses_initialized(self) diff --git a/tuning/config/acceleration_configs/quantized_lora_config.py b/tuning/config/acceleration_configs/quantized_lora_config.py new file mode 100644 index 000000000..d8174438c --- /dev/null +++ b/tuning/config/acceleration_configs/quantized_lora_config.py @@ -0,0 +1,82 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Standard +from dataclasses import dataclass +from typing import List + +# Local +from .utils import ( + EnsureTypes, + ensure_nested_dataclasses_initialized, + parsable_dataclass, +) + + +@parsable_dataclass +@dataclass +class AutoGPTQLoraConfig: + + # auto_gptq supports various kernels, to select the kernel to use. + kernel: str = "triton_v2" + + # allow auto_gptq to quantize a model before training commences. + # NOTE: currently this is not allowed. + from_quantized: bool = True + + def __post_init__(self): + + if self.kernel != "triton_v2": + raise ValueError("only 'triton_v2' kernel currently supported.") + + if not self.from_quantized: + raise ValueError("only 'from_quantized' == True currently supported.") + + +@parsable_dataclass +@dataclass +class BNBQLoraConfig(List): + + # to help the HfArgumentParser arrive at correct types + __args__ = [EnsureTypes(str, bool)] + + # type of quantization applied + quant_type: str = "nf4" + + # if we only want to quantize the base layer, and defer to the + # huggingface to prepare the peft (i.e. lora) model + no_peft_model: bool = False + + def __post_init__(self): + + # reset for another parse + self.__args__[0].reset() + + if self.quant_type not in ["nf4", "fp4"]: + raise ValueError("quant_type can only be either 'nf4' or 'fp4.") + + +@dataclass +class QuantizedLoraConfig: + + # to use auto_gptq 4bit lora base layers + auto_gptq: AutoGPTQLoraConfig = None + + # to use auto_gptq 4bit lora base layers + bnb_qlora: BNBQLoraConfig = None + + def __post_init__(self): + # ensure nested dataclasses initialized + ensure_nested_dataclasses_initialized(self) diff --git a/tuning/config/acceleration_configs/utils.py b/tuning/config/acceleration_configs/utils.py new file mode 100644 index 000000000..3085a9761 --- /dev/null +++ b/tuning/config/acceleration_configs/utils.py @@ -0,0 +1,88 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from dataclasses import fields, is_dataclass +from typing import Dict, List, Type, get_type_hints + +# Third Party +from transformers.hf_argparser import DataClass, string_to_bool + + +def ensure_nested_dataclasses_initialized(dataclass: DataClass): + """HfArgumentParser will think of the dataclass as a List with + multiple inputs, but it will not call the constructor, so + this is to be called at the top-level class to init all the + nested dataclasses. + """ + type_hints: Dict[str, type] = get_type_hints(dataclass) + for f in fields(dataclass): + nested_type = type_hints[f.name] + values = getattr(dataclass, f.name) + if values is not None and not is_dataclass(values): + values = nested_type(*values) + setattr(dataclass, f.name, values) + + +class EnsureTypes: + """EnsureTypes is a caster with an internal state to memorize the + the casting order, so that we can apply the correct casting type. + + e.g., EnsureTypes(int, str) will cast [x1, x2] as [int(x1), str(x2)] + """ + + def __init__(self, *types: Type): + _map = {bool: string_to_bool} + self.types = [_map.get(t, t) for t in types] + self.reset() + + def reset(self): + self.cnt = 0 + + def __call__(self, val): + if self.cnt >= len(self.types): + raise ValueError("EnsureTypes require 'reset' to be called to be re-used.") + + t = self.types[self.cnt] + self.cnt += 1 + return t(val) + + +def parsable_dataclass(cls): + """dataset decorator to masquarade as a list type, so that + HfArgumentParser will take in multiple arguments after the + --key arg1 arg2, ..., + + * when we override __args__, we can ensure the parseds + - arg1 arg2 .. will get casted to the correct type + + """ + + if not is_dataclass(cls): + raise ValueError("parsable only works with dataclass") + + types = [fi.type for fi in fields(cls)] + + class ParsableDataclass(cls, List): + + # to help the HfArgumentParser arrive at correct types + __args__ = [EnsureTypes(*types)] + + def __post_init__(self): + # reset for another parse + ParsableDataclass.__args__[0].reset() + + super().__post_init__() + + return ParsableDataclass diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 2ab1ece9d..de616fd22 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -32,7 +32,7 @@ LlamaTokenizerFast, TrainerCallback, ) -from transformers.utils import logging +from transformers.utils import is_accelerate_available, logging from trl import DataCollatorForCompletionOnlyLM, SFTTrainer import datasets import fire @@ -40,6 +40,11 @@ # Local from tuning.config import configs, peft_config +from tuning.config.acceleration_configs import ( + AccelerationFrameworkConfig, + FusedOpsAndKernelsConfig, + QuantizedLoraConfig, +) from tuning.config.tracker_configs import ( AimConfig, FileLoggingTrackerConfig, @@ -71,6 +76,8 @@ def train( ), additional_callbacks: Optional[List[TrainerCallback]] = None, exp_metadata: Optional[Dict] = None, + quantized_lora_config: Optional[QuantizedLoraConfig] = None, + fusedops_kernels_config: Optional[FusedOpsAndKernelsConfig] = None, ): """Call the SFTTrainer @@ -93,6 +100,11 @@ def train( or TrainerControllers. Callbacks associated with \ tracker with automatically be added. exp_metadata: Dict of key value pairs passed to train to be recoreded by the tracker. + quantized_lora_config: tuning.config.acceleration_configs.QuantizedLoraConfig \ + Should be used in combination with peft_config.LoraConfig for Lora tuning \ + fusedops_kernels_config: tuning.config.acceleration_configs.FusedOpsAndKernelsConfig \ + Should be used in combination with quantized_lora_config. Also currently + fused_lora and fast_kernels must used together (may change in future). \ """ logger = logging.get_logger("sft_trainer") @@ -140,8 +152,15 @@ def train( if additional_callbacks is not None: trainer_callbacks.append(additional_callbacks) + framework = AccelerationFrameworkConfig.from_dataclasses( + quantized_lora_config, fusedops_kernels_config + ).get_framework() + + model_loader = AutoModelForCausalLM.from_pretrained + if framework is not None and framework.requires_custom_loading: + model_loader = framework.model_loader # drop-in new loader model_load_time = time.time() - model = AutoModelForCausalLM.from_pretrained( + model = model_loader( model_args.model_name_or_path, cache_dir=train_args.cache_dir, torch_dtype=get_torch_dtype(model_args.torch_dtype), @@ -291,6 +310,11 @@ def train( "Validation dataset length is %s", len(formatted_validation_dataset) ) + if framework is not None and framework.requires_agumentation: + model, (peft_config,) = framework.augmentation( + model, train_args, modifiable_args=(peft_config,) + ) + trainer = SFTTrainer( model=model, tokenizer=tokenizer, @@ -324,6 +348,14 @@ def train( trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy( model ) + + if framework is not None: + accelerator = None if not is_accelerate_available else trainer.accelerator + + # ready for train may produce additional callbacks for the trainer + for x in framework.get_callbacks_and_ready_for_train(model, accelerator): + trainer.add_callback(x) + trainer.train() @@ -339,6 +371,8 @@ def get_parser(): peft_config.PromptTuningConfig, FileLoggingTrackerConfig, AimConfig, + QuantizedLoraConfig, + FusedOpsAndKernelsConfig, ) ) parser.add_argument( @@ -381,6 +415,10 @@ def parse_arguments(parser, json_config=None): Configuration for training log file. AimConfig Configuration for AIM stack. + QuantizedLoraConfig + Configuration for quantized LoRA (a form of PEFT). + FusedOpsAndKernelsConfig + Configuration for fused operations and kernels. dict[str, str] Extra AIM metadata. """ @@ -394,6 +432,8 @@ def parse_arguments(parser, json_config=None): prompt_tuning_config, file_logger_config, aim_config, + quantized_lora_config, + fusedops_kernels_config, ) = parser.parse_dict(json_config, allow_extra_keys=True) peft_method = json_config.get("peft_method") exp_metadata = json_config.get("exp_metadata") @@ -407,6 +447,8 @@ def parse_arguments(parser, json_config=None): prompt_tuning_config, file_logger_config, aim_config, + quantized_lora_config, + fusedops_kernels_config, additional, _, ) = parser.parse_args_into_dataclasses(return_remaining_strings=True) @@ -429,6 +471,8 @@ def parse_arguments(parser, json_config=None): tune_config, file_logger_config, aim_config, + quantized_lora_config, + fusedops_kernels_config, exp_metadata, ) @@ -449,12 +493,16 @@ def main(**kwargs): # pylint: disable=unused-argument tune_config, file_logger_config, aim_config, + quantized_lora_config, + fusedops_kernels_config, exp_metadata, ) = parse_arguments(parser, job_config) logger.debug( "Input args parsed: \ model_args %s, data_args %s, training_args %s, trainer_controller_args %s, \ - tune_config %s, file_logger_config, %s aim_config %s, exp_metadata %s", + tune_config %s, file_logger_config, %s aim_config %s, \ + quantized_lora_config %s, fusedops_kernels_config %s, \ + exp_metadata %s", model_args, data_args, training_args, @@ -462,6 +510,8 @@ def main(**kwargs): # pylint: disable=unused-argument tune_config, file_logger_config, aim_config, + quantized_lora_config, + fusedops_kernels_config, exp_metadata, ) except Exception as e: # pylint: disable=broad-except @@ -501,6 +551,8 @@ def main(**kwargs): # pylint: disable=unused-argument tracker_configs=combined_tracker_configs, additional_callbacks=None, exp_metadata=metadata, + quantized_lora_config=quantized_lora_config, + fusedops_kernels_config=fusedops_kernels_config, ) except (MemoryError, OutOfMemoryError) as e: logger.error(traceback.format_exc()) diff --git a/tuning/utils/import_utils.py b/tuning/utils/import_utils.py new file mode 100644 index 000000000..36dd606c6 --- /dev/null +++ b/tuning/utils/import_utils.py @@ -0,0 +1,34 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +from typing import List, Union + +# Third Party +from transformers.utils.import_utils import _is_package_available + + +def is_fms_accelerate_available( + plugins: Union[str, List[str]] = None, package_name: str = "fms_acceleration" +): + names = [package_name] + if plugins is not None: + if isinstance(plugins, str): + plugins = [plugins] + names.extend([package_name + "_" + x for x in plugins]) + + for n in names: + if not _is_package_available(n): + return False + return True