Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Install Acceleration Framework into Training Script #157

Merged
merged 24 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ If you wish to use [aim](https://github.com/aimhubio/aim), then you need to inst
pip install -e ".[aim]"
```

If you wish to use [fms-acceleration](https://github.com/foundation-model-stack/fms-acceleration), you need to install it.
```
pip install -e ".[fms-accel]"
```
`fms-acceleration` is a collection of plugins that packages that accelerate fine-tuning / training of large models, as part of the `fms-hf-tuning` suite. For more details on see [this section below](#fms-acceleration).

## Data format
We support two data formats:

Expand Down Expand Up @@ -377,6 +383,69 @@ Equally you can pass in a JSON configuration for running tuning. See [build doc]
}
```

### FMS Acceleration

`fms-acceleration` is fuss-free approach to access a curated collection of acceleration plugins that acclerate your `tuning/sft-trainer.py` experience. Accelerations that apply to a variety of use-cases, e.g., PeFT / full-finetuning, are being planned for. As such, the accelerations are grouped into *plugins*; only install the plugins needed for the acceleration of interest. The plugins are housed in the [seperate repository found here](https://github.com/foundation-model-stack/fms-acceleration).

To access `fms-acceleration` features the `[fms-accel]` dependency must first be installed:
```
$ pip install -e .[fms-accel]
```

Furthermore, the required `fms-acceleration` plugin must be installed. This is done via the command line utility `fms_acceleration.cli`. To show available plugins:
```
$ python -m fms_acceleration.cli plugins
```
as well as to install the `fms_acceleration_peft`:

```
$ python -m fms_acceleration.cli install fms_acceleration_peft
```

If you do not know what plugin to install (or forget), the framework will remind

```
An acceleration feature is requested by specifying the '--auto_gptq' argument, but the this requires acceleration packages to be installed. Please do:
- python -m fms_acceleration.cli install fms_acceleration_peft
```

The list of configurations for various `fms_acceleration` plugins:
- [quantized_lora_config](./tuning/config/acceleration_configs/quantized_lora_config.py): For quantized 4bit LoRA training
- `--auto_gptq`: 4bit GPTQ-LoRA with AutoGPTQ
- `--bnb_qlora`: 4bit QLoRA with bitsandbytes
fabianlim marked this conversation as resolved.
Show resolved Hide resolved
- [fused_ops_and_kernels](./tuning/config/acceleration_configs/fused_ops_and_kernels.py) (experimental):
- `--fused_lora`: fused lora for more efficient LoRA training.
fabianlim marked this conversation as resolved.
Show resolved Hide resolved
- `--fast_kernels`: fast cross-entropy, rope, rms loss kernels.

Notes:
* `quantized_lora_config` requires that it be used along with LoRA tuning technique. See [LoRA tuning section](https://github.com/foundation-model-stack/fms-hf-tuning/tree/main?tab=readme-ov-file#lora-tuning-example) on the LoRA parameters to pass.
* When setting `--auto_gptq triton_v2` plus note to also pass `--torch_dtype float16` and `--fp16`, or an exception will be raised. This is because these kernels only support this dtype.
* Currently, the `fused_ops_and_kernels` is to be used used together QLoRA or GPTQ-LORA via the `quantized_lora_config`. In the future it may be made more flexible such that `fast_kernels` can even be used with full-finetuning.
* When using `fused_ops_and_kernels` together with `quantized_lora_config`,
make sure to appropriately set `--fused_lora auto_gptq True` or `bitsandbytes True`; the `True` sets `fast_lora==True`.
* Currently `fused_ops_and_kernels` only supports activating `fast_loss,fast_rsm_layernorm,fast_rope_embeddings` all to `True`, so pass `--fast_kernels True True True`.


Activate `TRANSFORMERS_VERBOSITY=info` to see the huggingface trainer printouts and verify that `AccelerationFramework` is activated!

```
# this printout will be seen in huggingface trainer logs if acceleration is activated
***** FMS AccelerationFramework *****
Active Plugin: AutoGPTQAccelerationPlugin. Python package: fms_acceleration_peft. Version: 0.0.1.
***** Running training *****
Num examples = 1,549
Num Epochs = 1
Instantaneous batch size per device = 4
Total train batch size (w. parallel, distributed & accumulation) = 4
Gradient Accumulation steps = 1
Total optimization steps = 200
Number of trainable parameters = 13,631,488
```

The `fms_acceleration.cli` can do more to search for all available configs, plugins and arguments, [see the advanced flow](https://github.com/foundation-model-stack/fms-acceleration#advanced-flow).



## Inference
Currently, we do *not* offer inference support as part of the library, but we provide a standalone script for running inference on tuned models for testing purposes. For a full list of options run `python scripts/run_inference.py --help`. Note that no data formatting / templating is applied at inference time.

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ dependencies = [
dev = ["wheel", "packaging", "ninja", "scikit-learn>=1.0, <2.0", "boto3"]
flash-attn = ["flash-attn"]
aim = ["aim==3.19.0"]
fms-accel = [
"fms_acceleration @ git+https://github.com/foundation-model-stack/fms-acceleration.git#subdirectory=plugins/framework"
]

[tool.setuptools.packages.find]
exclude = ["tests", "tests.*"]
Expand Down
13 changes: 13 additions & 0 deletions tests/acceleration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
47 changes: 47 additions & 0 deletions tests/acceleration/spying_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def create_mock_plugin_class_and_spy(class_name, plugin_cls):
"helper function to create plugin class"

spy = {
"model_loader_calls": 0,
"augmentation_calls": 0,
"get_ready_for_train_calls": 0,
}

def model_loader(self, *args, **kwargs):
spy["model_loader_calls"] += 1
return plugin_cls.model_loader(self, *args, **kwargs)

def augmentation(
self,
*args,
**kwargs,
):
spy["augmentation_calls"] += 1
return plugin_cls.augmentation(self, *args, **kwargs)

def get_callbacks_and_ready_for_train(self, *args, **kwargs):
spy["get_ready_for_train_calls"] += 1
return plugin_cls.get_callbacks_and_ready_for_train(self, args, **kwargs)

attributes = {
"model_loader": model_loader,
"augmentation": augmentation,
"get_callbacks_and_ready_for_train": get_callbacks_and_ready_for_train,
}

return type(class_name, (plugin_cls,), attributes), spy
135 changes: 135 additions & 0 deletions tests/acceleration/test_acceleration_dataclasses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard

# Third Party
import pytest
import transformers

# Local
from tuning.config.acceleration_configs import (
FusedOpsAndKernelsConfig,
QuantizedLoraConfig,
)
from tuning.config.acceleration_configs.fused_ops_and_kernels import (
FastKernelsConfig,
FusedLoraConfig,
)
from tuning.config.acceleration_configs.quantized_lora_config import (
AutoGPTQLoraConfig,
BNBQLoraConfig,
)


def test_dataclass_parse_successfully():
parser = transformers.HfArgumentParser(dataclass_types=QuantizedLoraConfig)

# if nothing is specified then it will parse into the null class
(cfg, _) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
assert cfg.auto_gptq is None
assert cfg.bnb_qlora is None

# 1.1 specifying "--auto_gptq" with the first item of AutoGPTQLoraConfig
# will parse
(cfg,) = parser.parse_args_into_dataclasses(
["--auto_gptq", "triton_v2"],
)
assert isinstance(cfg.auto_gptq, AutoGPTQLoraConfig)
assert cfg.bnb_qlora is None

# 1.2 specifying "--auto_gptq" with the two items of AutoGPTQLoraConfig
# will parse
(cfg,) = parser.parse_args_into_dataclasses(
["--auto_gptq", "triton_v2", "true"],
)
assert isinstance(cfg.auto_gptq, AutoGPTQLoraConfig)
assert cfg.bnb_qlora is None

# 2. specifying "--bnb_qlora" with the first item of BNBQLoraConfig
# will parse
(cfg,) = parser.parse_args_into_dataclasses(
["--bnb_qlora", "nf4"],
)
assert cfg.auto_gptq is None
assert isinstance(cfg.bnb_qlora, BNBQLoraConfig)


def test_two_dataclasses_parse_successfully_together():
"""Ensure that the two dataclasses can parse arguments successfully
together.
"""
parser = transformers.HfArgumentParser(
dataclass_types=(QuantizedLoraConfig, FusedOpsAndKernelsConfig)
)

# 1. specifying "--auto_gptq" together with "--fused_lora" and
# "--fast_kernels" will parse.
cfg, cfg2 = parser.parse_args_into_dataclasses(
[
"--auto_gptq",
"triton_v2",
"--fused_lora",
"auto_gptq",
"true",
"--fast_kernels",
"true",
"true",
"true",
],
)
assert isinstance(cfg.auto_gptq, AutoGPTQLoraConfig)
assert cfg.bnb_qlora is None
assert isinstance(cfg2.fused_lora, FusedLoraConfig)
assert isinstance(cfg2.fast_kernels, FastKernelsConfig)


def test_dataclass_will_fail_to_parse_with_no_args():
"""Ensure that the dataclass arg parser will refuse to parse if
only the key is specified without any following arguments.
"""
parser = transformers.HfArgumentParser(dataclass_types=QuantizedLoraConfig)

# 1. passing only the key without any body will fail
# - at least the first argument of the dataclass will be expected.
with pytest.raises(
SystemExit, # argparse will exit
):
(_,) = parser.parse_args_into_dataclasses(
["--auto_gptq"],
)


def test_dataclass_will_fail_to_accept_illegal_args():
"""Ensure that some basic rules that are put in the dataclasses will
fail at initialization of the class.
"""

# 1. auto_gptq does not support from_quantized at the moment.
with pytest.raises(
ValueError, match="only 'from_quantized' == True currently supported."
):
AutoGPTQLoraConfig(from_quantized=False)

# 1.1 auto_gptq only supports triton_v2 at the moment
with pytest.raises(
ValueError, match="only 'triton_v2' kernel currently supported."
):
AutoGPTQLoraConfig(kernel="fake-kernel")

# 2 bnb only supports two quant types
with pytest.raises(
ValueError, match="quant_type can only be either 'nf4' or 'fp4."
):
BNBQLoraConfig(quant_type="fake-quant-type")
Loading
Loading