From 8335f1d07a0373ba76ccf2d9135b7812a6c61775 Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 20 Jun 2024 04:18:19 +0000 Subject: [PATCH] add one more check in get_framework and other fixes. Signed-off-by: Yu Chin Fabian Lim --- .../test_acceleration_framework.py | 20 +++++++++++-------- .../acceleration_framework_config.py | 15 +++++++------- tuning/sft_trainer.py | 2 ++ 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/tests/acceleration/test_acceleration_framework.py b/tests/acceleration/test_acceleration_framework.py index 8437e7c24..c907b8f06 100644 --- a/tests/acceleration/test_acceleration_framework.py +++ b/tests/acceleration/test_acceleration_framework.py @@ -71,6 +71,10 @@ # - see plugins/framework/tests/test_framework.py +@pytest.mark.skipif( + not is_fms_accelerate_available(plugins="peft"), + reason="Only runs if fms-accelerate is installed along with accelerated-peft plugin", +) def test_acceleration_framework_fail_construction(): """Ensure that construct of the framework will fail if rules regarding the dataclasess are violated. @@ -301,10 +305,10 @@ def test_framework_raises_due_to_invalid_arguments( ( "peft.quantization.bitsandbytes", create_mock_plugin_class_and_spy( - "PluginMock", - BNBAccelerationPlugin - if is_fms_accelerate_available(plugins='peft') - else object + "PluginMock", + BNBAccelerationPlugin + if is_fms_accelerate_available(plugins="peft") + else object, ), ), ), @@ -314,10 +318,10 @@ def test_framework_raises_due_to_invalid_arguments( ( "peft.quantization.auto_gptq", create_mock_plugin_class_and_spy( - "PluginMock", + "PluginMock", AutoGPTQAccelerationPlugin - if is_fms_accelerate_available(plugins='peft') - else object + if is_fms_accelerate_available(plugins="peft") + else object, ), ), ), @@ -379,7 +383,7 @@ def test_framework_intialized_properly_peft( not is_fms_accelerate_available(plugins=["peft", "foak"]), reason=( "Only runs if fms-accelerate is installed along with accelerated-peft " - "and foak plugins", + "and foak plugins" ), ) def test_framework_intialized_properly_foak(): diff --git a/tuning/config/acceleration_configs/acceleration_framework_config.py b/tuning/config/acceleration_configs/acceleration_framework_config.py index 64575276b..91547c3f6 100644 --- a/tuning/config/acceleration_configs/acceleration_framework_config.py +++ b/tuning/config/acceleration_configs/acceleration_framework_config.py @@ -176,13 +176,14 @@ def get_framework(self): self.to_yaml(f.name) return AccelerationFramework(f.name) else: - raise ValueError( - "No acceleration framework package found. To use, first " - "ensure that 'pip install -e.[fms-accel]' is done first to " - "obtain the acceleration framework dependency. Additional " - "acceleration plugins make be required depending on the requsted " - "acceleration. See README.md for instructions." - ) + if len(self.to_dict()) > 0: + raise ValueError( + "No acceleration framework package found. To use, first " + "ensure that 'pip install -e.[fms-accel]' is done first to " + "obtain the acceleration framework dependency. Additional " + "acceleration plugins make be required depending on the requsted " + "acceleration. See README.md for instructions." + ) def to_dict(self): """convert a valid AccelerationFrameworkConfig dataclass into a schema-less dictionary diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 1fcf1b231..de616fd22 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -62,6 +62,7 @@ write_termination_log, ) + def train( model_args: configs.ModelArguments, data_args: configs.DataArguments, @@ -578,5 +579,6 @@ def main(**kwargs): # pylint: disable=unused-argument write_termination_log(f"Unhandled exception during training: {e}") sys.exit(INTERNAL_ERROR_EXIT_CODE) + if __name__ == "__main__": fire.Fire(main)