From 891cef86180eeb2a911f3f00a75574a97e4ebfae Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Thu, 13 Jun 2024 06:40:33 +0000 Subject: [PATCH] fmt + lint Signed-off-by: Yu Chin Fabian Lim --- README.md | 44 +++-- .../test_acceleration_framework.py | 41 +++-- .../config/acceleration_configs/__init__.py | 4 +- .../acceleration_framework_config.py | 164 ++++++++---------- .../fused_ops_and_kernels.py | 37 ++-- .../quantized_lora_config.py | 27 +-- tuning/config/acceleration_configs/utils.py | 19 +- tuning/sft_trainer.py | 14 +- 8 files changed, 180 insertions(+), 170 deletions(-) diff --git a/README.md b/README.md index ea37775bb..6d8eb575d 100644 --- a/README.md +++ b/README.md @@ -336,29 +336,37 @@ tuning/sft_trainer.py \ `fms-acceleration` is fuss-free approach to access a curated collection of acceleration plugins that acclerate your `tuning/sft-trainer.py` experience. Accelerations that apply to a variety of use-cases, e.g., PeFT / full-finetuning, are being planned for. As such, the accelerations are grouped into *plugins*; only install the plugins needed for the acceleration of interest. The plugins are housed in the [seperate repository found here](https://github.com/foundation-model-stack/fms-acceleration). -Basic usage includes these steps: +To access `fms-acceleration` features the `[fms-accel]` dependency must first be installed: + ``` + $ pip install -e .[fms-accel] + ``` -1. Install the `[fms-accel]` dependency: - ``` - $ pip install -e .[fms-accel] - ``` +Furthermore, the required `fms-acceleration` plugin must be installed. This is done via the command line utility `fms_acceleration.cli`. To show available plugins: + ``` + $ python -m fms_acceleration.cli plugins + ``` +as well as to install the `fms_acceleration_peft`: - The installs the command line utility `fms_acceleration.cli`, used to install plugins. -3. `install` the required framework plugins; we install the `fms-acceleration-peft` plugin for GPTQ-LoRA tuning with triton v2 as: - ``` - python -m fms_acceleration.cli install fms_acceleration_peft - ``` + ``` + $ python -m fms_acceleration.cli install fms_acceleration_peft + ``` -5. Run `sft_trainer.py` providing the acceleration configuration and arguments; given the basic flow assumption that we simply re-use the same `sft_trainer.py` arguments as we had without using the `fms_acceleration` package: - ``` - python sft_trainer.py \ - --acceleration_framework_config_file framework_config.yaml \ - ... # arguments - ``` +If you do not know what plugin to install (or forget), the framework will remind - See [this sample configuration for GPTQ-LoRA with triton v2](./fixtures/accelerated-peft-autogptq-sample-configuration.yaml) to be passed into `--acceleration_framework_config_file` above. +``` +An acceleration feature is requested by specifying the '--auto_gptq' argument, but the this requires acceleration packages to be installed. Please do: +- python -m fms_acceleration install fms_acceleration_peft +``` + +The list of configurations for various `fms_acceleration` plugins: +- [quantized_lora_config](./tuning/config/acceleration_configs/quantized_lora_config.py): For quantized 4bit LoRA training + - `--auto_gptq`: 4bit GPTQ_LoRA with AutoGPTQ + - `--bnb_qlora`: 4bit QLoRA with bitsandbytes +- [fused_ops_and_kernels](./tuning/config/acceleration_configs/fused_ops_and_kernels.py) (experimental): + - `--fused_lora`: fused lora for more efficient LoRA training. + - `--fast_kernels`: fast cross-entropy, rope, rms loss kernels. -Thats it! Activate `TRANSFORMERS_VERBOSITY=info` to see the huggingface trainer printouts and verify that `AccelerationFramework` is activated! +Activate `TRANSFORMERS_VERBOSITY=info` to see the huggingface trainer printouts and verify that `AccelerationFramework` is activated! ``` # this printout will be seen in huggingface trainer logs if acceleration is activated diff --git a/tests/acceleration/test_acceleration_framework.py b/tests/acceleration/test_acceleration_framework.py index 192b8c544..7ff09b3dd 100644 --- a/tests/acceleration/test_acceleration_framework.py +++ b/tests/acceleration/test_acceleration_framework.py @@ -18,7 +18,6 @@ # Third Party import pytest -import yaml # First Party from tests.helpers import causal_lm_train_kwargs @@ -26,20 +25,20 @@ # Local from tuning import sft_trainer -from tuning.utils.import_utils import is_fms_accelerate_available -import tuning.config.configs as config from tuning.config.acceleration_configs import ( - AccelerationFrameworkConfig, QuantizedLoraConfig + AccelerationFrameworkConfig, + QuantizedLoraConfig, ) from tuning.config.acceleration_configs.quantized_lora_config import ( - AutoGPTQLoraConfig, BNBQLoraConfig + AutoGPTQLoraConfig, + BNBQLoraConfig, ) +from tuning.utils.import_utils import is_fms_accelerate_available # pylint: disable=import-error if is_fms_accelerate_available(): # Third Party - from fms_acceleration.framework import KEY_PLUGINS, AccelerationFramework from fms_acceleration.utils.test_utils import build_framework_and_maybe_instantiate if is_fms_accelerate_available(plugins="peft"): @@ -92,7 +91,8 @@ def test_construct_framework_config_with_incorrect_configurations(): "Ensure that framework configuration cannot have empty body" with pytest.raises( - ValueError, match="AccelerationFrameworkConfig construction requires at least one dataclass" + ValueError, + match="AccelerationFrameworkConfig construction requires at least one dataclass", ): AccelerationFrameworkConfig.from_dataclasses() @@ -102,14 +102,18 @@ def test_construct_framework_config_with_incorrect_configurations(): ): AutoGPTQLoraConfig(from_quantized=False) - # test an invalid activation of two standalone configs. + # test an invalid activation of two standalone configs. quantized_lora_config = QuantizedLoraConfig( auto_gptq=AutoGPTQLoraConfig(), bnb_qlora=BNBQLoraConfig() ) with pytest.raises( - ValueError, match="Configuration path 'peft.quantization' already has one standalone config." + ValueError, + match="Configuration path 'peft.quantization' already has one standalone config.", ): - AccelerationFrameworkConfig.from_dataclasses(quantized_lora_config).get_framework() + AccelerationFrameworkConfig.from_dataclasses( + quantized_lora_config + ).get_framework() + @pytest.mark.skipif( not is_fms_accelerate_available(plugins="peft"), @@ -119,7 +123,9 @@ def test_construct_framework_with_auto_gptq_peft(): "Ensure that framework object is correctly configured." quantized_lora_config = QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig()) - acceleration_config = AccelerationFrameworkConfig.from_dataclasses(quantized_lora_config) + acceleration_config = AccelerationFrameworkConfig.from_dataclasses( + quantized_lora_config + ) # for this test we skip the require package check as second order package # dependencies of accelerated_peft is not required @@ -133,6 +139,7 @@ def test_construct_framework_with_auto_gptq_peft(): # the configuration file should successfully activate the plugin assert len(framework.active_plugins) == 1 + @pytest.mark.skipif( not is_fms_accelerate_available(), reason="Only runs if fms-accelerate is installed", @@ -156,20 +163,22 @@ def test_framework_not_installed_or_initalized_properly(): # patch is_fms_accelerate_available to return False inside sft_trainer # to simulate fms_acceleration not installed with patch( - "tuning.config.acceleration_configs.acceleration_framework_config.is_fms_accelerate_available", return_value=False + "tuning.config.acceleration_configs.acceleration_framework_config." + "is_fms_accelerate_available", + return_value=False, ): with pytest.raises( - ValueError, - match="No acceleration framework package found." + ValueError, match="No acceleration framework package found." ): sft_trainer.train( model_args, data_args, training_args, tune_config, - quantized_lora_config=quantized_lora_config + quantized_lora_config=quantized_lora_config, ) + @pytest.mark.skipif( not is_fms_accelerate_available(plugins="peft"), reason="Only runs if fms-accelerate is installed along with accelerated-peft plugin", @@ -206,7 +215,7 @@ def test_framework_intialized_properly(): training_args, tune_config, # acceleration_framework_args=framework_args, - quantized_lora_config=quantized_lora_config + quantized_lora_config=quantized_lora_config, ) # spy to ensure that the plugin functions were called. diff --git a/tuning/config/acceleration_configs/__init__.py b/tuning/config/acceleration_configs/__init__.py index f361ae530..f971e2108 100644 --- a/tuning/config/acceleration_configs/__init__.py +++ b/tuning/config/acceleration_configs/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Local from .acceleration_framework_config import AccelerationFrameworkConfig - +from .fused_ops_and_kernels import FusedOpsAndKernelsConfig from .quantized_lora_config import QuantizedLoraConfig -from .fused_ops_and_kernels import FusedOpsAndKernelsConfig \ No newline at end of file diff --git a/tuning/config/acceleration_configs/acceleration_framework_config.py b/tuning/config/acceleration_configs/acceleration_framework_config.py index 0927beca8..32e524e6a 100644 --- a/tuning/config/acceleration_configs/acceleration_framework_config.py +++ b/tuning/config/acceleration_configs/acceleration_framework_config.py @@ -12,113 +12,92 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, fields, asdict, is_dataclass -from typing import Annotated, List, Dict, Type -from tuning.utils.import_utils import is_fms_accelerate_available -from .quantized_lora_config import AutoGPTQLoraConfig, BNBQLoraConfig -from .fused_ops_and_kernels import FusedLoraConfig, FastKernelsConfig -import yaml +# Standard +from dataclasses import asdict, dataclass, fields, is_dataclass +from typing import Annotated, Dict, List, Type import warnings +# Third Party +import yaml + +# Local +from .fused_ops_and_kernels import FastKernelsConfig, FusedLoraConfig +from .quantized_lora_config import AutoGPTQLoraConfig, BNBQLoraConfig +from tuning.utils.import_utils import is_fms_accelerate_available + if is_fms_accelerate_available(): # Third Party from fms_acceleration import AccelerationFramework # pylint: disable=import-error - from fms_acceleration.framework import KEY_PLUGINS - -# DESIGN OF FMS CONFIGS: -# - FMS will have differnt configs (probably one (or more) / plugin). -# - e,g. QuantizedLoraConfig will be for the accelerated_peft plugin -# - e.g, FusedOpsAndKernelsConfig will be for fused_ops_and_kernels plugin -# - FMS users will understand that to use thse configs, they will need -# to install the plugin that corresponds to that config -# - each FMS config will nest multiple dataclasses in a single level -# - typically each nested dataclass corresponds to one use case -# - e.g. for the QuantizedLoraConfig, two use cases of auto_gptq and bnb_qlora - -# - the HF dataclass argument parser will create position arguments from the -# FMS config -# - in the usal way, the keys of the FMS config will correspond to a --key -# - then the use case dataclass will be passed its attributes by position -# - hence, this is the reason why we enforce the FMS config to be -# single-level nested dataclasses. - -# DESIGN OF ACCELERATION CONFIGS -# - An ACCELERATION CONFIG is a monolothic config passed to AccelerationFramework -# - it is NOT meant to be user facing. Users will only configure -# use case dataclasses within. -# - however, uses can consult the annotations (see below) to understand -# which use-case config can be active at the same time. -# - it is a collection of use-case dataclasses (see above) -# - every use-case dataclass is annotated with a header -# - any two use-case dataclasses that are annotated with the -# same header, cannot be active at the same time. -# - An Acceleration Config is valid only if it does not have any -# use-case dataclass that violates these rules. + from fms_acceleration.framework import KEY_PLUGINS # pylint: disable=import-error # these are optional annotations that describe different behavior @dataclass class ConfigAnnotation: # AccelerationFramework configuration path - path: str + path: str # if omitted, will take the field name - key: str = None + key: str = None # only one that has single=True may exist under its path - # - this is used to indicate conflicting configurations - # - we do not allow two configurations that load the model to be + # - this is used to indicate conflicting configurations + # - we do not allow two configurations that load the model to be # activated at the same time - standalone: bool = False + standalone: bool = False # set to true to throw a user warning experimental: bool = False - # set to indicate what acceeleration packages are needed + # set to indicate what acceeleration packages are needed required_packages: List[str] = None + @dataclass class AccelerationFrameworkConfig: "Dataclass that manages configuration of AccelerationFramework" - PACKAGE_PREFIX = 'fms_acceleration_' + PACKAGE_PREFIX = "fms_acceleration_" # each field will a single-level use case dataclass auto_gptq: Annotated[ - AutoGPTQLoraConfig, ConfigAnnotation( - path="peft.quantization", standalone=True, - required_packages=['peft'] - ) + AutoGPTQLoraConfig, + ConfigAnnotation( + path="peft.quantization", standalone=True, required_packages=["peft"] + ), ] = None bitsandbytes: Annotated[ - BNBQLoraConfig, ConfigAnnotation( - path="peft.quantization", standalone=True, - required_packages=['peft'] - ) + BNBQLoraConfig, + ConfigAnnotation( + path="peft.quantization", standalone=True, required_packages=["peft"] + ), ] = None - + fused_lora: Annotated[ - FusedLoraConfig, ConfigAnnotation( - path="peft.quantization", key='fused_ops_and_kernels', + FusedLoraConfig, + ConfigAnnotation( + path="peft.quantization", + key="fused_ops_and_kernels", experimental=True, - required_packages=['foak'] - ) + required_packages=["foak"], + ), ] = None fast_kernels: Annotated[ - FastKernelsConfig, ConfigAnnotation( - path="peft.quantization", key='fused_ops_and_kernels', + FastKernelsConfig, + ConfigAnnotation( + path="peft.quantization", + key="fused_ops_and_kernels", experimental=True, - required_packages=['foak'] - ) + required_packages=["foak"], + ), ] = None @staticmethod def from_dataclasses(*dataclasses: Type): "Convert one or many FMS config dataclasses to a monolithic AccelerationConfig" - # Assumption: AccelerationFrameworkConfig only has fields that are # single level dataclasses # Assumption: dataclasses is a list of nested dataclasses @@ -138,19 +117,21 @@ def from_dataclasses(*dataclasses: Type): # make sure that it every field is a dataclass for fi in fields(dc): - attr = getattr(dc, fi.name) + attr = getattr(dc, fi.name) if attr is None: - break # skip the None attributes + break # skip the None attributes - if not is_dataclass(attr): - raise ValueError(f"field '{fi.name}' is specified but not a dataclass") + if not is_dataclass(attr): + raise ValueError( + f"field '{fi.name}' is specified but not a dataclass" + ) - # NOTE: should we also check that these are non-nested + # NOTE: should we also check that these are non-nested # dataclasses? nested_dataclasses.append(attr) config = AccelerationFrameworkConfig() - rem_fields = {fi.name: fi for fi in fields(config)} # these need to be parsed + rem_fields = {fi.name: fi for fi in fields(config)} # these need to be parsed # process the dataclasses that were nested # by assumption these are non-nested dataclasses @@ -172,7 +153,7 @@ def from_dataclasses(*dataclasses: Type): # assign the dataclass setattr(config, fi.name, dc) - del rem_fields[fi.name] # remove the field + del rem_fields[fi.name] # remove the field return config @@ -182,8 +163,12 @@ def get_framework(self): # to be eventually be made to be passed as a dict to Acceleration # Framework - from tempfile import NamedTemporaryFile - with NamedTemporaryFile('w') as f: + # Standard + from tempfile import ( # pylint: disable=import-outside-toplevel + NamedTemporaryFile, + ) + + with NamedTemporaryFile("w") as f: self.to_yaml(f.name) return AccelerationFramework(f.name) else: @@ -208,11 +193,11 @@ def _descend_and_set(path: List[str], d: Dict): r = configuration_contents for p in path[:-1]: if p not in r: - r[p] = {} # new branch + r[p] = {} # new branch r = r[p] p = path[-1] - r[p] = {**r.get(p, {}), **d} # merge dict if exists + r[p] = {**r.get(p, {}), **d} # merge dict if exists # parse each field already_set = set() @@ -222,19 +207,20 @@ def _descend_and_set(path: List[str], d: Dict): # this is the documented way to get annotations # https://docs.python.org/3/library/typing.html#typing.Annotated annotate: ConfigAnnotation - annotate, = fi.type.__metadata__ - prefix_path = tuple(annotate.path.split('.')) - if ( - annotate.standalone and - prefix_path in already_set - ): - raise ValueError(f"Configuration path '{'.'.join(prefix_path)}' already has one standalone config.") + (annotate,) = fi.type.__metadata__ + prefix_path = tuple(annotate.path.split(".")) + if annotate.standalone and prefix_path in already_set: + raise ValueError( + f"Configuration path '{'.'.join(prefix_path)}' " + "already has one standalone config." + ) if annotate.experimental: warnings.warn( "An experimental acceleration feature is requested by specifying the " f"'--{fi.name}' argument. Please note this feature may not support certain " - "edge cases at this juncture. When the feature matures this message will be turned off." + "edge cases at this juncture. When the feature matures this " + "message will be turned off." ) if not all( @@ -243,12 +229,14 @@ def _descend_and_set(path: List[str], d: Dict): raise ValueError( "An acceleration feature is requested by specifying the " f"'--{fi.name}' argument, but the this requires acceleration packages " - "to be installed. Please do:\n" + - "\n".join([ - '- python -m fms_acceleration install ' - f'{AccelerationFrameworkConfig.PACKAGE_PREFIX + x}' - for x in annotate.required_packages - ]) + "to be installed. Please do:\n" + + "\n".join( + [ + "- python -m fms_acceleration install " + f"{AccelerationFrameworkConfig.PACKAGE_PREFIX + x}" + for x in annotate.required_packages + ] + ) ) key = annotate.key if annotate.key is not None else fi.name @@ -261,5 +249,5 @@ def _descend_and_set(path: List[str], d: Dict): def to_yaml(self, filename: str): "convert a valid AccelerationConfig dataclass into a yaml" configuration_contents = self.to_dict() - with open(filename, "w") as f: + with open(filename, "w", encoding="utf-8") as f: yaml.dump({KEY_PLUGINS: configuration_contents}, f) diff --git a/tuning/config/acceleration_configs/fused_ops_and_kernels.py b/tuning/config/acceleration_configs/fused_ops_and_kernels.py index 9c0747ed9..3142da184 100644 --- a/tuning/config/acceleration_configs/fused_ops_and_kernels.py +++ b/tuning/config/acceleration_configs/fused_ops_and_kernels.py @@ -13,8 +13,11 @@ # limitations under the License. +# Standard from dataclasses import dataclass from typing import List + +# Local from .utils import EnsureTypes, ensure_nested_dataclasses_initialized @@ -36,19 +39,18 @@ def __post_init__(self): # reset for another parse self.__args__[0].reset() - if ( - self.base_layer is not None and - self.base_layer not in {'auto_gptq', 'bitsandbytes'} - ): - raise ValueError( - f"base_layer set to invalid value '{self.base_layer}'" - ) + if self.base_layer is not None and self.base_layer not in { + "auto_gptq", + "bitsandbytes", + }: + raise ValueError(f"base_layer set to invalid value '{self.base_layer}'") if self.base_layer is not None and not self.fused_lora: raise ValueError( f"base_layer set to '{self.base_layer}' so fused_lora must be set to True" ) + @dataclass class FastKernelsConfig(List): @@ -69,14 +71,13 @@ def __post_init__(self): # reset for another parse self.__args__[0].reset() - if not ( - self.fast_loss == self.fast_rsm_layernorm == self.fast_rope_embeddings - ): + if not self.fast_loss == self.fast_rsm_layernorm == self.fast_rope_embeddings: raise ValueError( - 'fast_loss, fast_rms_layernorm and fast_rope_embedding must be enabled ' - 'together. This restriction may be relaxed in the future.' + "fast_loss, fast_rms_layernorm and fast_rope_embedding must be enabled " + "together. This restriction may be relaxed in the future." ) + @dataclass class FusedOpsAndKernelsConfig: @@ -87,15 +88,13 @@ class FusedOpsAndKernelsConfig: fast_kernels: FastKernelsConfig = None def __post_init__(self): - if ( - (self.fused_lora is not None and self.fast_kernels is None) - or - (self.fused_lora is None and self.fast_kernels is not None) + if (self.fused_lora is not None and self.fast_kernels is None) or ( + self.fused_lora is None and self.fast_kernels is not None ): raise ValueError( - 'fused lora and fast_kernels must be used together. ' - 'This restriction may be relaxed in the future.' + "fused lora and fast_kernels must be used together. " + "This restriction may be relaxed in the future." ) # ensure nested dataclasses initialized - ensure_nested_dataclasses_initialized(self) \ No newline at end of file + ensure_nested_dataclasses_initialized(self) diff --git a/tuning/config/acceleration_configs/quantized_lora_config.py b/tuning/config/acceleration_configs/quantized_lora_config.py index 690563c89..08025433b 100644 --- a/tuning/config/acceleration_configs/quantized_lora_config.py +++ b/tuning/config/acceleration_configs/quantized_lora_config.py @@ -13,19 +13,22 @@ # limitations under the License. +# Standard from dataclasses import dataclass from typing import List -from .utils import EnsureTypes, ensure_nested_dataclasses_initialized + +# Local +from .utils import EnsureTypes, ensure_nested_dataclasses_initialized @dataclass class AutoGPTQLoraConfig(List): - + # to help the HfArgumentParser arrive at correct types __args__ = [EnsureTypes(str, bool)] # auto_gptq supports various kernels, to select the kernel to use. - kernel: str = 'triton_v2' + kernel: str = "triton_v2" # allow auto_gptq to quantize a model before training commences. # NOTE: currently this is not allowed. @@ -35,13 +38,14 @@ def __post_init__(self): # reset for another parse self.__args__[0].reset() - - if self.kernel != 'triton_v2': + + if self.kernel != "triton_v2": raise ValueError("only 'triton_v2' kernel currently supported.") if not self.from_quantized: raise ValueError("only 'from_quantized' == True currently supported.") + @dataclass class BNBQLoraConfig(List): @@ -49,9 +53,9 @@ class BNBQLoraConfig(List): __args__ = [EnsureTypes(str, bool)] # type of quantization applied - quant_type: str = 'nf4' + quant_type: str = "nf4" - # if we only want to quantize the base layer, and defer to the + # if we only want to quantize the base layer, and defer to the # huggingface to prepare the peft (i.e. lora) model no_peft_model: bool = False @@ -60,9 +64,10 @@ def __post_init__(self): # reset for another parse self.__args__[0].reset() - if self.quant_type not in ['nf4', 'fp4']: + if self.quant_type not in ["nf4", "fp4"]: raise ValueError("quant_type can only be either 'nf4' or 'fp4.") + @dataclass class QuantizedLoraConfig: @@ -74,7 +79,7 @@ class QuantizedLoraConfig: def __post_init__(self): if self.auto_gptq is None and self.bnb_qlora is None: - raise ValueError('at least one quantized config has to be specified.') - + raise ValueError("at least one quantized config has to be specified.") + # ensure nested dataclasses initialized - ensure_nested_dataclasses_initialized(self) \ No newline at end of file + ensure_nested_dataclasses_initialized(self) diff --git a/tuning/config/acceleration_configs/utils.py b/tuning/config/acceleration_configs/utils.py index aca618994..12d79508c 100644 --- a/tuning/config/acceleration_configs/utils.py +++ b/tuning/config/acceleration_configs/utils.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Type, Dict, get_type_hints - +# Standard from dataclasses import fields -from transformers.hf_argparser import string_to_bool, DataClass +from typing import Dict, Type, get_type_hints + +# Third Party +from transformers.hf_argparser import DataClass, string_to_bool + def ensure_nested_dataclasses_initialized(dataclass: DataClass): type_hints: Dict[str, type] = get_type_hints(dataclass) @@ -26,11 +29,11 @@ def ensure_nested_dataclasses_initialized(dataclass: DataClass): values = nested_type(*values) setattr(dataclass, f.name, values) -class EnsureTypes: +class EnsureTypes: def __init__(self, *types: Type): - map = {bool: string_to_bool} - self.types = [map.get(t, t) for t in types] + _map = {bool: string_to_bool} + self.types = [_map.get(t, t) for t in types] self.reset() def reset(self): @@ -38,9 +41,7 @@ def reset(self): def __call__(self, val): if self.cnt >= len(self.types): - raise ValueError( - "EnsureTypes require 'reset' to be called to be re-used." - ) + raise ValueError("EnsureTypes require 'reset' to be called to be re-used.") t = self.types[self.cnt] self.cnt += 1 diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 77ad2dac6..c0f642f0f 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -36,6 +36,11 @@ # Local from tuning.config import configs, peft_config +from tuning.config.acceleration_configs import ( + AccelerationFrameworkConfig, + FusedOpsAndKernelsConfig, + QuantizedLoraConfig, +) from tuning.config.tracker_configs import ( AimConfig, FileLoggingTrackerConfig, @@ -46,11 +51,6 @@ from tuning.trainercontroller import TrainerControllerCallback from tuning.utils.config_utils import get_hf_peft_config from tuning.utils.data_type_utils import get_torch_dtype -from tuning.config.acceleration_configs import ( - AccelerationFrameworkConfig, - QuantizedLoraConfig, - FusedOpsAndKernelsConfig -) def train( @@ -325,7 +325,7 @@ def main(**kwargs): # pylint: disable=unused-argument FileLoggingTrackerConfig, AimConfig, QuantizedLoraConfig, - FusedOpsAndKernelsConfig + FusedOpsAndKernelsConfig, ) ) parser.add_argument( @@ -351,7 +351,7 @@ def main(**kwargs): # pylint: disable=unused-argument file_logger_config, aim_config, quantized_lora_config, - fusedops_kernels_config, + fusedops_kernels_config, additional, _, ) = parser.parse_args_into_dataclasses(return_remaining_strings=True)