Skip to content

Commit

Permalink
Install Acceleration Framework into Training Script (#157)
Browse files Browse the repository at this point in the history
* add acceleration framework

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* framework can add callbacks

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* add basic acceleration framework unit tests, lint.

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* add README, plugin installation tool

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* updates to readme

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* more readme updates

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* update fms-accel dep

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* add  more tests

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* fixes after rebase + linting

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* make acceleration framework tests a module and lint,fmt.

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* clarify the usages flows

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* replace yaml with dataclass args

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* fmt + lint

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* improve tests

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* test fixes

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* improve data parsing logic

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* add foak test

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* fix bug and add bnb test

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* add missing peft config test

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* update README as per @Ssukriti's suggestions.

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* remove test helpers

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* fix merge errors and other issues

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* add one more check in get_framework and other fixes.

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

* fix tests

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

---------

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim authored Jun 20, 2024
1 parent fc87c74 commit fc8938d
Show file tree
Hide file tree
Showing 14 changed files with 1,368 additions and 6 deletions.
69 changes: 69 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ If you wish to use [aim](https://github.com/aimhubio/aim), then you need to inst
pip install -e ".[aim]"
```

If you wish to use [fms-acceleration](https://github.com/foundation-model-stack/fms-acceleration), you need to install it.
```
pip install -e ".[fms-accel]"
```
`fms-acceleration` is a collection of plugins that packages that accelerate fine-tuning / training of large models, as part of the `fms-hf-tuning` suite. For more details on see [this section below](#fms-acceleration).

## Data format
We support two data formats:

Expand Down Expand Up @@ -377,6 +383,69 @@ Equally you can pass in a JSON configuration for running tuning. See [build doc]
}
```

### FMS Acceleration

`fms-acceleration` is fuss-free approach to access a curated collection of acceleration plugins that acclerate your `tuning/sft-trainer.py` experience. Accelerations that apply to a variety of use-cases, e.g., PeFT / full-finetuning, are being planned for. As such, the accelerations are grouped into *plugins*; only install the plugins needed for the acceleration of interest. The plugins are housed in the [seperate repository found here](https://github.com/foundation-model-stack/fms-acceleration).

To access `fms-acceleration` features the `[fms-accel]` dependency must first be installed:
```
$ pip install -e .[fms-accel]
```

Furthermore, the required `fms-acceleration` plugin must be installed. This is done via the command line utility `fms_acceleration.cli`. To show available plugins:
```
$ python -m fms_acceleration.cli plugins
```
as well as to install the `fms_acceleration_peft`:

```
$ python -m fms_acceleration.cli install fms_acceleration_peft
```

If you do not know what plugin to install (or forget), the framework will remind

```
An acceleration feature is requested by specifying the '--auto_gptq' argument, but the this requires acceleration packages to be installed. Please do:
- python -m fms_acceleration.cli install fms_acceleration_peft
```

The list of configurations for various `fms_acceleration` plugins:
- [quantized_lora_config](./tuning/config/acceleration_configs/quantized_lora_config.py): For quantized 4bit LoRA training
- `--auto_gptq`: 4bit GPTQ-LoRA with AutoGPTQ
- `--bnb_qlora`: 4bit QLoRA with bitsandbytes
- [fused_ops_and_kernels](./tuning/config/acceleration_configs/fused_ops_and_kernels.py) (experimental):
- `--fused_lora`: fused lora for more efficient LoRA training.
- `--fast_kernels`: fast cross-entropy, rope, rms loss kernels.

Notes:
* `quantized_lora_config` requires that it be used along with LoRA tuning technique. See [LoRA tuning section](https://github.com/foundation-model-stack/fms-hf-tuning/tree/main?tab=readme-ov-file#lora-tuning-example) on the LoRA parameters to pass.
* When setting `--auto_gptq triton_v2` plus note to also pass `--torch_dtype float16` and `--fp16`, or an exception will be raised. This is because these kernels only support this dtype.
* Currently, the `fused_ops_and_kernels` is to be used used together QLoRA or GPTQ-LORA via the `quantized_lora_config`. In the future it may be made more flexible such that `fast_kernels` can even be used with full-finetuning.
* When using `fused_ops_and_kernels` together with `quantized_lora_config`,
make sure to appropriately set `--fused_lora auto_gptq True` or `bitsandbytes True`; the `True` sets `fast_lora==True`.
* Currently `fused_ops_and_kernels` only supports activating `fast_loss,fast_rsm_layernorm,fast_rope_embeddings` all to `True`, so pass `--fast_kernels True True True`.


Activate `TRANSFORMERS_VERBOSITY=info` to see the huggingface trainer printouts and verify that `AccelerationFramework` is activated!

```
# this printout will be seen in huggingface trainer logs if acceleration is activated
***** FMS AccelerationFramework *****
Active Plugin: AutoGPTQAccelerationPlugin. Python package: fms_acceleration_peft. Version: 0.0.1.
***** Running training *****
Num examples = 1,549
Num Epochs = 1
Instantaneous batch size per device = 4
Total train batch size (w. parallel, distributed & accumulation) = 4
Gradient Accumulation steps = 1
Total optimization steps = 200
Number of trainable parameters = 13,631,488
```

The `fms_acceleration.cli` can do more to search for all available configs, plugins and arguments, [see the advanced flow](https://github.com/foundation-model-stack/fms-acceleration#advanced-flow).



## Inference
Currently, we do *not* offer inference support as part of the library, but we provide a standalone script for running inference on tuned models for testing purposes. For a full list of options run `python scripts/run_inference.py --help`. Note that no data formatting / templating is applied at inference time.

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ dependencies = [
dev = ["wheel", "packaging", "ninja", "scikit-learn>=1.0, <2.0", "boto3"]
flash-attn = ["flash-attn"]
aim = ["aim==3.19.0"]
fms-accel = [
"fms_acceleration @ git+https://github.com/foundation-model-stack/fms-acceleration.git#subdirectory=plugins/framework"
]

[tool.setuptools.packages.find]
exclude = ["tests", "tests.*"]
Expand Down
13 changes: 13 additions & 0 deletions tests/acceleration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
47 changes: 47 additions & 0 deletions tests/acceleration/spying_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def create_mock_plugin_class_and_spy(class_name, plugin_cls):
"helper function to create plugin class"

spy = {
"model_loader_calls": 0,
"augmentation_calls": 0,
"get_ready_for_train_calls": 0,
}

def model_loader(self, *args, **kwargs):
spy["model_loader_calls"] += 1
return plugin_cls.model_loader(self, *args, **kwargs)

def augmentation(
self,
*args,
**kwargs,
):
spy["augmentation_calls"] += 1
return plugin_cls.augmentation(self, *args, **kwargs)

def get_callbacks_and_ready_for_train(self, *args, **kwargs):
spy["get_ready_for_train_calls"] += 1
return plugin_cls.get_callbacks_and_ready_for_train(self, args, **kwargs)

attributes = {
"model_loader": model_loader,
"augmentation": augmentation,
"get_callbacks_and_ready_for_train": get_callbacks_and_ready_for_train,
}

return type(class_name, (plugin_cls,), attributes), spy
135 changes: 135 additions & 0 deletions tests/acceleration/test_acceleration_dataclasses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard

# Third Party
import pytest
import transformers

# Local
from tuning.config.acceleration_configs import (
FusedOpsAndKernelsConfig,
QuantizedLoraConfig,
)
from tuning.config.acceleration_configs.fused_ops_and_kernels import (
FastKernelsConfig,
FusedLoraConfig,
)
from tuning.config.acceleration_configs.quantized_lora_config import (
AutoGPTQLoraConfig,
BNBQLoraConfig,
)


def test_dataclass_parse_successfully():
parser = transformers.HfArgumentParser(dataclass_types=QuantizedLoraConfig)

# if nothing is specified then it will parse into the null class
(cfg, _) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
assert cfg.auto_gptq is None
assert cfg.bnb_qlora is None

# 1.1 specifying "--auto_gptq" with the first item of AutoGPTQLoraConfig
# will parse
(cfg,) = parser.parse_args_into_dataclasses(
["--auto_gptq", "triton_v2"],
)
assert isinstance(cfg.auto_gptq, AutoGPTQLoraConfig)
assert cfg.bnb_qlora is None

# 1.2 specifying "--auto_gptq" with the two items of AutoGPTQLoraConfig
# will parse
(cfg,) = parser.parse_args_into_dataclasses(
["--auto_gptq", "triton_v2", "true"],
)
assert isinstance(cfg.auto_gptq, AutoGPTQLoraConfig)
assert cfg.bnb_qlora is None

# 2. specifying "--bnb_qlora" with the first item of BNBQLoraConfig
# will parse
(cfg,) = parser.parse_args_into_dataclasses(
["--bnb_qlora", "nf4"],
)
assert cfg.auto_gptq is None
assert isinstance(cfg.bnb_qlora, BNBQLoraConfig)


def test_two_dataclasses_parse_successfully_together():
"""Ensure that the two dataclasses can parse arguments successfully
together.
"""
parser = transformers.HfArgumentParser(
dataclass_types=(QuantizedLoraConfig, FusedOpsAndKernelsConfig)
)

# 1. specifying "--auto_gptq" together with "--fused_lora" and
# "--fast_kernels" will parse.
cfg, cfg2 = parser.parse_args_into_dataclasses(
[
"--auto_gptq",
"triton_v2",
"--fused_lora",
"auto_gptq",
"true",
"--fast_kernels",
"true",
"true",
"true",
],
)
assert isinstance(cfg.auto_gptq, AutoGPTQLoraConfig)
assert cfg.bnb_qlora is None
assert isinstance(cfg2.fused_lora, FusedLoraConfig)
assert isinstance(cfg2.fast_kernels, FastKernelsConfig)


def test_dataclass_will_fail_to_parse_with_no_args():
"""Ensure that the dataclass arg parser will refuse to parse if
only the key is specified without any following arguments.
"""
parser = transformers.HfArgumentParser(dataclass_types=QuantizedLoraConfig)

# 1. passing only the key without any body will fail
# - at least the first argument of the dataclass will be expected.
with pytest.raises(
SystemExit, # argparse will exit
):
(_,) = parser.parse_args_into_dataclasses(
["--auto_gptq"],
)


def test_dataclass_will_fail_to_accept_illegal_args():
"""Ensure that some basic rules that are put in the dataclasses will
fail at initialization of the class.
"""

# 1. auto_gptq does not support from_quantized at the moment.
with pytest.raises(
ValueError, match="only 'from_quantized' == True currently supported."
):
AutoGPTQLoraConfig(from_quantized=False)

# 1.1 auto_gptq only supports triton_v2 at the moment
with pytest.raises(
ValueError, match="only 'triton_v2' kernel currently supported."
):
AutoGPTQLoraConfig(kernel="fake-kernel")

# 2 bnb only supports two quant types
with pytest.raises(
ValueError, match="quant_type can only be either 'nf4' or 'fp4."
):
BNBQLoraConfig(quant_type="fake-quant-type")
Loading

0 comments on commit fc8938d

Please sign in to comment.