Skip to content

Commit

Permalink
fmt + lint
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Jun 13, 2024
1 parent 9580b99 commit 891cef8
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 170 deletions.
44 changes: 26 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -336,29 +336,37 @@ tuning/sft_trainer.py \

`fms-acceleration` is fuss-free approach to access a curated collection of acceleration plugins that acclerate your `tuning/sft-trainer.py` experience. Accelerations that apply to a variety of use-cases, e.g., PeFT / full-finetuning, are being planned for. As such, the accelerations are grouped into *plugins*; only install the plugins needed for the acceleration of interest. The plugins are housed in the [seperate repository found here](https://github.com/foundation-model-stack/fms-acceleration).

Basic usage includes these steps:
To access `fms-acceleration` features the `[fms-accel]` dependency must first be installed:
```
$ pip install -e .[fms-accel]
```

1. Install the `[fms-accel]` dependency:
```
$ pip install -e .[fms-accel]
```
Furthermore, the required `fms-acceleration` plugin must be installed. This is done via the command line utility `fms_acceleration.cli`. To show available plugins:
```
$ python -m fms_acceleration.cli plugins
```
as well as to install the `fms_acceleration_peft`:

The installs the command line utility `fms_acceleration.cli`, used to install plugins.
3. `install` the required framework plugins; we install the `fms-acceleration-peft` plugin for GPTQ-LoRA tuning with triton v2 as:
```
python -m fms_acceleration.cli install fms_acceleration_peft
```
```
$ python -m fms_acceleration.cli install fms_acceleration_peft
```

5. Run `sft_trainer.py` providing the acceleration configuration and arguments; given the basic flow assumption that we simply re-use the same `sft_trainer.py` arguments as we had without using the `fms_acceleration` package:
```
python sft_trainer.py \
--acceleration_framework_config_file framework_config.yaml \
... # arguments
```
If you do not know what plugin to install (or forget), the framework will remind

See [this sample configuration for GPTQ-LoRA with triton v2](./fixtures/accelerated-peft-autogptq-sample-configuration.yaml) to be passed into `--acceleration_framework_config_file` above.
```
An acceleration feature is requested by specifying the '--auto_gptq' argument, but the this requires acceleration packages to be installed. Please do:
- python -m fms_acceleration install fms_acceleration_peft
```

The list of configurations for various `fms_acceleration` plugins:
- [quantized_lora_config](./tuning/config/acceleration_configs/quantized_lora_config.py): For quantized 4bit LoRA training
- `--auto_gptq`: 4bit GPTQ_LoRA with AutoGPTQ
- `--bnb_qlora`: 4bit QLoRA with bitsandbytes
- [fused_ops_and_kernels](./tuning/config/acceleration_configs/fused_ops_and_kernels.py) (experimental):
- `--fused_lora`: fused lora for more efficient LoRA training.
- `--fast_kernels`: fast cross-entropy, rope, rms loss kernels.

Thats it! Activate `TRANSFORMERS_VERBOSITY=info` to see the huggingface trainer printouts and verify that `AccelerationFramework` is activated!
Activate `TRANSFORMERS_VERBOSITY=info` to see the huggingface trainer printouts and verify that `AccelerationFramework` is activated!

```
# this printout will be seen in huggingface trainer logs if acceleration is activated
Expand Down
41 changes: 25 additions & 16 deletions tests/acceleration/test_acceleration_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,27 @@

# Third Party
import pytest
import yaml

# First Party
from tests.helpers import causal_lm_train_kwargs
from tests.test_sft_trainer import BASE_LORA_KWARGS

# Local
from tuning import sft_trainer
from tuning.utils.import_utils import is_fms_accelerate_available
import tuning.config.configs as config
from tuning.config.acceleration_configs import (
AccelerationFrameworkConfig, QuantizedLoraConfig
AccelerationFrameworkConfig,
QuantizedLoraConfig,
)
from tuning.config.acceleration_configs.quantized_lora_config import (
AutoGPTQLoraConfig, BNBQLoraConfig
AutoGPTQLoraConfig,
BNBQLoraConfig,
)
from tuning.utils.import_utils import is_fms_accelerate_available

# pylint: disable=import-error
if is_fms_accelerate_available():

# Third Party
from fms_acceleration.framework import KEY_PLUGINS, AccelerationFramework
from fms_acceleration.utils.test_utils import build_framework_and_maybe_instantiate

if is_fms_accelerate_available(plugins="peft"):
Expand Down Expand Up @@ -92,7 +91,8 @@ def test_construct_framework_config_with_incorrect_configurations():
"Ensure that framework configuration cannot have empty body"

with pytest.raises(
ValueError, match="AccelerationFrameworkConfig construction requires at least one dataclass"
ValueError,
match="AccelerationFrameworkConfig construction requires at least one dataclass",
):
AccelerationFrameworkConfig.from_dataclasses()

Expand All @@ -102,14 +102,18 @@ def test_construct_framework_config_with_incorrect_configurations():
):
AutoGPTQLoraConfig(from_quantized=False)

# test an invalid activation of two standalone configs.
# test an invalid activation of two standalone configs.
quantized_lora_config = QuantizedLoraConfig(
auto_gptq=AutoGPTQLoraConfig(), bnb_qlora=BNBQLoraConfig()
)
with pytest.raises(
ValueError, match="Configuration path 'peft.quantization' already has one standalone config."
ValueError,
match="Configuration path 'peft.quantization' already has one standalone config.",
):
AccelerationFrameworkConfig.from_dataclasses(quantized_lora_config).get_framework()
AccelerationFrameworkConfig.from_dataclasses(
quantized_lora_config
).get_framework()


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="peft"),
Expand All @@ -119,7 +123,9 @@ def test_construct_framework_with_auto_gptq_peft():
"Ensure that framework object is correctly configured."

quantized_lora_config = QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig())
acceleration_config = AccelerationFrameworkConfig.from_dataclasses(quantized_lora_config)
acceleration_config = AccelerationFrameworkConfig.from_dataclasses(
quantized_lora_config
)

# for this test we skip the require package check as second order package
# dependencies of accelerated_peft is not required
Expand All @@ -133,6 +139,7 @@ def test_construct_framework_with_auto_gptq_peft():
# the configuration file should successfully activate the plugin
assert len(framework.active_plugins) == 1


@pytest.mark.skipif(
not is_fms_accelerate_available(),
reason="Only runs if fms-accelerate is installed",
Expand All @@ -156,20 +163,22 @@ def test_framework_not_installed_or_initalized_properly():
# patch is_fms_accelerate_available to return False inside sft_trainer
# to simulate fms_acceleration not installed
with patch(
"tuning.config.acceleration_configs.acceleration_framework_config.is_fms_accelerate_available", return_value=False
"tuning.config.acceleration_configs.acceleration_framework_config."
"is_fms_accelerate_available",
return_value=False,
):
with pytest.raises(
ValueError,
match="No acceleration framework package found."
ValueError, match="No acceleration framework package found."
):
sft_trainer.train(
model_args,
data_args,
training_args,
tune_config,
quantized_lora_config=quantized_lora_config
quantized_lora_config=quantized_lora_config,
)


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="peft"),
reason="Only runs if fms-accelerate is installed along with accelerated-peft plugin",
Expand Down Expand Up @@ -206,7 +215,7 @@ def test_framework_intialized_properly():
training_args,
tune_config,
# acceleration_framework_args=framework_args,
quantized_lora_config=quantized_lora_config
quantized_lora_config=quantized_lora_config,
)

# spy to ensure that the plugin functions were called.
Expand Down
4 changes: 2 additions & 2 deletions tuning/config/acceleration_configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Local
from .acceleration_framework_config import AccelerationFrameworkConfig

from .fused_ops_and_kernels import FusedOpsAndKernelsConfig
from .quantized_lora_config import QuantizedLoraConfig
from .fused_ops_and_kernels import FusedOpsAndKernelsConfig
Loading

0 comments on commit 891cef8

Please sign in to comment.