diff --git a/tests/acceleration/test_acceleration_framework.py b/tests/acceleration/test_acceleration_framework.py index b5454ba8e..38d7d6bf9 100644 --- a/tests/acceleration/test_acceleration_framework.py +++ b/tests/acceleration/test_acceleration_framework.py @@ -24,7 +24,7 @@ # First Party from tests.helpers import causal_lm_train_kwargs -from tests.test_sft_trainer import BASE_LORA_KWARGS +from tests.test_sft_trainer import BASE_FT_KWARGS, BASE_LORA_KWARGS # Local from .spying_utils import create_mock_plugin_class_and_spy @@ -55,8 +55,10 @@ if is_fms_accelerate_available(plugins="peft"): # Third Party - from fms_acceleration_peft import AutoGPTQAccelerationPlugin - from fms_acceleration_peft import BNBAccelerationPlugin + from fms_acceleration_peft import ( + AutoGPTQAccelerationPlugin, + BNBAccelerationPlugin, + ) if is_fms_accelerate_available(plugins="foak"): # Third Party @@ -242,6 +244,65 @@ def test_framework_raises_if_used_with_missing_package(): quantized_lora_config=quantized_lora_config, ) + +invalid_kwargs_map = [ + ( + { + **BASE_LORA_KWARGS, + "model_name_or_path": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", + }, + AssertionError, + "need to run in fp16 mixed precision or load model in fp16", + ), + ( + { + **BASE_FT_KWARGS, + "model_name_or_path": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", + "torch_dtype": torch.float16, + }, + AssertionError, + "need peft_config to install PEFT adapters", + ), +] + + +@pytest.mark.skipif( + not is_fms_accelerate_available(plugins="peft"), + reason="Only runs if fms-accelerate is installed along with accelerated-peft plugin", +) +@pytest.mark.parametrize( + "train_kwargs,exception", + invalid_kwargs_map, + ids=["triton_v2 requires fp16", "accelerated peft requires peft config"], +) +def test_framework_raises_due_to_invalid_arguments( + bad_train_kwargs, exception, exception_msg +): + """Ensure that invalid arguments will be checked by the activated framework + plugin. + """ + with tempfile.TemporaryDirectory() as tempdir: + TRAIN_KWARGS = { + **bad_train_kwargs, + "output_dir": tempdir, + } + model_args, data_args, training_args, tune_config = causal_lm_train_kwargs( + TRAIN_KWARGS + ) + quantized_lora_config = QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig()) + + # 1. activate the accelerated peft plugin + # 2. demonstrate that the invalid arguments will be checked + with pytest.raises(exception, match=exception_msg): + sft_trainer.train( + model_args, + data_args, + training_args, + tune_config, + quantized_lora_config=quantized_lora_config, + ) + + if is_fms_accelerate_available(plugins="peft"): acceleration_configs_map = [ ( @@ -249,9 +310,7 @@ def test_framework_raises_if_used_with_missing_package(): "TinyLlama/TinyLlama-1.1B-Chat-v1.0", ( "peft.quantization.bitsandbytes", - create_mock_plugin_class_and_spy( - "PluginMock", BNBAccelerationPlugin - ) + create_mock_plugin_class_and_spy("PluginMock", BNBAccelerationPlugin), ), ), ( @@ -261,19 +320,20 @@ def test_framework_raises_if_used_with_missing_package(): "peft.quantization.auto_gptq", create_mock_plugin_class_and_spy( "PluginMock", AutoGPTQAccelerationPlugin - ) + ), ), ), ] + @pytest.mark.skipif( not is_fms_accelerate_available(plugins="peft"), reason="Only runs if fms-accelerate is installed along with accelerated-peft plugin", ) @pytest.mark.parametrize( - "quantized_lora_config,model_name_or_path,mock_and_spy", + "quantized_lora_config,model_name_or_path,mock_and_spy", acceleration_configs_map, - ids=['bitsandbytes', 'auto_gptq'], + ids=["bitsandbytes", "auto_gptq"], ) def test_framework_intialized_properly_peft( quantized_lora_config, model_name_or_path, mock_and_spy @@ -320,6 +380,7 @@ def test_framework_intialized_properly_peft( assert spy["augmentation_calls"] == 1 assert spy["get_ready_for_train_calls"] == 1 + @pytest.mark.skipif( not is_fms_accelerate_available(plugins=["peft", "foak"]), reason=( diff --git a/tests/helpers.py b/tests/helpers.py index a88ae3ef8..221697973 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -43,5 +43,5 @@ def causal_lm_train_kwargs(train_kwargs): training_args, lora_config if train_kwargs.get("peft_method") == "lora" - else prompt_tuning_config, + else (None if train_kwargs.get("peft_method") == "" else prompt_tuning_config,), ) diff --git a/tuning/config/acceleration_configs/acceleration_framework_config.py b/tuning/config/acceleration_configs/acceleration_framework_config.py index d3ac90145..64575276b 100644 --- a/tuning/config/acceleration_configs/acceleration_framework_config.py +++ b/tuning/config/acceleration_configs/acceleration_framework_config.py @@ -123,7 +123,7 @@ def from_dataclasses(*dataclasses: Type): for fi in fields(dc): attr = getattr(dc, fi.name) if attr is None: - continue # skip the None attributes + continue # skip the None attributes if not is_dataclass(attr): raise ValueError(