Skip to content

Commit

Permalink
fix merge errors and other issues
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Jun 20, 2024
1 parent 428b4d9 commit 3e2b7c5
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 49 deletions.
47 changes: 27 additions & 20 deletions tests/acceleration/test_acceleration_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@
# limitations under the License.

# Standard
import copy
from dataclasses import dataclass, replace
from typing import Annotated
from unittest.mock import patch
import copy
import tempfile

# Third Party
import pytest
import torch

# First Party
from tests.test_sft_trainer import MODEL_ARGS, DATA_ARGS, TRAIN_ARGS, PEFT_LORA_ARGS
from tests.test_sft_trainer import DATA_ARGS, MODEL_ARGS, PEFT_LORA_ARGS, TRAIN_ARGS

# Local
from .spying_utils import create_mock_plugin_class_and_spy
Expand Down Expand Up @@ -279,7 +279,7 @@ def test_framework_raises_due_to_invalid_arguments(
model_args = replace(model_args, **bad_kwargs)
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

quantized_lora_config = QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig())

# 1. activate the accelerated peft plugin
Expand All @@ -294,27 +294,34 @@ def test_framework_raises_due_to_invalid_arguments(
)


if is_fms_accelerate_available(plugins="peft"):
acceleration_configs_map = [
acceleration_configs_map = [
(
QuantizedLoraConfig(bnb_qlora=BNBQLoraConfig()),
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
(
QuantizedLoraConfig(bnb_qlora=BNBQLoraConfig()),
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
(
"peft.quantization.bitsandbytes",
create_mock_plugin_class_and_spy("PluginMock", BNBAccelerationPlugin),
"peft.quantization.bitsandbytes",
create_mock_plugin_class_and_spy(
"PluginMock",
BNBAccelerationPlugin
if is_fms_accelerate_available(plugins='peft')
else object
),
),
),
(
QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig()),
"TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
(
QuantizedLoraConfig(auto_gptq=AutoGPTQLoraConfig()),
"TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
(
"peft.quantization.auto_gptq",
create_mock_plugin_class_and_spy(
"PluginMock", AutoGPTQAccelerationPlugin
),
"peft.quantization.auto_gptq",
create_mock_plugin_class_and_spy(
"PluginMock",
AutoGPTQAccelerationPlugin
if is_fms_accelerate_available(plugins='peft')
else object
),
),
]
),
]


@pytest.mark.skipif(
Expand All @@ -338,7 +345,7 @@ def test_framework_intialized_properly_peft(
model_args.torch_dtype = torch.float16
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.save_strategy = 'no'
train_args.save_strategy = "no"
train_args.fp16 = True

installation_path, (MockedPlugin, spy) = mock_and_spy
Expand Down Expand Up @@ -386,7 +393,7 @@ def test_framework_intialized_properly_foak():
model_args.torch_dtype = torch.float16
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.save_strategy = 'no'
train_args.save_strategy = "no"
train_args.fp16 = True

# setup default quantized lora args dataclass
Expand Down
71 changes: 42 additions & 29 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,6 @@
USER_ERROR_EXIT_CODE,
write_termination_log,
)
from tuning.utils.import_utils import is_fms_accelerate_available
from tuning.config.acceleration_configs import (
AccelerationFrameworkConfig,
QuantizedLoraConfig,
FusedOpsAndKernelsConfig
)

if is_fms_accelerate_available():
# Third Party
from fms_acceleration import AccelerationFramework # pylint: disable=import-error


def train(
model_args: configs.ModelArguments,
Expand Down Expand Up @@ -443,7 +432,7 @@ def parse_arguments(parser, json_config=None):
file_logger_config,
aim_config,
quantized_lora_config,
fusedops_kernels_config,
fusedops_kernels_config,
) = parser.parse_dict(json_config, allow_extra_keys=True)
peft_method = json_config.get("peft_method")
exp_metadata = json_config.get("exp_metadata")
Expand All @@ -458,7 +447,7 @@ def parse_arguments(parser, json_config=None):
file_logger_config,
aim_config,
quantized_lora_config,
fusedops_kernels_config,
fusedops_kernels_config,
additional,
_,
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
Expand All @@ -482,7 +471,7 @@ def parse_arguments(parser, json_config=None):
file_logger_config,
aim_config,
quantized_lora_config,
fusedops_kernels_config,
fusedops_kernels_config,
exp_metadata,
)

Expand All @@ -504,7 +493,7 @@ def main(**kwargs): # pylint: disable=unused-argument
file_logger_config,
aim_config,
quantized_lora_config,
fusedops_kernels_config,
fusedops_kernels_config,
exp_metadata,
) = parse_arguments(parser, job_config)
logger.debug(
Expand All @@ -521,7 +510,7 @@ def main(**kwargs): # pylint: disable=unused-argument
file_logger_config,
aim_config,
quantized_lora_config,
fusedops_kernels_config,
fusedops_kernels_config,
exp_metadata,
)
except Exception as e: # pylint: disable=broad-except
Expand Down Expand Up @@ -551,19 +540,43 @@ def main(**kwargs): # pylint: disable=unused-argument
combined_tracker_configs.file_logger_config = file_logger_config
combined_tracker_configs.aim_config = aim_config

train(
model_args=model_args,
data_args=data_args,
train_args=training_args,
peft_config=tune_config,
trainer_controller_args=trainer_controller_args,
tracker_configs=combined_tracker_configs,
additional_callbacks=None,
exp_metadata=metadata,
quantized_lora_config=quantized_lora_config,
fusedops_kernels_config=fusedops_kernels_config,
)

try:
train(
model_args=model_args,
data_args=data_args,
train_args=training_args,
peft_config=tune_config,
trainer_controller_args=trainer_controller_args,
tracker_configs=combined_tracker_configs,
additional_callbacks=None,
exp_metadata=metadata,
quantized_lora_config=quantized_lora_config,
fusedops_kernels_config=fusedops_kernels_config,
)
except (MemoryError, OutOfMemoryError) as e:
logger.error(traceback.format_exc())
write_termination_log(f"OOM error during training. {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)
except FileNotFoundError as e:
logger.error(traceback.format_exc())
write_termination_log("Unable to load file: {}".format(e))
sys.exit(USER_ERROR_EXIT_CODE)
except HFValidationError as e:
logger.error(traceback.format_exc())
write_termination_log(
f"There may be a problem with loading the model. Exception: {e}"
)
sys.exit(USER_ERROR_EXIT_CODE)
except (TypeError, ValueError, EnvironmentError) as e:
logger.error(traceback.format_exc())
write_termination_log(
f"Exception raised during training. This may be a problem with your input: {e}"
)
sys.exit(USER_ERROR_EXIT_CODE)
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc())
write_termination_log(f"Unhandled exception during training: {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)

if __name__ == "__main__":
fire.Fire(main)

0 comments on commit 3e2b7c5

Please sign in to comment.