diff --git a/README.md b/README.md index 7a6bbd7d2..43fa9c216 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ pip install fms-hf-tuning[aim] If you wish to use [fms-acceleration](https://github.com/foundation-model-stack/fms-acceleration), you need to install it. ``` -pip install git+https://github.com/foundation-model-stack/fms-acceleration.git#subdirectory=plugins/framework +pip install fms-hf-tuning[fms-accel] ``` `fms-acceleration` is a collection of plugins that packages that accelerate fine-tuning / training of large models, as part of the `fms-hf-tuning` suite. For more details on see [this section below](#fms-acceleration). @@ -389,7 +389,7 @@ Equally you can pass in a JSON configuration for running tuning. See [build doc] To access `fms-acceleration` features the `[fms-accel]` dependency must first be installed: ``` - $ pip install https://github.com/foundation-model-stack/fms-acceleration.git#subdirectory=plugins/framework + $ pip install fms-hf-tuning[fms-accel] ``` Furthermore, the required `fms-acceleration` plugin must be installed. This is done via the command line utility `fms_acceleration.cli`. To show available plugins: diff --git a/pyproject.toml b/pyproject.toml index 1a8c7ba75..a869fc9f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,8 @@ dependencies = [ dev = ["wheel>=0.42.0,<1.0", "packaging>=23.2,<24", "ninja>=1.11.1.1,<2.0", "scikit-learn>=1.0, <2.0", "boto3>=1.34, <2.0"] flash-attn = ["flash-attn>=2.5.3,<3.0"] aim = ["aim>=3.19.0,<4.0"] +fms-accel = ["fms-acceleration>=0.1"] + [tool.setuptools.packages.find] exclude = ["tests", "tests.*"] diff --git a/tuning/config/acceleration_configs/acceleration_framework_config.py b/tuning/config/acceleration_configs/acceleration_framework_config.py index 2d666a0be..ed0f54d1e 100644 --- a/tuning/config/acceleration_configs/acceleration_framework_config.py +++ b/tuning/config/acceleration_configs/acceleration_framework_config.py @@ -172,16 +172,29 @@ def get_framework(self): NamedTemporaryFile, ) - with NamedTemporaryFile("w") as f: - self.to_yaml(f.name) - return AccelerationFramework(f.name) + try: + with NamedTemporaryFile("w") as f: + self.to_yaml(f.name) + return AccelerationFramework(f.name) + except ValueError as e: + (msg,) = e.args + + # AcceleratorFramework raises ValueError if it + # fails to configure any plugin + if self.is_empty() and msg.startswith("No plugins could be configured"): + # in the case when the error was thrown when + # the acceleration framework config was empty + # then this is expected. + return None + + raise e else: if not self.is_empty(): raise ValueError( - "No acceleration framework package found. To use, first ensure that " - "'pip install git+https://github.com/foundation-model-stack/fms-acceleration.git#subdirectory=plugins/framework' " # pylint: disable=line-too-long - "is done first to obtain the acceleration framework dependency. Additional " - "acceleration plugins make be required depending on the requested " + "No acceleration framework package found. To use, first " + "ensure that 'pip install fms-hf-tuning[fms-accel]' is done first to " + "obtain the acceleration framework dependency. Additional " + "acceleration plugins make be required depending on the requsted " "acceleration. See README.md for instructions." ) @@ -244,7 +257,7 @@ def _descend_and_set(path: List[str], d: Dict): "to be installed. Please do:\n" + "\n".join( [ - "- python -m fms_acceleration install " + "- python -m fms_acceleration.cli install " f"{AccelerationFrameworkConfig.PACKAGE_PREFIX + x}" for x in annotate.required_packages ] diff --git a/tuning/config/acceleration_configs/fused_ops_and_kernels.py b/tuning/config/acceleration_configs/fused_ops_and_kernels.py index 91df8c9dc..ded51415e 100644 --- a/tuning/config/acceleration_configs/fused_ops_and_kernels.py +++ b/tuning/config/acceleration_configs/fused_ops_and_kernels.py @@ -18,20 +18,13 @@ from typing import List # Local -from .utils import ( - EnsureTypes, - ensure_nested_dataclasses_initialized, - parsable_dataclass, -) +from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass @parsable_dataclass @dataclass class FusedLoraConfig(List): - # to help the HfArgumentParser arrive at correct types - __args__ = [EnsureTypes(str, bool)] - # load unsloth optimizations for these 4bit base layer weights. # currently only support "auto_gptq" and "bitsandbytes" base_layer: str = None @@ -41,9 +34,6 @@ class FusedLoraConfig(List): def __post_init__(self): - # reset for another parse - self.__args__[0].reset() - if self.base_layer is not None and self.base_layer not in { "auto_gptq", "bitsandbytes", @@ -60,9 +50,6 @@ def __post_init__(self): @dataclass class FastKernelsConfig(List): - # to help the HfArgumentParser arrive at correct types - __args__ = [EnsureTypes(bool, bool, bool)] - # fast loss triton kernels fast_loss: bool = False @@ -74,9 +61,6 @@ class FastKernelsConfig(List): def __post_init__(self): - # reset for another parse - self.__args__[0].reset() - if not self.fast_loss == self.fast_rsm_layernorm == self.fast_rope_embeddings: raise ValueError( "fast_loss, fast_rms_layernorm and fast_rope_embedding must be enabled " diff --git a/tuning/config/acceleration_configs/quantized_lora_config.py b/tuning/config/acceleration_configs/quantized_lora_config.py index d8174438c..a55ac55d6 100644 --- a/tuning/config/acceleration_configs/quantized_lora_config.py +++ b/tuning/config/acceleration_configs/quantized_lora_config.py @@ -18,11 +18,7 @@ from typing import List # Local -from .utils import ( - EnsureTypes, - ensure_nested_dataclasses_initialized, - parsable_dataclass, -) +from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass @parsable_dataclass @@ -49,9 +45,6 @@ def __post_init__(self): @dataclass class BNBQLoraConfig(List): - # to help the HfArgumentParser arrive at correct types - __args__ = [EnsureTypes(str, bool)] - # type of quantization applied quant_type: str = "nf4" @@ -61,9 +54,6 @@ class BNBQLoraConfig(List): def __post_init__(self): - # reset for another parse - self.__args__[0].reset() - if self.quant_type not in ["nf4", "fp4"]: raise ValueError("quant_type can only be either 'nf4' or 'fp4.")