Skip to content

Commit

Permalink
add one more check in get_framework and other fixes.
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Jun 20, 2024
1 parent 3e2b7c5 commit 8335f1d
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 15 deletions.
20 changes: 12 additions & 8 deletions tests/acceleration/test_acceleration_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@
# - see plugins/framework/tests/test_framework.py


@pytest.mark.skipif(
not is_fms_accelerate_available(plugins="peft"),
reason="Only runs if fms-accelerate is installed along with accelerated-peft plugin",
)
def test_acceleration_framework_fail_construction():
"""Ensure that construct of the framework will fail if rules regarding
the dataclasess are violated.
Expand Down Expand Up @@ -301,10 +305,10 @@ def test_framework_raises_due_to_invalid_arguments(
(
"peft.quantization.bitsandbytes",
create_mock_plugin_class_and_spy(
"PluginMock",
BNBAccelerationPlugin
if is_fms_accelerate_available(plugins='peft')
else object
"PluginMock",
BNBAccelerationPlugin
if is_fms_accelerate_available(plugins="peft")
else object,
),
),
),
Expand All @@ -314,10 +318,10 @@ def test_framework_raises_due_to_invalid_arguments(
(
"peft.quantization.auto_gptq",
create_mock_plugin_class_and_spy(
"PluginMock",
"PluginMock",
AutoGPTQAccelerationPlugin
if is_fms_accelerate_available(plugins='peft')
else object
if is_fms_accelerate_available(plugins="peft")
else object,
),
),
),
Expand Down Expand Up @@ -379,7 +383,7 @@ def test_framework_intialized_properly_peft(
not is_fms_accelerate_available(plugins=["peft", "foak"]),
reason=(
"Only runs if fms-accelerate is installed along with accelerated-peft "
"and foak plugins",
"and foak plugins"
),
)
def test_framework_intialized_properly_foak():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,14 @@ def get_framework(self):
self.to_yaml(f.name)
return AccelerationFramework(f.name)
else:
raise ValueError(
"No acceleration framework package found. To use, first "
"ensure that 'pip install -e.[fms-accel]' is done first to "
"obtain the acceleration framework dependency. Additional "
"acceleration plugins make be required depending on the requsted "
"acceleration. See README.md for instructions."
)
if len(self.to_dict()) > 0:
raise ValueError(
"No acceleration framework package found. To use, first "
"ensure that 'pip install -e.[fms-accel]' is done first to "
"obtain the acceleration framework dependency. Additional "
"acceleration plugins make be required depending on the requsted "
"acceleration. See README.md for instructions."
)

def to_dict(self):
"""convert a valid AccelerationFrameworkConfig dataclass into a schema-less dictionary
Expand Down
2 changes: 2 additions & 0 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
write_termination_log,
)


def train(
model_args: configs.ModelArguments,
data_args: configs.DataArguments,
Expand Down Expand Up @@ -578,5 +579,6 @@ def main(**kwargs): # pylint: disable=unused-argument
write_termination_log(f"Unhandled exception during training: {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)


if __name__ == "__main__":
fire.Fire(main)

0 comments on commit 8335f1d

Please sign in to comment.