Skip to content

Commit

Permalink
add missing peft config test
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Jun 19, 2024
1 parent 2597241 commit 03ae17e
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 11 deletions.
79 changes: 70 additions & 9 deletions tests/acceleration/test_acceleration_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

# First Party
from tests.helpers import causal_lm_train_kwargs
from tests.test_sft_trainer import BASE_LORA_KWARGS
from tests.test_sft_trainer import BASE_FT_KWARGS, BASE_LORA_KWARGS

# Local
from .spying_utils import create_mock_plugin_class_and_spy
Expand Down Expand Up @@ -55,8 +55,10 @@

if is_fms_accelerate_available(plugins="peft"):
# Third Party
from fms_acceleration_peft import AutoGPTQAccelerationPlugin
from fms_acceleration_peft import BNBAccelerationPlugin
from fms_acceleration_peft import (
AutoGPTQAccelerationPlugin,
BNBAccelerationPlugin,
)

if is_fms_accelerate_available(plugins="foak"):
# Third Party
Expand Down Expand Up @@ -242,16 +244,73 @@ def test_framework_raises_if_used_with_missing_package():
quantized_lora_config=quantized_lora_config,
)


invalid_kwargs_map = [
(
{
**BASE_LORA_KWARGS,
"model_name_or_path": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
},
AssertionError,
"need to run in fp16 mixed precision or load model in fp16",
),
(
{
**BASE_FT_KWARGS,
"model_name_or_path": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
"torch_dtype": torch.float16,
},
AssertionError,
"need peft_config to install PEFT adapters",
),
]


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="peft"),
reason="Only runs if fms-accelerate is installed along with accelerated-peft plugin",
)
@pytest.mark.parametrize(
"train_kwargs,exception",
invalid_kwargs_map,
ids=["triton_v2 requires fp16", "accelerated peft requires peft config"],
)
def test_framework_raises_due_to_invalid_arguments(
bad_train_kwargs, exception, exception_msg
):
"""Ensure that invalid arguments will be checked by the activated framework
plugin.
"""
with tempfile.TemporaryDirectory() as tempdir:
TRAIN_KWARGS = {
**bad_train_kwargs,
"output_dir": tempdir,
}
model_args, data_args, training_args, tune_config = causal_lm_train_kwargs(
TRAIN_KWARGS
)
quantized_lora_config = QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig())

# 1. activate the accelerated peft plugin
# 2. demonstrate that the invalid arguments will be checked
with pytest.raises(exception, match=exception_msg):
sft_trainer.train(
model_args,
data_args,
training_args,
tune_config,
quantized_lora_config=quantized_lora_config,
)


if is_fms_accelerate_available(plugins="peft"):
acceleration_configs_map = [
(
QuantizedLoraConfig(bnb_qlora=BNBQLoraConfig()),
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
(
"peft.quantization.bitsandbytes",
create_mock_plugin_class_and_spy(
"PluginMock", BNBAccelerationPlugin
)
create_mock_plugin_class_and_spy("PluginMock", BNBAccelerationPlugin),
),
),
(
Expand All @@ -261,19 +320,20 @@ def test_framework_raises_if_used_with_missing_package():
"peft.quantization.auto_gptq",
create_mock_plugin_class_and_spy(
"PluginMock", AutoGPTQAccelerationPlugin
)
),
),
),
]


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="peft"),
reason="Only runs if fms-accelerate is installed along with accelerated-peft plugin",
)
@pytest.mark.parametrize(
"quantized_lora_config,model_name_or_path,mock_and_spy",
"quantized_lora_config,model_name_or_path,mock_and_spy",
acceleration_configs_map,
ids=['bitsandbytes', 'auto_gptq'],
ids=["bitsandbytes", "auto_gptq"],
)
def test_framework_intialized_properly_peft(
quantized_lora_config, model_name_or_path, mock_and_spy
Expand Down Expand Up @@ -320,6 +380,7 @@ def test_framework_intialized_properly_peft(
assert spy["augmentation_calls"] == 1
assert spy["get_ready_for_train_calls"] == 1


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins=["peft", "foak"]),
reason=(
Expand Down
2 changes: 1 addition & 1 deletion tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,5 @@ def causal_lm_train_kwargs(train_kwargs):
training_args,
lora_config
if train_kwargs.get("peft_method") == "lora"
else prompt_tuning_config,
else (None if train_kwargs.get("peft_method") == "" else prompt_tuning_config,),
)
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def from_dataclasses(*dataclasses: Type):
for fi in fields(dc):
attr = getattr(dc, fi.name)
if attr is None:
continue # skip the None attributes
continue # skip the None attributes

if not is_dataclass(attr):
raise ValueError(
Expand Down

0 comments on commit 03ae17e

Please sign in to comment.