Skip to content

Commit

Permalink
replace yaml with dataclass args
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Jun 13, 2024
1 parent f05580e commit 9580b99
Show file tree
Hide file tree
Showing 9 changed files with 570 additions and 123 deletions.
24 changes: 0 additions & 24 deletions fixtures/accelerated-peft-autogptq-sample-configuration.yaml

This file was deleted.

97 changes: 40 additions & 57 deletions tests/acceleration/test_acceleration_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

# Standard
from unittest.mock import patch
import os
import tempfile

# Third Party
Expand All @@ -29,6 +28,12 @@
from tuning import sft_trainer
from tuning.utils.import_utils import is_fms_accelerate_available
import tuning.config.configs as config
from tuning.config.acceleration_configs import (
AccelerationFrameworkConfig, QuantizedLoraConfig
)
from tuning.config.acceleration_configs.quantized_lora_config import (
AutoGPTQLoraConfig, BNBQLoraConfig
)

# pylint: disable=import-error
if is_fms_accelerate_available():
Expand All @@ -47,11 +52,6 @@
# repository.
# - see plugins/framework/tests/test_framework.py

CONFIG_PATH_AUTO_GPTQ = os.path.join(
os.path.dirname(__file__),
"../../fixtures/accelerated-peft-autogptq-sample-configuration.yaml",
)

# helper function
def create_mock_plugin_class(plugin_cls):
"Create a mocked acceleration framework class that can be used to spy"
Expand Down Expand Up @@ -88,45 +88,51 @@ def get_callbacks_and_ready_for_train(self, *args, **kwargs):
return MockPlugin


@pytest.mark.skipif(
not is_fms_accelerate_available(),
reason="Only runs if fms-accelerate is installed",
)
def test_construct_framework_with_empty_file():
def test_construct_framework_config_with_incorrect_configurations():
"Ensure that framework configuration cannot have empty body"

with pytest.raises(ValueError) as e:
with tempfile.NamedTemporaryFile("w") as f:
yaml.dump({KEY_PLUGINS: None}, f)
AccelerationFramework(f.name)

e.match(f"Configuration file must contain a '{KEY_PLUGINS}' body")

with pytest.raises(
ValueError, match="AccelerationFrameworkConfig construction requires at least one dataclass"
):
AccelerationFrameworkConfig.from_dataclasses()

# test a currently not supported config
with pytest.raises(
ValueError, match="only 'from_quantized' == True currently supported."
):
AutoGPTQLoraConfig(from_quantized=False)

# test an invalid activation of two standalone configs.
quantized_lora_config = QuantizedLoraConfig(
auto_gptq=AutoGPTQLoraConfig(), bnb_qlora=BNBQLoraConfig()
)
with pytest.raises(
ValueError, match="Configuration path 'peft.quantization' already has one standalone config."
):
AccelerationFrameworkConfig.from_dataclasses(quantized_lora_config).get_framework()

@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="peft"),
reason="Only runs if fms-accelerate is installed along with accelerated-peft plugin",
)
def test_construct_framework_with_auto_gptq_peft():
"Ensure that framework object initializes correctly with the sample config"
"Ensure that framework object is correctly configured."

# the test util below requires to read the file first
with open(CONFIG_PATH_AUTO_GPTQ, encoding="utf-8") as f:
configuration = yaml.safe_load(f)
quantized_lora_config = QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig())
acceleration_config = AccelerationFrameworkConfig.from_dataclasses(quantized_lora_config)

# for this test we skip the require package check as second order package
# dependencies of accelerated_peft is not required
with build_framework_and_maybe_instantiate(
[],
configuration["plugins"],
acceleration_config.to_dict(),
reset_registrations=False,
require_packages_check=False,
) as framework:

# the configuration file should successfully activate the plugin
assert len(framework.active_plugins) == 1


@pytest.mark.skipif(
not is_fms_accelerate_available(),
reason="Only runs if fms-accelerate is installed",
Expand All @@ -145,46 +151,24 @@ def test_framework_not_installed_or_initalized_properly():
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
framework_args = config.AccelerationFrameworkArguments(
acceleration_framework_config_file=CONFIG_PATH_AUTO_GPTQ,
)
quantized_lora_config = QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig())

# patch is_fms_accelerate_available to return False inside sft_trainer
# to simulate fms_acceleration not installed
with patch(
"tuning.sft_trainer.is_fms_accelerate_available", return_value=False
"tuning.config.acceleration_configs.acceleration_framework_config.is_fms_accelerate_available", return_value=False
):
with pytest.raises(ValueError) as e:
with pytest.raises(
ValueError,
match="No acceleration framework package found."
):
sft_trainer.train(
model_args,
data_args,
training_args,
tune_config,
acceleration_framework_args=framework_args,
quantized_lora_config=quantized_lora_config
)
e.match("Specified acceleration framework config")

# test with a dummy configuration file that will fail to activate any
# framework plugin
with tempfile.NamedTemporaryFile("w") as f:
yaml.dump({KEY_PLUGINS: {"dummy": 1}}, f)

framework_args_dummy_file = config.AccelerationFrameworkArguments(
acceleration_framework_config_file=f.name,
)

# patch is_fms_accelerate_available to return False inside sft_trainer
# to simulate fms_acceleration not installed
with pytest.raises(ValueError) as e:
sft_trainer.train(
model_args,
data_args,
training_args,
tune_config,
acceleration_framework_args=framework_args_dummy_file,
)
e.match("No plugins could be configured.")


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="peft"),
Expand All @@ -204,9 +188,7 @@ def test_framework_intialized_properly():
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
framework_args = config.AccelerationFrameworkArguments(
acceleration_framework_config_file=CONFIG_PATH_AUTO_GPTQ,
)
quantized_lora_config = QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig())

# create mocked plugin class for spying
MockedClass = create_mock_plugin_class(AutoGPTQAccelerationPlugin)
Expand All @@ -223,7 +205,8 @@ def test_framework_intialized_properly():
data_args,
training_args,
tune_config,
acceleration_framework_args=framework_args,
# acceleration_framework_args=framework_args,
quantized_lora_config=quantized_lora_config
)

# spy to ensure that the plugin functions were called.
Expand Down
18 changes: 18 additions & 0 deletions tuning/config/acceleration_configs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .acceleration_framework_config import AccelerationFrameworkConfig

from .quantized_lora_config import QuantizedLoraConfig
from .fused_ops_and_kernels import FusedOpsAndKernelsConfig
Loading

0 comments on commit 9580b99

Please sign in to comment.