Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
dchourasia committed Jul 10, 2024
2 parents 3386a1c + bf22a2f commit 13fd496
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 38 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pip install fms-hf-tuning[aim]

If you wish to use [fms-acceleration](https://github.com/foundation-model-stack/fms-acceleration), you need to install it.
```
pip install git+https://github.com/foundation-model-stack/fms-acceleration.git#subdirectory=plugins/framework
pip install fms-hf-tuning[fms-accel]
```
`fms-acceleration` is a collection of plugins that packages that accelerate fine-tuning / training of large models, as part of the `fms-hf-tuning` suite. For more details on see [this section below](#fms-acceleration).

Expand Down Expand Up @@ -389,7 +389,7 @@ Equally you can pass in a JSON configuration for running tuning. See [build doc]

To access `fms-acceleration` features the `[fms-accel]` dependency must first be installed:
```
$ pip install https://github.com/foundation-model-stack/fms-acceleration.git#subdirectory=plugins/framework
$ pip install fms-hf-tuning[fms-accel]
```

Furthermore, the required `fms-acceleration` plugin must be installed. This is done via the command line utility `fms_acceleration.cli`. To show available plugins:
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ dependencies = [
dev = ["wheel>=0.42.0,<1.0", "packaging>=23.2,<24", "ninja>=1.11.1.1,<2.0", "scikit-learn>=1.0, <2.0", "boto3>=1.34, <2.0"]
flash-attn = ["flash-attn>=2.5.3,<3.0"]
aim = ["aim>=3.19.0,<4.0"]
fms-accel = ["fms-acceleration>=0.1"]


[tool.setuptools.packages.find]
exclude = ["tests", "tests.*"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,16 +172,29 @@ def get_framework(self):
NamedTemporaryFile,
)

with NamedTemporaryFile("w") as f:
self.to_yaml(f.name)
return AccelerationFramework(f.name)
try:
with NamedTemporaryFile("w") as f:
self.to_yaml(f.name)
return AccelerationFramework(f.name)
except ValueError as e:
(msg,) = e.args

# AcceleratorFramework raises ValueError if it
# fails to configure any plugin
if self.is_empty() and msg.startswith("No plugins could be configured"):
# in the case when the error was thrown when
# the acceleration framework config was empty
# then this is expected.
return None

raise e
else:
if not self.is_empty():
raise ValueError(
"No acceleration framework package found. To use, first ensure that "
"'pip install git+https://github.com/foundation-model-stack/fms-acceleration.git#subdirectory=plugins/framework' " # pylint: disable=line-too-long
"is done first to obtain the acceleration framework dependency. Additional "
"acceleration plugins make be required depending on the requested "
"No acceleration framework package found. To use, first "
"ensure that 'pip install fms-hf-tuning[fms-accel]' is done first to "
"obtain the acceleration framework dependency. Additional "
"acceleration plugins make be required depending on the requsted "
"acceleration. See README.md for instructions."
)

Expand Down Expand Up @@ -244,7 +257,7 @@ def _descend_and_set(path: List[str], d: Dict):
"to be installed. Please do:\n"
+ "\n".join(
[
"- python -m fms_acceleration install "
"- python -m fms_acceleration.cli install "
f"{AccelerationFrameworkConfig.PACKAGE_PREFIX + x}"
for x in annotate.required_packages
]
Expand Down
18 changes: 1 addition & 17 deletions tuning/config/acceleration_configs/fused_ops_and_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,13 @@
from typing import List

# Local
from .utils import (
EnsureTypes,
ensure_nested_dataclasses_initialized,
parsable_dataclass,
)
from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass


@parsable_dataclass
@dataclass
class FusedLoraConfig(List):

# to help the HfArgumentParser arrive at correct types
__args__ = [EnsureTypes(str, bool)]

# load unsloth optimizations for these 4bit base layer weights.
# currently only support "auto_gptq" and "bitsandbytes"
base_layer: str = None
Expand All @@ -41,9 +34,6 @@ class FusedLoraConfig(List):

def __post_init__(self):

# reset for another parse
self.__args__[0].reset()

if self.base_layer is not None and self.base_layer not in {
"auto_gptq",
"bitsandbytes",
Expand All @@ -60,9 +50,6 @@ def __post_init__(self):
@dataclass
class FastKernelsConfig(List):

# to help the HfArgumentParser arrive at correct types
__args__ = [EnsureTypes(bool, bool, bool)]

# fast loss triton kernels
fast_loss: bool = False

Expand All @@ -74,9 +61,6 @@ class FastKernelsConfig(List):

def __post_init__(self):

# reset for another parse
self.__args__[0].reset()

if not self.fast_loss == self.fast_rsm_layernorm == self.fast_rope_embeddings:
raise ValueError(
"fast_loss, fast_rms_layernorm and fast_rope_embedding must be enabled "
Expand Down
12 changes: 1 addition & 11 deletions tuning/config/acceleration_configs/quantized_lora_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@
from typing import List

# Local
from .utils import (
EnsureTypes,
ensure_nested_dataclasses_initialized,
parsable_dataclass,
)
from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass


@parsable_dataclass
Expand All @@ -49,9 +45,6 @@ def __post_init__(self):
@dataclass
class BNBQLoraConfig(List):

# to help the HfArgumentParser arrive at correct types
__args__ = [EnsureTypes(str, bool)]

# type of quantization applied
quant_type: str = "nf4"

Expand All @@ -61,9 +54,6 @@ class BNBQLoraConfig(List):

def __post_init__(self):

# reset for another parse
self.__args__[0].reset()

if self.quant_type not in ["nf4", "fp4"]:
raise ValueError("quant_type can only be either 'nf4' or 'fp4.")

Expand Down

0 comments on commit 13fd496

Please sign in to comment.