Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Install Acceleration Framework into Training Script #157

Merged
merged 24 commits into from
Jun 20, 2024

Conversation

fabianlim
Copy link
Collaborator

@fabianlim fabianlim commented May 15, 2024

Description of the change

This PR installs the Acceleration Framework into the sft_trainer.py. This is a followup to #119 which proposes to have a lightweight integration; the implementation of Acceleration Framework is kept seperate in the repo fms-acceleration under the same foundation-model-stack org.

  • introduce a AccelerationFrameworkArguments that accepts an --acceleration_framework_config_file argument to configure the framework.

  • update pyproject.toml with an optional dependency [fms-accel] that installs the fms-acceleration framework.

  • update README.md to include basic usage of fms-acceleration, using it to perform accelerated PEFT with a 4bit GPTQ-LoRA.

  • ensure that the integration within sft_trainer.py is optional, the .get_framework call below will silently disable framework if fms-accel dependency is not desired.

    framework = AccelerationFrameworkConfig.from_dataclasses(*dataclass_configs).get_framework()
    
  • restrict to only three integration points within sft_trainer.py script:

    1. framework.model_loader: load model if framework.requires_custom_loading == True
    2. framework.augmentation: load if framework.requires_agumentation == True
    3. framework.get_callbacks_and_ready_for_train: get callbacks and do final prep on model and trainer.accelerator if required.

Update: Picture has been changed to reflect new dataclass arguments flow

image

Related issue number

Related to the merged PR #119 for the ADR of Training Acceleration.

How to verify the PR

This PR can be verified in the following ways:

  1. run the new folder tests/acceleration, see Note-for-unit-tests below.
  2. [Update: this bench needs to be reworked due to YAML arguments being disabled, right now works with some patching] run the provided benchmark utility, and checking with the results in the PR and the a100_80gb

Note for unit tests
For tests/acceleration, note that testenv will not install .fms-accel dependencies, so some of the newly added tests wil be skipped. So to run all the new tests:

pip install ".[fms-accel]"
python -m fms_acceleration.cli install peft
pytest tests/acceleration

Was the PR tested

  • I have added >=1 unit test(s) for every new method I have added.
  • I have ensured all unit tests pass

@fabianlim fabianlim marked this pull request as draft May 15, 2024 15:50
@fabianlim fabianlim force-pushed the accel-pr branch 3 times, most recently from 7b5e354 to 6095b09 Compare May 16, 2024 02:21
@fabianlim fabianlim changed the title DO NOT REVIEW: Acceleration framework Install Acceleration Framework into Training Script May 16, 2024
@fabianlim fabianlim marked this pull request as ready for review May 16, 2024 02:26
@Ssukriti
Copy link
Collaborator

from DM -

It would be great if there was a way to pass a Python dataclass instead of yaml file
python sft_trainer.py \ --acceleration_framework_config_file framework.yaml

we can create a new GPTQ-LOra config here https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/config/peft_config.py#L21
so users of fms-tuning have exactly same experience as Lora

Inside fms-tuning, we know for GPTQ-Lora we need to call AccelerationFramework() -> can we convert to a Python dataclass object accepted by AccelerationFramework -AccelerationFramework(qLORAcONFIG )

You know what dataclass instance it is so you can parse params from that

we can implement this in steps though instead of requiring too much change to your framework-
Step 1 - users of fms-tuning continue to use dataclasses -> we will have to internally convert params passed to the yaml file you want (we wont expose yaml to users of tuning, but keep it inside tuning library ) -> call accelerationframework same way with yaml
step 2- > acceleration framework adds support for Python Dataclasses for each plugin as well besides yaml , so we dont have to create yaml but can just pass params needed for a certain plugin (edited)

if we do step 1, we can merge code and get it working, while working on step 2 as enhancement

We are discussing if we need 1 dataclass or multiple, but we can start with step 1 with dataclass only in tuning and see how it goes

@fabianlim
Copy link
Collaborator Author

fabianlim commented Jun 13, 2024

@Ssukriti @alex-jw-brooks I have made the requested changes:

  • now with dataclass parsing, it should be easy for the user to figure out what the required arguments are from inspecting the dataclasses. There is no more specifying of YAML.
  • see for example quantized_lora_config.
  • I have put unit tests in tests. It is almost complete but I might add one or two more.

Notes:

  • it was not that straightforward to get transformers.HfArgumentParser work with nested dataclasses. I had to implement two additional utilities here. I can try to further simply this implementation.
    • made some attempts to simplify this with a decorator parsable_dataclass
  • The argument parsing works as expected, you can see the tests

@fabianlim fabianlim force-pushed the accel-pr branch 3 times, most recently from cf23771 to 96ad8bf Compare June 13, 2024 12:48
@Ssukriti
Copy link
Collaborator

Ssukriti commented Jun 18, 2024

@fabianlim thank you for the redesign and the accelerationframework unit tests. Design looks good, so no major changes needed. Just few comments:

  1. question above on when is GPTQLora applicable

  2. The tests for acceleration framework check integration in detail, thank you. But I think if possible it would be beneficial to add some top level tests to test_sft_trainer as well, just to ensure
    a. if quantization_lora config passed, tuning still succeeds and model after tuning can still be loaded and inferred on (like rest of tuning unit tests) . I understand you may need a quantized base model to do these unit tests? so let me know if its feasible to add or not .
    b. similarly with kernels as well, if passed, tuning still succeeds.
    Either way might be good to add tests for failure case as well in test_sft_trainer. What if user passes GPTQLoraconfig to an unsupported model that is not quantized - what happens then? is error given and caught

  3. Documentation may need some more updates - but we can do that in subsequent PRs. Mainly if there is any limitation on what model types we can apply QLoRA to (needs a 4bit quantized model), that should be documented and highlighted in README

@fabianlim
Copy link
Collaborator Author

fabianlim commented Jun 19, 2024

@Ssukriti thank you for reviewing!

since this requires augmentation and peft_config , is quantized_lora_config only expected to work with LoRA tuning and Loraconfig , or can one also apply it with fine tuning and prompt tuning?

The AccelerationFramework has logic inside its plugins to check if the peft_config is not properly set. The peft_config is passed to the framework via the augmentation step. Maybe I can have a unit test to demonstrate this.

The tests for acceleration framework check integration in detail, thank you. But I think if possible it would be beneficial to add some top level tests to test_sft_trainer as well, just to ensure
a. if quantization_lora config passed, tuning still succeeds and model after tuning can still be loaded and inferred on (like rest of tuning unit tests) . I understand you may need a quantized base model to do these unit tests? so let me know if its feasible to add or not .
b. similarly with kernels as well, if passed, tuning still succeeds.
Either way might be good to add tests for failure case as well in test_sft_trainer. What if user passes GPTQLoraconfig to an unsupported model that is not quantized - what happens then? is error given and caught

Actually I think you are referring to this kind of test. We do already have one.

  • I can add two much such tests, i) for fused ops and kernels, and ii) a negative test on an unsupported (non-quantized) model.

Documentation may need some more updates - but we can do that in subsequent PRs. Mainly if there is any limitation on what model types we can apply QLoRA to (needs a 4bit quantized model), that should be documented and highlighted in README

ok np!

@fabianlim
Copy link
Collaborator Author

@Ssukriti I have added a set of extra tests

These tests involve calling sft_trainer.train, so I think these are the tests that you are requesting for. They can be considered integration tests, to ensure correct working of the framework integrated with the trainer.

Therefore, I have added the following integration tests:

  1. test_framework_raises_due_to_invalid_arguments: this test will demonstrate that the framework plugins will also check the arguments, and throw if invalid arguments are passed in. For example,. if we attempt to use accelerated peft and no peft_config is passed, then it throws
  2. test_framework_intialized_properly_peft: this was refactored from an older test, but now it also demonstrates BNBQLora properly loading. Previously it was only GPTQ_LoRA happy path
  3. test_framework_intialized_properly_foak: this one demonstrates that fused ops and kernels are also integrated properly.

Copy link
Collaborator

@Ssukriti Ssukriti left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will approve after conflicts are merged and all checks pass. Suggested minor edits to make it clear that GPTQLora needs a peft_config.LORAconfig passed. we have to make that clear in our documentation as users of tuning will not know that.

Remaining work in subsequent PRs after this PR is merged:

  1. we need to ensure that in CI/CD all the tests run regularly and they are not skipped. That means all dependencuies should be installed for our tests to run regularly . Purpose is to ensure with every release, all tests pass.
  2. Unit tests - Additional unit tests added are good, thank you. I did want to ensure model after tuning after GPTQLora is of correct format , and can be loaded and inferred correctly. We have had issues in past, when something would change and model format produced is no longer correct - we should have tests to capture that to have full confidence (will DM about this)

README.md Show resolved Hide resolved
README.md Show resolved Hide resolved
tuning/sft_trainer.py Show resolved Hide resolved
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
fabianlim added 14 commits June 20, 2024 10:16
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@fabianlim
Copy link
Collaborator Author

fabianlim commented Jun 20, 2024

@Ssukriti i have rebased the changes, and also created an issue here to track the remaining work items #205. Working on making the tests pass Update: all checks passing now.

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
@Ssukriti Ssukriti merged commit fc8938d into foundation-model-stack:main Jun 20, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants