From a9bb7a3d5c8e2dc4545e9c9b7c6255a8b3fa1ef5 Mon Sep 17 00:00:00 2001 From: YeZhengMao Date: Fri, 28 Jun 2024 15:48:45 +0800 Subject: [PATCH] [improvement] add code formatter - black and isort, static typing (#227) checker mypy. --- .github/workflows/pr-clean-code-test.yml | 11 +- .github/workflows/pre-commit | 8 +- README.md | 8 +- mlora/cli/__init__.py | 8 +- mlora/cli/adapter.py | 83 ++--- mlora/cli/dataset.py | 75 ++-- mlora/cli/dispatcher.py | 7 +- mlora/cli/file.py | 61 ++-- mlora/cli/task.py | 42 +-- mlora/config/__init__.py | 16 +- mlora/config/adapter.py | 63 ++-- mlora/config/config.py | 4 +- mlora/config/dataset.py | 10 +- mlora/config/dispatcher.py | 6 +- mlora/config/lr_scheduler.py | 27 +- mlora/config/mlora.py | 33 +- mlora/config/optimizer.py | 33 +- mlora/config/task.py | 87 +++-- mlora/evaluator/evaluator.py | 13 - mlora/evaluator/evaluator_factory.py | 23 -- mlora/evaluator/mmlu_evaluator.py | 188 ---------- mlora/executor/__init__.py | 4 +- mlora/executor/context/__init__.py | 15 +- mlora/executor/context/context.py | 42 +-- mlora/executor/context/inference.py | 14 +- mlora/executor/context/lora.py | 79 +++-- mlora/executor/context/loraplus.py | 34 +- mlora/executor/context/train.py | 53 ++- mlora/executor/dispatcher/__init__.py | 13 +- .../executor/dispatcher/backend_dispatcher.py | 7 +- mlora/executor/dispatcher/dispatcher.py | 68 ++-- .../dispatcher/pipeline_dispatcher.py | 55 --- mlora/executor/executor.py | 50 ++- mlora/executor/task/__init__.py | 19 +- mlora/executor/task/cpo_task.py | 62 ++-- mlora/executor/task/dpo_task.py | 108 +++--- mlora/executor/task/task.py | 80 ++--- mlora/executor/task/train_task.py | 86 +++-- mlora/model/args.py | 140 +++++--- mlora/model/checkpoint/checkpoint.py | 2 +- mlora/model/llm/__init__.py | 7 +- mlora/model/llm/model_llama.py | 162 +++++---- mlora/model/llm/model_llm.py | 38 +- mlora/model/modules/__init__.py | 12 +- mlora/model/modules/adapter.py | 13 +- mlora/model/modules/attention.py | 97 ++--- mlora/model/modules/decoder.py | 39 +- mlora/model/modules/embedding.py | 3 +- mlora/model/modules/linear.py | 37 +- mlora/model/modules/lora.py | 128 ++++--- mlora/model/modules/mlp.py | 29 +- mlora/model/modules/output_layer.py | 13 +- mlora/model/tokenizer/__init__.py | 4 +- mlora/model/tokenizer/tokenizer.py | 10 +- mlora/pipeline/function.py | 72 ---- mlora/pipeline/messages.py | 24 -- mlora/pipeline/pipe.py | 332 ------------------ mlora/pipeline/queue.py | 99 ------ mlora/pipeline/stream.py | 16 - mlora/pipeline/transport.py | 179 ---------- mlora/profiler/__init__.py | 11 +- mlora/profiler/profiler.py | 63 ++-- mlora/profiler/traceviz.py | 80 ++--- mlora/prompter/__init__.py | 7 +- mlora/prompter/instruction_data_prompter.py | 5 +- mlora/prompter/preference_data_prompter.py | 11 +- mlora/prompter/prompter.py | 15 +- mlora/server/__init__.py | 36 +- mlora/server/adapter.py | 5 +- mlora/server/dataset.py | 26 +- mlora/server/file.py | 25 +- mlora/server/storage.py | 13 +- mlora/server/task.py | 15 +- mlora/utils/__init__.py | 9 +- mlora/utils/cmd.py | 77 ++-- mlora/utils/loader.py | 36 +- mlora/utils/setup.py | 28 +- pyproject.toml | 2 +- tests/lora_op_test.py | 13 +- 79 files changed, 1401 insertions(+), 2127 deletions(-) delete mode 100644 mlora/evaluator/evaluator.py delete mode 100644 mlora/evaluator/evaluator_factory.py delete mode 100644 mlora/evaluator/mmlu_evaluator.py delete mode 100644 mlora/executor/dispatcher/pipeline_dispatcher.py delete mode 100644 mlora/pipeline/function.py delete mode 100644 mlora/pipeline/messages.py delete mode 100644 mlora/pipeline/pipe.py delete mode 100644 mlora/pipeline/queue.py delete mode 100644 mlora/pipeline/stream.py delete mode 100644 mlora/pipeline/transport.py diff --git a/.github/workflows/pr-clean-code-test.yml b/.github/workflows/pr-clean-code-test.yml index f8e159d1..64c37dd6 100644 --- a/.github/workflows/pr-clean-code-test.yml +++ b/.github/workflows/pr-clean-code-test.yml @@ -32,7 +32,16 @@ jobs: lizard -l python ./mlora -C 12 - name: Lint with flake8 run: | - flake8 . --count --show-source --statistics --max-line-length=127 --max-complexity 15 --ignore=E722,W504 + flake8 ./mlora --count --show-source --statistics --max-line-length=88 --max-complexity 15 --ignore=E203,W503,E704 + - name: Lint with black + run: | + black --check ./mlora + - name: Lint with isort + run: | + isort ./mlora --check --profile black + - name: Static code check with mypy + run: | + mypy ./mlora --ignore-missing-imports --non-interactive --install-types --check-untyped-defs - name: Test with pytest run: | pytest diff --git a/.github/workflows/pre-commit b/.github/workflows/pre-commit index 28384a28..49f04528 100755 --- a/.github/workflows/pre-commit +++ b/.github/workflows/pre-commit @@ -2,6 +2,12 @@ lizard -l python ./mlora -C 12 -flake8 . --count --show-source --statistics --max-line-length=127 --max-complexity 15 --ignore=E722,W504 +black --check ./mlora + +isort ./mlora --check --profile black + +flake8 ./mlora --count --show-source --statistics --max-line-length=88 --max-complexity 15 --ignore=E203,W503,E704 + +mypy ./mlora --ignore-missing-imports --non-interactive --install-types --check-untyped-defs pytest diff --git a/README.md b/README.md index 5fe44451..5a8bab72 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ Firstly, you should clone this repository and install dependencies: # Clone Repository git clone https://github.com/TUDB-Labs/mLoRA cd mLoRA -# Install requirements +# Install requirements need the Python >= 3.12 pip install . ``` @@ -123,8 +123,14 @@ Submit a pull request with a detailed explanation of your changes. You can use the pre-commit to check your code. ```bash +# Install requirements +pip install .[ci_test] ln -s ../../.github/workflows/pre-commit .git/hooks/pre-commit ``` +Or just call the script to check your code +```bash +.github/workflows/pre-commit +``` ## Citation Please cite the repo if you use the code in this repo. diff --git a/mlora/cli/__init__.py b/mlora/cli/__init__.py index 18a6ee19..c50eaec2 100644 --- a/mlora/cli/__init__.py +++ b/mlora/cli/__init__.py @@ -1,8 +1,8 @@ -from .setting import G_HOST, G_PORT +from .adapter import do_adapter, help_adapter +from .dataset import do_dataset, help_dataset from .dispatcher import do_dispatcher, help_dispatcher from .file import do_file, help_file -from .dataset import do_dataset, help_dataset -from .adapter import do_adapter, help_adapter +from .setting import G_HOST, G_PORT from .task import do_task, help_task __all__ = [ @@ -17,5 +17,5 @@ "help_adapter", "do_adapter", "help_task", - "do_task" + "do_task", ] diff --git a/mlora/cli/adapter.py b/mlora/cli/adapter.py index 5e50f956..2b663b88 100644 --- a/mlora/cli/adapter.py +++ b/mlora/cli/adapter.py @@ -1,19 +1,19 @@ import json +from typing import Any, Dict + import requests -from InquirerPy import inquirer -from InquirerPy import validator +from InquirerPy import inquirer, validator from InquirerPy.base import Choice from rich import print -from rich.table import Table from rich.box import ASCII -from typing import Dict +from rich.table import Table from .setting import url def list_adapter(obj): ret = requests.get(url() + "/adapter") - ret = json.loads(ret.text) + ret_items = json.loads(ret.text) table = Table(show_header=True, show_lines=True, box=ASCII) table.add_column("name", justify="center") @@ -23,63 +23,58 @@ def list_adapter(obj): obj.ret_ = [] - for item in ret: - item = json.loads(item) + for ret_item in ret_items: + item = json.loads(ret_item) table.add_row(item["name"], item["type"], item["path"], item["state"]) obj.ret_.append(item["name"]) obj.pret_ = table -def adapter_type_set(adapter_conf: Dict[str, any]): +def adapter_type_set(adapter_conf: Dict[str, Any]): adapter_type = inquirer.select( - message="type:", choices=["lora", "loraplus"]).execute() + message="type:", choices=["lora", "loraplus"] + ).execute() adapter_conf["type"] = adapter_type if adapter_type == "loraplus": lr_ratio = inquirer.number( - message="lr_ratio:", - float_allowed=True, - default=8.0, - replace_mode=True + message="lr_ratio:", float_allowed=True, default=8.0, replace_mode=True ).execute() adapter_conf["lr_ratio"] = lr_ratio return adapter_conf -def adapter_optimizer_set(adapter_conf: Dict[str, any]): +def adapter_optimizer_set(adapter_conf: Dict[str, Any]): optimizer = inquirer.select( - message="optimizer:", choices=["adamw", "sgd"]).execute() + message="optimizer:", choices=["adamw", "sgd"] + ).execute() adapter_conf["optimizer"] = optimizer lr = inquirer.number( - message="learning rate:", - float_allowed=True, - default=3e-4, - replace_mode=True + message="learning rate:", float_allowed=True, default=3e-4, replace_mode=True ).execute() adapter_conf["lr"] = lr if optimizer == "sgd": momentum = inquirer.number( - message="momentum:", - float_allowed=True, - default=0.0, - replace_mode=True + message="momentum:", float_allowed=True, default=0.0, replace_mode=True ).execute() adapter_conf["momentum"] = momentum return adapter_conf -def adapter_lr_scheduler_set(adapter_conf: Dict[str, any]): +def adapter_lr_scheduler_set(adapter_conf: Dict[str, Any]): need_lr_scheduler = inquirer.confirm( - message="Need learning rate scheduler:", default=False).execute() + message="Need learning rate scheduler:", default=False + ).execute() if not need_lr_scheduler: return adapter_conf lr_scheduler_type = inquirer.select( - message="optimizer:", choices=["cosine"]).execute() + message="optimizer:", choices=["cosine"] + ).execute() adapter_conf["lrscheduler"] = lr_scheduler_type if lr_scheduler_type == "cosine": @@ -101,24 +96,15 @@ def adapter_lr_scheduler_set(adapter_conf: Dict[str, any]): return adapter_conf -def adapter_set(adapter_conf: Dict[str, any]): - r = inquirer.number( - message="rank:", - default=32 - ).execute() +def adapter_set(adapter_conf: Dict[str, Any]): + r = inquirer.number(message="rank:", default=32).execute() adapter_conf["r"] = r - alpha = inquirer.number( - message="alpha:", - default=64 - ).execute() + alpha = inquirer.number(message="alpha:", default=64).execute() adapter_conf["alpha"] = alpha dropout = inquirer.number( - message="dropout:", - float_allowed=True, - replace_mode=True, - default=0.05 + message="dropout:", float_allowed=True, replace_mode=True, default=0.05 ).execute() adapter_conf["dropout"] = dropout @@ -133,13 +119,15 @@ def adapter_set(adapter_conf: Dict[str, any]): } target_modules = inquirer.checkbox( message="target_modules:", - choices=[Choice("q_proj", enabled=True), - Choice("k_proj", enabled=True), - Choice("v_proj", enabled=True), - Choice("o_proj", enabled=True), - Choice("gate_proj", enabled=False), - Choice("down_proj", enabled=False), - Choice("up_proj", enabled=False)] + choices=[ + Choice("q_proj", enabled=True), + Choice("k_proj", enabled=True), + Choice("v_proj", enabled=True), + Choice("o_proj", enabled=True), + Choice("gate_proj", enabled=False), + Choice("down_proj", enabled=False), + Choice("up_proj", enabled=False), + ], ).execute() for target in target_modules: adapter_conf["target_modules"][target] = True @@ -152,7 +140,8 @@ def create_adapter(): name = inquirer.text( message="name:", - validate=validator.EmptyInputValidator("Input should not be empty")).execute() + validate=validator.EmptyInputValidator("Input should not be empty"), + ).execute() adapter_conf["name"] = name adapter_conf = adapter_type_set(adapter_conf) diff --git a/mlora/cli/dataset.py b/mlora/cli/dataset.py index 108d8740..09fe3096 100644 --- a/mlora/cli/dataset.py +++ b/mlora/cli/dataset.py @@ -1,17 +1,18 @@ import json + import requests from InquirerPy import inquirer, separator from rich import print -from rich.table import Table from rich.box import ASCII +from rich.table import Table -from .setting import url from .file import list_file +from .setting import url def list_dataset(obj): ret = requests.get(url() + "/dataset") - ret = json.loads(ret.text) + ret_items = json.loads(ret.text) table = Table(show_header=True, show_lines=True, box=ASCII) table.add_column("name", justify="center") @@ -22,21 +23,22 @@ def list_dataset(obj): obj.ret_ = [] - for item in ret: - item = json.loads(item) - table.add_row(item["name"], - item["data_name"], - item["prompt_name"], - item["prompt_type"], - item["preprocess"]) + for ret_item in ret_items: + item = json.loads(ret_item) + table.add_row( + item["name"], + item["data_name"], + item["prompt_name"], + item["prompt_type"], + item["preprocess"], + ) obj.ret_.append(item["name"]) obj.pret_ = table def create_dataset(obj): - name = inquirer.text( - message="name:").execute() + name = inquirer.text(message="name:").execute() list_file(obj, "data") all_train_data = [item["name"] for item in obj.ret_] @@ -51,26 +53,33 @@ def create_dataset(obj): return use_train = inquirer.select( - message="train data file:", choices=[separator.Separator(), - *all_train_data, - separator.Separator()]).execute() + message="train data file:", + choices=[separator.Separator(), *all_train_data, separator.Separator()], + ).execute() use_prompt = inquirer.select( - message="prompt template file:", choices=[separator.Separator(), - *all_prompt, - separator.Separator()]).execute() + message="prompt template file:", + choices=[separator.Separator(), *all_prompt, separator.Separator()], + ).execute() use_preprocess = inquirer.select( - message="data preprocessing:", choices=[separator.Separator(), - "default", - "shuffle", - "sort", - separator.Separator()]).execute() - - ret = requests.post(url() + "/dataset", json={ - "name": name, - "data_name": use_train, - "prompt_name": use_prompt, - "preprocess": use_preprocess - }) + message="data preprocessing:", + choices=[ + separator.Separator(), + "default", + "shuffle", + "sort", + separator.Separator(), + ], + ).execute() + + ret = requests.post( + url() + "/dataset", + json={ + "name": name, + "data_name": use_train, + "prompt_name": use_prompt, + "preprocess": use_preprocess, + }, + ) print(json.loads(ret.text)) @@ -84,9 +93,9 @@ def showcase_dataset(obj): return use_dataset = inquirer.select( - message="dataset name:", choices=[separator.Separator(), - *all_dataset, - separator.Separator()]).execute() + message="dataset name:", + choices=[separator.Separator(), *all_dataset, separator.Separator()], + ).execute() ret = requests.get(url() + f"/showcase?name={use_dataset}") ret = json.loads(ret.text) diff --git a/mlora/cli/dispatcher.py b/mlora/cli/dispatcher.py index 529bb833..fe5341a8 100644 --- a/mlora/cli/dispatcher.py +++ b/mlora/cli/dispatcher.py @@ -1,8 +1,9 @@ import json + import requests from rich import print -from rich.table import Table from rich.box import ASCII +from rich.table import Table from .setting import url @@ -13,12 +14,12 @@ def help_dispatcher(_): def do_dispatcher(*_): ret = requests.get(url() + "/dispatcher") - ret = json.loads(ret.text) + ret_text = json.loads(ret.text) table = Table(show_header=True, show_lines=True, box=ASCII) table.add_column("Item", justify="center") table.add_column("Value", justify="center") - for item, value in ret.items(): + for item, value in ret_text.items(): table.add_row(item, str(value)) print(table) diff --git a/mlora/cli/file.py b/mlora/cli/file.py index 93a2bbc5..ba5fba55 100644 --- a/mlora/cli/file.py +++ b/mlora/cli/file.py @@ -1,22 +1,19 @@ import json + import requests -from InquirerPy import inquirer, validator, separator +from InquirerPy import inquirer, separator, validator from rich import print -from rich.table import Table from rich.box import ASCII - +from rich.table import Table from .setting import url -g_file_type_map = { - "train data": "data", - "prompt data": "prompt" -} +g_file_type_map = {"train data": "data", "prompt data": "prompt"} def list_file(obj, file_type: str): ret = requests.get(url() + f"/{file_type}") - ret = json.loads(ret.text) + ret_items = json.loads(ret.text) table = Table(show_header=True, show_lines=True, box=ASCII) table.add_column("name", justify="center") @@ -24,42 +21,48 @@ def list_file(obj, file_type: str): if file_type == "prompt": table.add_column("prompter", justify="center") - for item in ret: + for item in ret_items: row_data = [item["name"], item["file"]["file_path"]] if file_type == "prompt": row_data.append(item["file"]["prompt_type"]) table.add_row(*row_data) - obj.ret_ = ret + obj.ret_ = ret_items obj.pret_ = table def upload_file(): name = inquirer.text( message="name:", - validate=validator.EmptyInputValidator("name should not be empty")).execute() + validate=validator.EmptyInputValidator("name should not be empty"), + ).execute() - file_type = inquirer.select(message="file type:", - choices=[separator.Separator(), - *g_file_type_map.keys(), - separator.Separator()]).execute() + file_type = inquirer.select( + message="file type:", + choices=[separator.Separator(), *g_file_type_map.keys(), separator.Separator()], + ).execute() file_type = g_file_type_map[file_type] post_url = url() + f"/{file_type}?name={name}" if file_type == "prompt": - prompt_type = inquirer.select(message="prompter type:", - choices=[separator.Separator(), - "instruction", - "preference", - separator.Separator()]).execute() + prompt_type = inquirer.select( + message="prompter type:", + choices=[ + separator.Separator(), + "instruction", + "preference", + separator.Separator(), + ], + ).execute() post_url += f"&prompt_type={prompt_type}" - path = inquirer.filepath(message="file path:", - default="/", - validate=validator.PathValidator(is_file=True, - message="input is not a file"), - only_files=True).execute() + path = inquirer.filepath( + message="file path:", + default="/", + validate=validator.PathValidator(is_file=True, message="input is not a file"), + only_files=True, + ).execute() ret = requests.post(post_url, files={"data_file": open(path, "rb")}) @@ -81,9 +84,11 @@ def do_file(obj, args): # to chose file type file_type = inquirer.select( message="type:", - choices=[separator.Separator(), - *g_file_type_map.keys(), - separator.Separator()] + choices=[ + separator.Separator(), + *g_file_type_map.keys(), + separator.Separator(), + ], ).execute() file_type = g_file_type_map[file_type] list_file(obj, file_type) diff --git a/mlora/cli/task.py b/mlora/cli/task.py index c9162f01..f0825dd2 100644 --- a/mlora/cli/task.py +++ b/mlora/cli/task.py @@ -1,9 +1,10 @@ import json + import requests from InquirerPy import inquirer from rich import print -from rich.table import Table from rich.box import ASCII +from rich.table import Table from .adapter import list_adapter from .dataset import list_dataset @@ -12,7 +13,7 @@ def list_task(obj): ret = requests.get(url() + "/task") - ret = json.loads(ret.text) + ret_items = json.loads(ret.text) table = Table(show_header=True, show_lines=True, box=ASCII) table.add_column("name", justify="center") @@ -21,10 +22,11 @@ def list_task(obj): table.add_column("adapter", justify="center") table.add_column("state", justify="center") - for item in ret: - item = json.loads(item) - table.add_row(item["name"], item["type"], - item["dataset"], item["adapter"], item["state"]) + for ret_item in ret_items: + item = json.loads(ret_item) + table.add_row( + item["name"], item["type"], item["dataset"], item["adapter"], item["state"] + ) obj.pret_ = table @@ -32,10 +34,7 @@ def list_task(obj): def task_type_set(task_conf, all_adapters): if task_conf["type"] == "dpo" or task_conf["type"] == "cpo": beta = inquirer.number( - message="beta:", - float_allowed=True, - default=0.1, - replace_mode=True + message="beta:", float_allowed=True, default=0.1, replace_mode=True ).execute() task_conf["beta"] = beta @@ -43,23 +42,26 @@ def task_type_set(task_conf, all_adapters): message="label_smoothing:", float_allowed=True, default=0.0, - replace_mode=True + replace_mode=True, ).execute() task_conf["label_smoothing"] = label_smoothing if task_conf["type"] == "cpo": loss_type = inquirer.select( - message="loss_type:", choices=["sigmoid", "hinge"]).execute() + message="loss_type:", choices=["sigmoid", "hinge"] + ).execute() task_conf["loss_type"] = loss_type if task_conf["type"] == "dpo": loss_type = inquirer.select( - message="loss_type:", choices=["sigmoid", "ipo"]).execute() + message="loss_type:", choices=["sigmoid", "ipo"] + ).execute() task_conf["loss_type"] = loss_type all_adapters.append("base") reference = inquirer.select( - message="reference model:", choices=all_adapters).execute() + message="reference model:", choices=all_adapters + ).execute() task_conf["reference"] = reference return task_conf @@ -108,11 +110,11 @@ def create_task(obj): task_conf = {} task_type = inquirer.select( - message="type:", choices=["train", "dpo", "cpo"]).execute() + message="type:", choices=["train", "dpo", "cpo"] + ).execute() task_conf["type"] = task_type - name = inquirer.text( - message="name:").execute() + name = inquirer.text(message="name:").execute() task_conf["name"] = name list_dataset(obj) @@ -122,8 +124,7 @@ def create_task(obj): print("no dataset, please create one") return - dataset = inquirer.select( - message="dataset:", choices=all_dataset).execute() + dataset = inquirer.select(message="dataset:", choices=all_dataset).execute() task_conf["dataset"] = dataset list_adapter(obj) @@ -133,8 +134,7 @@ def create_task(obj): print("no adapter can be train, please create one") return - adapter = inquirer.select( - message="train adapter:", choices=all_adapters).execute() + adapter = inquirer.select(message="train adapter:", choices=all_adapters).execute() task_conf["adapter"] = adapter task_conf = task_type_set(task_conf, all_adapters.copy()) diff --git a/mlora/config/__init__.py b/mlora/config/__init__.py index 4b3eb6fa..770b6bf1 100644 --- a/mlora/config/__init__.py +++ b/mlora/config/__init__.py @@ -1,9 +1,15 @@ -from .mlora import MLoRAConfig, MLoRAServerConfig +from .adapter import ADAPTERCONFIG_CLASS, AdapterConfig, LoRAConfig, LoRAPlusConfig from .dataset import DatasetConfig -from .adapter import AdapterConfig, LoRAConfig, LoRAPlusConfig, ADAPTERCONFIG_CLASS -from .task import TaskConfig, TrainTaskConfig, DPOTaskConfig, CPOTaskConfig, TASKCONFIG_CLASS -from .optimizer import OptimizerConfig from .lr_scheduler import LRSchedulerConfig +from .mlora import MLoRAConfig, MLoRAServerConfig +from .optimizer import OptimizerConfig +from .task import ( + TASKCONFIG_CLASS, + CPOTaskConfig, + DPOTaskConfig, + TaskConfig, + TrainTaskConfig, +) __all__ = [ "MLoRAConfig", @@ -19,5 +25,5 @@ "LoRAPlusConfig", "ADAPTERCONFIG_CLASS", "OptimizerConfig", - "LRSchedulerConfig" + "LRSchedulerConfig", ] diff --git a/mlora/config/adapter.py b/mlora/config/adapter.py index e9db056d..576a2db1 100644 --- a/mlora/config/adapter.py +++ b/mlora/config/adapter.py @@ -1,68 +1,64 @@ import logging -from typing import Dict, Any, override from abc import abstractmethod +from typing import Any, Dict, Optional, override from .config import DictConfig -from .optimizer import OptimizerConfig, OPTIMIZERCONFIG_CLASS -from .lr_scheduler import LRSchedulerConfig, LRSCHEDULERCONFIG_CLASS +from .lr_scheduler import LRSCHEDULERCONFIG_CLASS, LRSchedulerConfig +from .optimizer import OPTIMIZERCONFIG_CLASS, OptimizerConfig class AdapterConfig(DictConfig): - type_: str = "" - name_: str = "" - path_: str = "" + type_: str + name_: str + path_: str - optimizer_config_: OptimizerConfig = None - lr_scheduler_config_: LRSchedulerConfig = None + optimizer_config_: Optional[OptimizerConfig] + lr_scheduler_config_: Optional[LRSchedulerConfig] - __params_map: Dict[str, str] = { - "type_": "type", - "name_": "name", - "path_": "path" - } + __params_map: Dict[str, str] = {"type_": "type", "name_": "name", "path_": "path"} def __init_optim(self, config: Dict[str, str]): if config["optimizer"] not in OPTIMIZERCONFIG_CLASS: raise NotImplementedError - self.optimizer_config_ = OPTIMIZERCONFIG_CLASS[config["optimizer"]]( - config) + self.optimizer_config_ = OPTIMIZERCONFIG_CLASS[config["optimizer"]](config) def __init_lr_scheduler(self, config: Dict[str, str]): if config["lrscheduler"] not in LRSCHEDULERCONFIG_CLASS: raise NotImplementedError self.lr_scheduler_config_ = LRSCHEDULERCONFIG_CLASS[config["lrscheduler"]]( - config) + config + ) def __init__(self, config: Dict[str, str]): super().__init__(config) self.init(self.__params_map, config) + self.lr_scheduler_config_ = None + self.optimizer_config_ = None + if "optimizer" not in config: - logging.info( - f"Adapter {self.name_} without optimizer, only for inference") + logging.info(f"Adapter {self.name_} without optimizer, only for inference") return self.__init_optim(config) if "lrscheduler" not in config: - logging.info( - f"Adapter {self.name_} without lr_scheduler.") + logging.info(f"Adapter {self.name_} without lr_scheduler.") return self.__init_lr_scheduler(config) @abstractmethod - def export(self) -> Dict[str, str]: - ... + def export(self) -> Dict[str, str]: ... class LoRAConfig(AdapterConfig): - r_: int = -1 - alpha_: int = 0 - dropout_: float = 0.05 - target_: Dict[str, bool] = {} + r_: int + alpha_: int + dropout_: float + target_: Dict[str, bool] __params_map: Dict[str, str] = { "r_": "r", @@ -83,7 +79,7 @@ def __init__(self, config: Dict[str, Any]): self.target_[key] = bool(value) @override - def export(self) -> Dict[str, str]: + def export(self) -> Dict[str, Any]: return { "lora_alpha": self.alpha_, "lora_dropout": self.dropout_, @@ -91,16 +87,14 @@ def export(self) -> Dict[str, str]: "peft_type": "LORA", "task_type": "CAUSAL_LM", "bias": "none", - "target_modules": [key for key in self.target_ if self.target_[key]] + "target_modules": [key for key in self.target_ if self.target_[key]], } class LoRAPlusConfig(LoRAConfig): - lr_ratio_: float = 8.0 + lr_ratio_: float - __params_map: Dict[str, str] = { - "lr_ratio_": "lr_ratio" - } + __params_map: Dict[str, str] = {"lr_ratio_": "lr_ratio"} def __init__(self, config: Dict[str, Any]): super().__init__(config) @@ -109,7 +103,4 @@ def __init__(self, config: Dict[str, Any]): self.lr_ratio_ = float(self.lr_ratio_) -ADAPTERCONFIG_CLASS = { - "lora": LoRAConfig, - "loraplus": LoRAPlusConfig -} +ADAPTERCONFIG_CLASS = {"lora": LoRAConfig, "loraplus": LoRAPlusConfig} diff --git a/mlora/config/config.py b/mlora/config/config.py index 2c9974a8..312d361f 100644 --- a/mlora/config/config.py +++ b/mlora/config/config.py @@ -7,8 +7,6 @@ class DictConfig: def __init__(self, config: Dict[str, str]) -> None: self.init(self.__params_map, config) - def init(self, - params_map: Dict[str, str], - config: Dict[str, str]): + def init(self, params_map: Dict[str, str], config: Dict[str, str]): for key, value in params_map.items(): setattr(self, key, config[value]) diff --git a/mlora/config/dataset.py b/mlora/config/dataset.py index a452b53c..d9eec6b3 100644 --- a/mlora/config/dataset.py +++ b/mlora/config/dataset.py @@ -4,11 +4,11 @@ class DatasetConfig(DictConfig): - name_: str = "" - data_path_: str = "" - prompt_path_: str = "" - prompt_type_: str = "" - preprocess_: str = "shuffle" + name_: str + data_path_: str + prompt_path_: str + prompt_type_: str + preprocess_: str __params_map: Dict[str, str] = { "name_": "name", diff --git a/mlora/config/dispatcher.py b/mlora/config/dispatcher.py index 9103977c..093cace6 100644 --- a/mlora/config/dispatcher.py +++ b/mlora/config/dispatcher.py @@ -4,12 +4,12 @@ class DispatcherConfig(DictConfig): - name_: str = "default" - concurrency_num_: int = 2 + name_: str + concurrency_num_: int __params_map: Dict[str, str] = { "name_": "name", - "concurrency_num_": "concurrency_num" + "concurrency_num_": "concurrency_num", } def __init__(self, config: Dict[str, str]): diff --git a/mlora/config/lr_scheduler.py b/mlora/config/lr_scheduler.py index 1317c52f..c4c24659 100644 --- a/mlora/config/lr_scheduler.py +++ b/mlora/config/lr_scheduler.py @@ -1,33 +1,27 @@ -from typing import Dict, override from abc import abstractmethod +from typing import Any, Dict, override from .config import DictConfig class LRSchedulerConfig(DictConfig): - lr_scheduler_: str = "" + lr_scheduler_: str - __params_map: Dict[str, str] = { - "lr_scheduler_": "lrscheduler" - } + __params_map: Dict[str, str] = {"lr_scheduler_": "lrscheduler"} def __init__(self, config: Dict[str, str]) -> None: super().__init__(config) self.init(self.__params_map, config) @abstractmethod - def to_fn_parameters(self) -> Dict[str, str]: - ... + def to_fn_parameters(self) -> Dict[str, str]: ... class CosineLRSchedulerConfig(LRSchedulerConfig): - t_max_: int = -1 - eta_min_: int = 0 + t_max_: int + eta_min_: int - __params_map: Dict[str, str] = { - "t_max_": "t_max", - "eta_min_": "eta_min" - } + __params_map: Dict[str, str] = {"t_max_": "t_max", "eta_min_": "eta_min"} def __init__(self, config: Dict[str, str]) -> None: super().__init__(config) @@ -37,11 +31,8 @@ def __init__(self, config: Dict[str, str]) -> None: self.eta_min_ = int(self.eta_min_) @override - def to_fn_parameters(self) -> Dict[str, str]: - return { - "T_max": float(self.t_max_), - "eta_min": float(self.eta_min_) - } + def to_fn_parameters(self) -> Dict[str, Any]: + return {"T_max": float(self.t_max_), "eta_min": float(self.eta_min_)} LRSCHEDULERCONFIG_CLASS = { diff --git a/mlora/config/mlora.py b/mlora/config/mlora.py index 2e8713ce..6dcccf85 100644 --- a/mlora/config/mlora.py +++ b/mlora/config/mlora.py @@ -1,34 +1,36 @@ +from typing import Any, Dict, List + import yaml -from typing import Dict, List -from .dispatcher import DispatcherConfig +from .adapter import ADAPTERCONFIG_CLASS, AdapterConfig from .dataset import DatasetConfig -from .adapter import AdapterConfig, ADAPTERCONFIG_CLASS -from .task import TaskConfig, TASKCONFIG_CLASS +from .dispatcher import DispatcherConfig +from .task import TASKCONFIG_CLASS, TaskConfig class MLoRAConfig: - dispatcher_: DispatcherConfig = None - tasks_: List[TaskConfig] = [] - __datasets_: Dict[str, DatasetConfig] = {} - __adapters_: Dict[str, AdapterConfig] = {} + dispatcher_: DispatcherConfig + tasks_: List[TaskConfig] + __datasets_: Dict[str, DatasetConfig] + __adapters_: Dict[str, AdapterConfig] - def __init_datasets(self, config: List[Dict[str, any]]): + def __init_datasets(self, config: List[Dict[str, Any]]): for item in config: name = item["name"] self.__datasets_[name] = DatasetConfig(item) - def __init_adapters(self, config: List[Dict[str, any]]): + def __init_adapters(self, config: List[Dict[str, Any]]): for item in config: name = item["name"] atype = item["type"] self.__adapters_[name] = ADAPTERCONFIG_CLASS[atype](item) - def __init_tasks(self, config: List[Dict[str, any]]): + def __init_tasks(self, config: List[Dict[str, Any]]): for item in config: assert item["type"] in TASKCONFIG_CLASS - self.tasks_.append(TASKCONFIG_CLASS[item["type"]]( - item, self.__adapters_, self.__datasets_)) + self.tasks_.append( + TASKCONFIG_CLASS[item["type"]](item, self.__adapters_, self.__datasets_) + ) def __init__(self, path: str): with open(path) as fp: @@ -36,6 +38,10 @@ def __init__(self, path: str): self.dispatcher_ = DispatcherConfig(config["dispatcher"]) + self.__adapters_ = {} + self.__datasets_ = {} + self.tasks_ = [] + # must ensure the adapter and datasets init before the task self.__init_datasets(config["datasets"]) self.__init_adapters(config["adapters"]) @@ -44,6 +50,5 @@ def __init__(self, path: str): class MLoRAServerConfig(MLoRAConfig): - def __init__(self, config: Dict[str, str]) -> None: self.dispatcher_ = DispatcherConfig(config) diff --git a/mlora/config/optimizer.py b/mlora/config/optimizer.py index 3f5916b8..79016145 100644 --- a/mlora/config/optimizer.py +++ b/mlora/config/optimizer.py @@ -1,12 +1,12 @@ -from typing import Dict, override from abc import abstractmethod +from typing import Any, Dict, Type, override from .config import DictConfig class OptimizerConfig(DictConfig): - lr_: float = 0.0 - optimizer_: str = "" + lr_: float + optimizer_: str __params_map: Dict[str, str] = { "lr_": "lr", @@ -20,16 +20,13 @@ def __init__(self, config: Dict[str, str]) -> None: self.lr_ = float(self.lr_) @abstractmethod - def to_fn_parameters(self) -> Dict[str, str]: - ... + def to_fn_parameters(self) -> Dict[str, str]: ... class SGDOptimizerConfig(OptimizerConfig): - momentum_: float = 0.0 + momentum_: float - __params_map: Dict[str, str] = { - "momentum_": "momentum" - } + __params_map: Dict[str, str] = {"momentum_": "momentum"} def __init__(self, config: Dict[str, str]) -> None: super().__init__(config) @@ -38,29 +35,25 @@ def __init__(self, config: Dict[str, str]) -> None: self.momentum_ = float(self.momentum_) @override - def to_fn_parameters(self) -> Dict[str, str]: - return { - "lr": float(self.lr_), - "motentum": float(self.momentum_) - } + def to_fn_parameters(self) -> Dict[str, Any]: + return {"lr": float(self.lr_), "motentum": float(self.momentum_)} class AdamWOptimizerConfig(OptimizerConfig): - __params_map: Dict[str, str] = { - } + __params_map: Dict[str, str] = {} def __init__(self, config: Dict[str, str]) -> None: super().__init__(config) self.init(self.__params_map, config) @override - def to_fn_parameters(self) -> Dict[str, str]: + def to_fn_parameters(self) -> Dict[str, Any]: return { - "lr": self.lr_, + "lr": float(self.lr_), } -OPTIMIZERCONFIG_CLASS = { +OPTIMIZERCONFIG_CLASS: Dict[str, Type[OptimizerConfig]] = { "sgd": SGDOptimizerConfig, - "adamw": AdamWOptimizerConfig + "adamw": AdamWOptimizerConfig, } diff --git a/mlora/config/task.py b/mlora/config/task.py index 5834d316..74ecd518 100644 --- a/mlora/config/task.py +++ b/mlora/config/task.py @@ -1,26 +1,26 @@ import logging -from typing import Dict +from typing import Dict, Mapping, Optional, Type +from .adapter import AdapterConfig from .config import DictConfig from .dataset import DatasetConfig -from .adapter import AdapterConfig class TaskConfig(DictConfig): - name_: str = "" - type_: str = "" - - adapter_: AdapterConfig = None - dataset_: DatasetConfig = None + name_: str + type_: str __params_map: Dict[str, str] = { "name_": "name", "type_": "type", } - def __init__(self, config: Dict[str, str], - adapters: Dict[str, AdapterConfig], - datasets: Dict[str, DatasetConfig]): + def __init__( + self, + config: Dict[str, str], + adapters: Mapping[str, AdapterConfig], + datasets: Mapping[str, DatasetConfig], + ): super().__init__(config) self.init(self.__params_map, config) @@ -29,23 +29,26 @@ def __init__(self, config: Dict[str, str], class TrainTaskConfig(TaskConfig): - batch_size_: int = -1 - mini_batch_size_: int = -1 - num_epochs_: int = -1 - cutoff_len_: int = 256 - save_step_: int = 2000 + batch_size_: int + mini_batch_size_: int + num_epochs_: int + cutoff_len_: int + save_step_: int __params_map: Dict[str, str] = { "batch_size_": "batch_size", "mini_batch_size_": "mini_batch_size", "num_epochs_": "num_epochs", "cutoff_len_": "cutoff_len", - "save_step_": "save_step" + "save_step_": "save_step", } - def __init__(self, config: Dict[str, str], - adapters: Dict[str, AdapterConfig], - datasets: Dict[str, DatasetConfig]): + def __init__( + self, + config: Dict[str, str], + adapters: Mapping[str, AdapterConfig], + datasets: Mapping[str, DatasetConfig], + ): super().__init__(config, adapters, datasets) self.init(self.__params_map, config) @@ -64,21 +67,25 @@ def accumulate_step_(self) -> int: class DPOTaskConfig(TrainTaskConfig): - loss_type_: str = "sigmoid" - beta_: float = 0.2 - label_smoothing_: float = 0.0 + loss_type_: str + beta_: float + label_smoothing_: float - reference_: AdapterConfig = None + # is reference is None, use the base model + reference_: Optional[AdapterConfig] __params_map: Dict[str, str] = { "loss_type_": "loss_type", "beta_": "beta", - "label_smoothing_": "label_smoothing" + "label_smoothing_": "label_smoothing", } - def __init__(self, config: Dict[str, str], - adapters: Dict[str, AdapterConfig], - datasets: Dict[str, DatasetConfig]): + def __init__( + self, + config: Dict[str, str], + adapters: Mapping[str, AdapterConfig], + datasets: Mapping[str, DatasetConfig], + ): super().__init__(config, adapters, datasets) self.init(self.__params_map, config) @@ -88,32 +95,34 @@ def __init__(self, config: Dict[str, str], if config["reference"] not in adapters: self.reference_ = None logging.info( - f"DPOTask - {self.adapter_.name_} use the base model as reference model.") + f"DPOTask - {self.adapter_.name_} " + + "use the base model as reference model." + ) return self.reference_ = adapters[config["reference"]] class CPOTaskConfig(TrainTaskConfig): - loss_type_: str = "sigmoid" - beta_: float = 0.2 + loss_type_: str + beta_: float - __params_map: Dict[str, str] = { - "loss_type_": "loss_type", - "beta_": "beta" - } + __params_map: Dict[str, str] = {"loss_type_": "loss_type", "beta_": "beta"} - def __init__(self, config: Dict[str, str], - adapters: Dict[str, AdapterConfig], - datasets: Dict[str, DatasetConfig]): + def __init__( + self, + config: Dict[str, str], + adapters: Mapping[str, AdapterConfig], + datasets: Mapping[str, DatasetConfig], + ): super().__init__(config, adapters, datasets) self.init(self.__params_map, config) self.beta_ = float(self.beta_) -TASKCONFIG_CLASS = { +TASKCONFIG_CLASS: Dict[str, Type[TaskConfig]] = { "train": TrainTaskConfig, "dpo": DPOTaskConfig, - "cpo": CPOTaskConfig + "cpo": CPOTaskConfig, } diff --git a/mlora/evaluator/evaluator.py b/mlora/evaluator/evaluator.py deleted file mode 100644 index 96805085..00000000 --- a/mlora/evaluator/evaluator.py +++ /dev/null @@ -1,13 +0,0 @@ -from mlora.model.llm.model_llm import LLMModel -from mlora.model.tokenizer.tokenizer import Tokenizer - -from abc import ABCMeta, abstractmethod - - -class Evaluator(metaclass=ABCMeta): - model_: LLMModel = None - tokenizer_: Tokenizer = None - - @abstractmethod - def evaluate(self) -> float: - ... diff --git a/mlora/evaluator/evaluator_factory.py b/mlora/evaluator/evaluator_factory.py deleted file mode 100644 index 7f2948ca..00000000 --- a/mlora/evaluator/evaluator_factory.py +++ /dev/null @@ -1,23 +0,0 @@ -from mlora.model.llm.model_llm import LLMModel -from mlora.model.tokenizer.tokenizer import Tokenizer -from mlora.evaluator.evaluator import Evaluator -from mlora.evaluator.mmlu_evaluator import MMLUEvaluator - - -class EvaluatorFactory(): - @staticmethod - def create(model: LLMModel, - tokenizer: Tokenizer, - evaluator_type: str, - data: str) -> Evaluator: - type_args = evaluator_type.split(":") - - assert len(type_args) >= 1, f"error args {type_args}" - - # the first is the evaluator_class - evaluator_class = type_args[0] - - if evaluator_class == "mmlu": - return MMLUEvaluator(model, tokenizer, data, type_args[1:]) - else: - raise f"Not support: {evaluator_class}" diff --git a/mlora/evaluator/mmlu_evaluator.py b/mlora/evaluator/mmlu_evaluator.py deleted file mode 100644 index 5b329390..00000000 --- a/mlora/evaluator/mmlu_evaluator.py +++ /dev/null @@ -1,188 +0,0 @@ -from mlora.model.llm.model_llm import LLMModel -from mlora.model.args import Tokens, MLoRABatchData -from mlora.model.tokenizer.tokenizer import Tokenizer -from mlora.evaluator.evaluator_factory import Evaluator - -import math -import torch -import datasets -import logging - -from typing import List, Tuple - - -class MMLUEvaluator(Evaluator): - data_: str = "" - choices_map_: List[str] = None - choices_map_tokens_: List[int] = None - subject_: str = "all" - kshot_: int = 5 - max_seq_: int = 2048 - batch_size_: int = 2 - - def parse_arguments(self, args: List[str]): - # get arguments from the args mmlu::[kshot]:[batch_size]:[max_seq] - IDX_SUBJECT = 0 - IDX_KSHOT = 1 - IDX_BATCH_SIZE = 2 - IDX_MAX_SEQ_LEN = 3 - - if len(args) >= 1 and args[IDX_SUBJECT] != "": - self.subject_ = args[IDX_SUBJECT] - - if len(args) >= 2 and args[IDX_KSHOT] != "": - assert args[IDX_KSHOT].isdigit( - ), f"argument error {args[IDX_KSHOT]} must digit." - self.kshot_ = int(args[IDX_KSHOT]) - - if len(args) >= 3 and args[IDX_BATCH_SIZE] != "": - assert args[IDX_BATCH_SIZE].isdigit( - ), f"argument error {args[IDX_BATCH_SIZE]} must digit." - self.batch_size_ = int(args[IDX_BATCH_SIZE]) - - if len(args) >= 4 and args[IDX_MAX_SEQ_LEN] != "": - assert args[IDX_MAX_SEQ_LEN].isdigit( - ), f"argument error {args[IDX_MAX_SEQ_LEN]} must digit." - self.max_seq_ = int(args[IDX_MAX_SEQ_LEN]) - - def __init__(self, - model: LLMModel, - tokenizer: Tokenizer, - data: str, - args: List[str]): - super().__init__() - # data_: the path or name of mmlu datasets - self.parse_arguments(args) - - self.model_ = model - self.tokenizer_ = tokenizer - self.data_ = data - - self.choices_map_ = ["A", "B", "C", "D"] - self.choices_map_tokens_ = [self.tokenizer_.encode( - choice, bos=False, eos=False)[0] for choice in self.choices_map_] - - def prepare_evaluate_data(self, subject: str) -> Tuple[List[Tokens], List[str]]: - # return val: the tokens and the labels - mmlu_data_set = datasets.load_dataset(self.data_, subject) - dev_data = mmlu_data_set["dev"] - test_data = mmlu_data_set["test"] - - def format_subject(subject: str): - return subject.replace("_", " ") - - def format_prompt(data_point, with_answer=True): - # get the question and choices like: - # - # A. - # B. - # .... - # Answer: - choices_list = [f"{key}. {choice}\n" for key, - choice in zip(self.choices_map_, data_point["choices"])] - - question = data_point["question"].strip() - choices = "".join(choices_list) - - prompt = f"{question}\n{choices}Answer:" - - if with_answer: - prompt += " {}\n\n".format( - self.choices_map_[data_point["answer"]]) - - return prompt - - all_tokens: List[Tokens] = [] - all_labels: List[str] = [] - - for test_data_point in test_data: - kshot_prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format( - format_subject(subject)) - test_prompt = format_prompt(test_data_point, False) - - tokens: Tokens = [] - - # prepare kshot case - # kshot_prompt + [dev_prompt] + test_prompt, see format_prompt function - for (dev_shot_idx, dev_data_point) in enumerate(dev_data): - if dev_shot_idx >= self.kshot_: - break - # prepare the prompt, use k-show - dev_prompt = format_prompt(dev_data_point, True) - tmp_tokens = self.tokenizer_.encode( - kshot_prompt + dev_prompt + test_prompt, bos=True, eos=False) - if len(tmp_tokens) > self.max_seq_: - break - # to ensure the last one is eos token - if len(tmp_tokens) == self.max_seq_: - tmp_tokens[-1] = self.tokenizer_.eos_id_ - tokens = tmp_tokens - kshot_prompt += dev_prompt - - assert len(tokens) > 0 - all_tokens.append(tokens) - all_labels.append(self.choices_map_[test_data_point["answer"]]) - - return all_tokens, all_labels - - def get_choice_from_output(self, prob: torch.Tensor) -> str: - choice_prob = torch.tensor([prob[choice_token] - for choice_token in self.choices_map_tokens_]) - max_prob = torch.argmax(choice_prob).item() - return self.choices_map_[max_prob] - - def model_inference(self, tokens: List[Tokens]) -> List[str]: - - def pad_to_len(wait_to_pad: Tokens, seq_len: int): - while len(wait_to_pad) < seq_len: - wait_to_pad.append(self.tokenizer_.pad_id_) - return wait_to_pad - - choice_probs: List[str] = [] - - for start_pos in range(0, len(tokens), self.batch_size_): - batch_data = tokens[start_pos:start_pos + self.batch_size_] - # bd short for batch data - bd_tokens_len = [len(token) for token in batch_data] - # get max seq len and align with 8 - to_pad_len = max(bd_tokens_len) - to_pad_len = math.ceil(to_pad_len / 8) * 8 - # pad to it - aligned_batch_data = [pad_to_len(token, to_pad_len) - for token in batch_data] - # generate the pad - attention_mask = [self.tokenizer_.mask_from( - token) for token in aligned_batch_data] - - # TODO lora model - output: torch.Tensor = self.model_.forward(MLoRABatchData(batch_tokens_=aligned_batch_data, - batch_mask_=attention_mask, - lora_batch_data_config_=None, - inference_model_=True)) - # only get the last predict value - output = [output[idx][len - 1] - for idx, len in enumerate(bd_tokens_len)] - # get the choice - choice_probs.extend([self.get_choice_from_output( - each_output)for each_output in output]) - - del output - - return choice_probs - - def evaluate_subject(self, subject: str) -> List[bool]: - tokens, labels = self.prepare_evaluate_data(subject) - choices: List[str] = self.model_inference(tokens) - assert len(choices) == len(labels) - - result: List[bool] = [] - for idx in range(0, len(choices)): - result.append(labels[idx] == choices[idx]) - - return result - - @torch.inference_mode() - def evaluate(self) -> float: - logging.info(f"Performing MMLU/{self.subject_} Benchmark.") - result = self.evaluate_subject(self.subject_) - logging.info(f"Accuracy: {sum(result) / len(result)}") diff --git a/mlora/executor/__init__.py b/mlora/executor/__init__.py index 574cd2aa..69514566 100644 --- a/mlora/executor/__init__.py +++ b/mlora/executor/__init__.py @@ -1,5 +1,3 @@ from .executor import Executor -__all__ = [ - "Executor" -] +__all__ = ["Executor"] diff --git a/mlora/executor/context/__init__.py b/mlora/executor/context/__init__.py index 1a682c06..33309282 100644 --- a/mlora/executor/context/__init__.py +++ b/mlora/executor/context/__init__.py @@ -1,15 +1,19 @@ +from typing import Dict, Type + from .context import TaskContext -from .lora import TrainLoRAContext, InferenceLoRAContext +from .inference import InferenceTaskContext +from .lora import InferenceLoRAContext, TrainLoRAContext from .loraplus import TrainLoRAPlusContext +from .train import TrainTaskContext -TRAINCONTEXT_CLASS = { +TRAINCONTEXT_CLASS: Dict[str, Type[TrainTaskContext]] = { "lora": TrainLoRAContext, - "loraplus": TrainLoRAPlusContext + "loraplus": TrainLoRAPlusContext, } -INFERENCECONTEXT_CLASS = { +INFERENCECONTEXT_CLASS: Dict[str, Type[InferenceTaskContext]] = { "lora": InferenceLoRAContext, - "loraplus": InferenceLoRAContext + "loraplus": InferenceLoRAContext, } @@ -17,6 +21,7 @@ "TRAINCONTEXT_CLASS", "INFERENCECONTEXT_CLASS", "TaskContext", + "TrainTaskContext", "TrainLoRAContext", "InferenceLoRAContext", "TrainLoRAPlusContext", diff --git a/mlora/executor/context/context.py b/mlora/executor/context/context.py index 2fc571aa..a452d889 100644 --- a/mlora/executor/context/context.py +++ b/mlora/executor/context/context.py @@ -1,41 +1,43 @@ -from mlora.model.modules import AdapterModel -from mlora.model.args import LinearInfo -from mlora.config import AdapterConfig +from abc import ABCMeta, abstractmethod +from typing import Dict, List, OrderedDict import torch.optim -from typing import Dict, List, OrderedDict -from abc import ABCMeta, abstractmethod + +from mlora.config import AdapterConfig +from mlora.model.args import LinearInfo +from mlora.model.modules import AdapterModel class TaskContext(metaclass=ABCMeta): - type_: str = "" - name_: str = "" - path_: str = "" + type_: str + name_: str + path_: str + + config_: AdapterConfig + + device_: str - device_: str = "" + adapter_model_: AdapterModel - adapter_model_: AdapterModel = {} + def __init__(self, config: AdapterConfig) -> None: + self.type_ = config.type_ + self.name_ = config.name_ + self.path_ = config.path_ - def __init__(self, context_type: str, context_name: str, context_path: str) -> None: - self.type_ = context_type - self.name_ = context_name - self.path_ = context_path + self.config_ = config self.device_ = "cpu" self.adapter_model_ = {} @abstractmethod - def switch_device(self, device: str) -> None: - ... + def switch_device(self, device: str) -> None: ... @abstractmethod - def step(self) -> None: - ... + def step(self) -> None: ... @abstractmethod - def load_weight(self, config: AdapterConfig, linears_info: OrderedDict[str, LinearInfo]): - ... + def load_weight(self, linears_info: OrderedDict[str, LinearInfo]): ... def adapter_model(self) -> AdapterModel: return self.adapter_model_ diff --git a/mlora/executor/context/inference.py b/mlora/executor/context/inference.py index 9916ffe6..0f1eb715 100644 --- a/mlora/executor/context/inference.py +++ b/mlora/executor/context/inference.py @@ -1,18 +1,19 @@ +from collections import OrderedDict from mlora.config import AdapterConfig from mlora.model.args import LinearInfo -from collections import OrderedDict - from .context import TaskContext class InferenceTaskContext(TaskContext): - def __init__(self, config: AdapterConfig, linears_info: OrderedDict[str, LinearInfo]) -> None: - super().__init__(config.type_, config.name_, config.path_) + def __init__( + self, config: AdapterConfig, linears_info: OrderedDict[str, LinearInfo] + ) -> None: + super().__init__(config) # load the adapter's weight - self.load_weight(config, linears_info) + self.load_weight(linears_info) # disable all the weights' grad for module in self.adapter_model_.values(): @@ -27,5 +28,4 @@ def switch_device(self, device: str) -> None: self.device_ = device - def step(self) -> None: - ... + def step(self) -> None: ... diff --git a/mlora/executor/context/lora.py b/mlora/executor/context/lora.py index 3abbe10a..5cfa0a5b 100644 --- a/mlora/executor/context/lora.py +++ b/mlora/executor/context/lora.py @@ -1,68 +1,89 @@ -from mlora.config import LoRAConfig -from mlora.config.adapter import AdapterConfig -from mlora.model.modules import LoRA -from mlora.model.args import LinearInfo - -import os -import torch import logging -from typing import Dict +import os from collections import OrderedDict +from typing import Dict, override + +import torch + +from mlora.config import LoRAConfig +from mlora.model.args import LinearInfo +from mlora.model.modules import LoRA from .context import TaskContext -from .train import TrainTaskContext from .inference import InferenceTaskContext +from .train import TrainTaskContext -def _load_lora_weight(obj: TaskContext, - config: LoRAConfig, - linears_info: OrderedDict[str, LinearInfo]): +def _load_lora_weight( + obj: TaskContext, config: LoRAConfig, linears_info: OrderedDict[str, LinearInfo] +): # init the weight for linear_name, linear_info in linears_info.items(): - target_name = linear_name.split('.')[3] + target_name = linear_name.split(".")[3] if target_name not in config.target_: continue if config.target_[target_name] is not True: continue - obj.adapter_model_[linear_name] = LoRA(config.name_, - linear_info.in_dim_, linear_info.out_dim_, - config.r_, config.alpha_, config.dropout_) + obj.adapter_model_[linear_name] = LoRA( + config.name_, + linear_info.in_dim_, + linear_info.out_dim_, + config.r_, + config.alpha_, + config.dropout_, + ) weight_dict = None if os.path.isdir(obj.path_): - logging.info( - f"Adapter {obj.name_}:{obj.path_} weight exist, load from file.") + logging.info(f"Adapter {obj.name_}:{obj.path_} weight exist, load from file.") weight_dict = torch.load(f"{obj.path_}{os.sep}adapter_model.bin") prefix_name = "base_model.model.model." else: logging.info( - f"Adapter {obj.name_}:{obj.path_} weight not exist, use the default weight.") + f"Adapter {obj.name_}:{obj.path_} weight not exist, use the default weight." + ) for name, module in obj.adapter_model_.items(): - lora_a = None if weight_dict is None else weight_dict[prefix_name + - name + ".lora_A.weight"] - lora_b = None if weight_dict is None else weight_dict[prefix_name + - name + ".lora_B.weight"] + lora_a = ( + None + if weight_dict is None + else weight_dict[prefix_name + name + ".lora_A.weight"] + ) + lora_b = ( + None + if weight_dict is None + else weight_dict[prefix_name + name + ".lora_B.weight"] + ) module.init_weight(lora_a, lora_b) class InferenceLoRAContext(InferenceTaskContext): - def __init__(self, config: AdapterConfig, linears_info: OrderedDict[str, LinearInfo]) -> None: + config_: LoRAConfig + + def __init__( + self, config: LoRAConfig, linears_info: OrderedDict[str, LinearInfo] + ) -> None: super().__init__(config, linears_info) - def load_weight(self, config: LoRAConfig, linears_info: OrderedDict[str, LinearInfo]): - _load_lora_weight(self, config, linears_info) + @override + def load_weight(self, linears_info: OrderedDict[str, LinearInfo]): + _load_lora_weight(self, self.config_, linears_info) class TrainLoRAContext(TrainTaskContext): - def __init__(self, config: LoRAConfig, linears_info: OrderedDict[str, LinearInfo]) -> None: + config_: LoRAConfig + + def __init__( + self, config: LoRAConfig, linears_info: OrderedDict[str, LinearInfo] + ) -> None: super().__init__(config, linears_info) self.loss_fn_ = torch.nn.CrossEntropyLoss() - def load_weight(self, config: LoRAConfig, linears_info: OrderedDict[str, LinearInfo]): - _load_lora_weight(self, config, linears_info) + @override + def load_weight(self, linears_info: OrderedDict[str, LinearInfo]): + _load_lora_weight(self, self.config_, linears_info) def weight_dict(self) -> Dict[str, torch.Tensor]: # base_model.model.model.layers.{0}.self_attn.{q_proj}.{lora_A}.weight diff --git a/mlora/executor/context/loraplus.py b/mlora/executor/context/loraplus.py index 3c89fbe4..54171f3c 100644 --- a/mlora/executor/context/loraplus.py +++ b/mlora/executor/context/loraplus.py @@ -1,25 +1,30 @@ -from mlora.config import LoRAPlusConfig, OptimizerConfig -from mlora.model.args import LinearInfo - -import torch from collections import OrderedDict from typing import List, override -from .train import OPTIMIZER_CLASS +import torch + +from mlora.config import LoRAPlusConfig, OptimizerConfig +from mlora.model.args import LinearInfo + from .lora import TrainLoRAContext +from .train import OPTIMIZER_CLASS class TrainLoRAPlusContext(TrainLoRAContext): - lr_ratio_: float = 8.0 + lr_ratio_: float - def __init__(self, config: LoRAPlusConfig, linears_info: OrderedDict[str, LinearInfo]) -> None: - super().__init__(config, linears_info) + def __init__( + self, config: LoRAPlusConfig, linears_info: OrderedDict[str, LinearInfo] + ) -> None: self.lr_ratio_ = float(config.lr_ratio_) + super().__init__(config, linears_info) + @override - def create_optimizer(self, optim_config: OptimizerConfig): - optimizer_type_ = optim_config.optimizer_ + def create_optimizer(self, optim_config: OptimizerConfig | None): + assert optim_config is not None + optimizer_type_ = optim_config.optimizer_ assert optimizer_type_ in OPTIMIZER_CLASS lora_a_parameters: List[torch.Tensor] = [] @@ -31,9 +36,12 @@ def create_optimizer(self, optim_config: OptimizerConfig): parameters = [ {"params": lora_a_parameters}, - {"params": lora_b_parameters, - "lr": self.lr_ratio_ * float(optim_config.lr_)} + { + "params": lora_b_parameters, + "lr": self.lr_ratio_ * float(optim_config.lr_), + }, ] self.optimizer_ = OPTIMIZER_CLASS[optimizer_type_]( - parameters, **optim_config.to_fn_parameters()) + parameters, **optim_config.to_fn_parameters() + ) diff --git a/mlora/executor/context/train.py b/mlora/executor/context/train.py index 39988cb2..c336c85f 100644 --- a/mlora/executor/context/train.py +++ b/mlora/executor/context/train.py @@ -1,50 +1,46 @@ -from mlora.config import AdapterConfig, OptimizerConfig, LRSchedulerConfig -from mlora.model.args import LinearInfo - -import torch from abc import abstractmethod -from typing import List, Dict, Callable, Optional from collections import OrderedDict +from typing import Callable, Dict, List, Type + +import torch + +from mlora.config import AdapterConfig, LRSchedulerConfig, OptimizerConfig +from mlora.model.args import LinearInfo from .context import TaskContext -OPTIMIZER_CLASS = { - "sgd": torch.optim.SGD, - "adamw": torch.optim.AdamW -} +OPTIMIZER_CLASS = {"sgd": torch.optim.SGD, "adamw": torch.optim.AdamW} -LR_SCHEDULER_CLASS = { +LR_SCHEDULER_CLASS: Dict[str, Type[torch.optim.lr_scheduler.LRScheduler]] = { "cosine": torch.optim.lr_scheduler.CosineAnnealingLR, } class TrainTaskContext(TaskContext): - loss_fn_: Callable = None - optimizer_: torch.optim.Optimizer = None - lr_scheduler_: torch.optim.lr_scheduler.LRScheduler = None + loss_fn_: Callable + optimizer_: torch.optim.Optimizer + lr_scheduler_: torch.optim.lr_scheduler.LRScheduler | None - def __init__(self, config: AdapterConfig, linears_info: OrderedDict[str, LinearInfo]) -> None: - super().__init__(config.type_, config.name_, config.path_) + def __init__( + self, config: AdapterConfig, linears_info: OrderedDict[str, LinearInfo] + ) -> None: + super().__init__(config) # load the adapter's weight - self.load_weight(config, linears_info) + self.load_weight(linears_info) for module in self.adapter_model_.values(): module.enable_grad() - # init the optimizer - self.loss_fn_ = None - self.optimizer_ = None - self.lr_scheduler_ = None - self.create_optimizer(config.optimizer_config_) self.create_lr_scheduler(config.lr_scheduler_config_) @abstractmethod - def weight_dict(self) -> Dict[str, torch.Tensor]: - ... + def weight_dict(self) -> Dict[str, torch.Tensor]: ... + + def create_optimizer(self, optim_config: OptimizerConfig | None): + assert optim_config is not None - def create_optimizer(self, optim_config: OptimizerConfig): optimizer_type_ = optim_config.optimizer_ assert optimizer_type_ in OPTIMIZER_CLASS @@ -53,18 +49,21 @@ def create_optimizer(self, optim_config: OptimizerConfig): parameters.extend(adapter.get_tensors()) self.optimizer_ = OPTIMIZER_CLASS[optimizer_type_]( - parameters, **optim_config.to_fn_parameters()) + parameters, **optim_config.to_fn_parameters() + ) - def create_lr_scheduler(self, lr_scheduler_config: Optional[LRSchedulerConfig]): + def create_lr_scheduler(self, lr_scheduler_config: LRSchedulerConfig | None): assert self.optimizer_ is not None if lr_scheduler_config is None: + self.lr_scheduler_ = None return lr_scheduler_type_ = lr_scheduler_config.lr_scheduler_ assert lr_scheduler_type_ in LR_SCHEDULER_CLASS self.lr_scheduler_ = LR_SCHEDULER_CLASS[lr_scheduler_type_]( - self.optimizer_, **lr_scheduler_config.to_fn_parameters()) + self.optimizer_, **lr_scheduler_config.to_fn_parameters() # type: ignore + ) def switch_device(self, device: str) -> None: if self.device_ == device: diff --git a/mlora/executor/dispatcher/__init__.py b/mlora/executor/dispatcher/__init__.py index 14c4dfde..2652fe6c 100644 --- a/mlora/executor/dispatcher/__init__.py +++ b/mlora/executor/dispatcher/__init__.py @@ -1,13 +1,6 @@ -from .dispatcher import Dispatcher from .backend_dispatcher import BackendDispatcher +from .dispatcher import Dispatcher -DISPATCHER_CLASS = { - "default": Dispatcher, - "backend": BackendDispatcher -} +DISPATCHER_CLASS = {"default": Dispatcher, "backend": BackendDispatcher} -__all__ = [ - "Dispatcher", - "BackendDispatcher", - "DISPATCHER_CLASS" -] +__all__ = ["Dispatcher", "BackendDispatcher", "DISPATCHER_CLASS"] diff --git a/mlora/executor/dispatcher/backend_dispatcher.py b/mlora/executor/dispatcher/backend_dispatcher.py index 35b75223..9dbe6400 100644 --- a/mlora/executor/dispatcher/backend_dispatcher.py +++ b/mlora/executor/dispatcher/backend_dispatcher.py @@ -1,16 +1,15 @@ -from mlora.config.dispatcher import DispatcherConfig -from mlora.config.task import TaskConfig - import logging import threading from typing import override +from mlora.config.dispatcher import DispatcherConfig +from mlora.config.task import TaskConfig from .dispatcher import Dispatcher class BackendDispatcher(Dispatcher): - sem_: threading.Semaphore = None + sem_: threading.Semaphore def __init__(self, config: DispatcherConfig) -> None: super().__init__(config) diff --git a/mlora/executor/dispatcher/dispatcher.py b/mlora/executor/dispatcher/dispatcher.py index dbd7ee54..798cacf8 100644 --- a/mlora/executor/dispatcher/dispatcher.py +++ b/mlora/executor/dispatcher/dispatcher.py @@ -1,14 +1,14 @@ +import math +from typing import Any, Callable, Dict, List, Tuple + from mlora.config.dispatcher import DispatcherConfig from mlora.config.task import TaskConfig -from mlora.executor.task import Task, TASK_CLASS -from mlora.model.args import Tokens, Masks, MLoRADataConfig, MLoRAData - -import math -from typing import List, Callable, Tuple, Dict +from mlora.executor.task import TASK_CLASS, Task +from mlora.model.args import Masks, MLoRAData, MLoRADataConfig, Tokens -class DispatcherEvent(): - callback_list_: List[Callable] = None +class DispatcherEvent: + callback_list_: List[Callable] def __init__(self): self.callback_list_ = [] @@ -22,17 +22,17 @@ def notify(self, task: Task) -> None: class Dispatcher: - name_: str = "" + name_: str - ready_: List[Task] = [] - running_: List[Task] = [] - done_: List[Task] = [] + ready_: List[Task] + running_: List[Task] + done_: List[Task] - init_event_: DispatcherEvent = DispatcherEvent() - running_event_: DispatcherEvent = DispatcherEvent() - ready_event_: DispatcherEvent = DispatcherEvent() - done_event_: DispatcherEvent = DispatcherEvent() - step_event_: DispatcherEvent = DispatcherEvent() + init_event_: DispatcherEvent + running_event_: DispatcherEvent + ready_event_: DispatcherEvent + done_event_: DispatcherEvent + step_event_: DispatcherEvent concurrency_num_: int = 2 @@ -40,11 +40,18 @@ def __init__(self, config: DispatcherConfig) -> None: self.name_ = config.name_ self.concurrency_num_ = config.concurrency_num_ - def info(self) -> Dict[str, str]: - return { - "name": self.name_, - "concurrency_num": self.concurrency_num_ - } + self.ready_ = [] + self.running_ = [] + self.done_ = [] + + self.init_event_ = DispatcherEvent() + self.running_event_ = DispatcherEvent() + self.ready_event_ = DispatcherEvent() + self.done_event_ = DispatcherEvent() + self.step_event_ = DispatcherEvent() + + def info(self) -> Dict[str, Any]: + return {"name": self.name_, "concurrency_num": self.concurrency_num_} def register_hook(self, name: str, cb: Callable) -> None: event_map = { @@ -52,7 +59,7 @@ def register_hook(self, name: str, cb: Callable) -> None: "running": self.running_event_, "ready": self.ready_event_, "done": self.done_event_, - "step": self.step_event_ + "step": self.step_event_, } assert name in event_map @@ -87,8 +94,9 @@ def _dispatch_task_out(self): for task in done_task: self.done_event_.notify(task) - def _align_batch_tokens(self, batch_tokens: List[Tokens], - configs: List[MLoRADataConfig]) -> Tuple[List[Tokens], List[Masks]]: + def _align_batch_tokens( + self, batch_tokens: List[Tokens], configs: List[MLoRADataConfig] + ) -> Tuple[List[Tokens], List[Masks]]: max_seq_len = max(map(lambda x: len(x), batch_tokens)) max_seq_len = math.ceil(max_seq_len / 8) * 8 @@ -98,7 +106,8 @@ def _align_batch_tokens(self, batch_tokens: List[Tokens], s_idx = data_config.batch_start_idx_ e_idx = data_config.batch_end_idx_ batch_tokens[s_idx:e_idx], masks = data_config.expand_fn_( - batch_tokens[s_idx:e_idx], max_seq_len) + batch_tokens[s_idx:e_idx], max_seq_len + ) batch_masks.extend(masks) return batch_tokens, batch_masks @@ -119,12 +128,11 @@ def data(self) -> MLoRAData: start_idx = start_idx + len(data) # post process this batch data - batch_tokens, batch_masks = self._align_batch_tokens( - batch_tokens, data_configs) + batch_tokens, batch_masks = self._align_batch_tokens(batch_tokens, data_configs) - return MLoRAData(batch_tokens=batch_tokens, - batch_mask=batch_masks, - data_config=data_configs) + return MLoRAData( + batch_tokens=batch_tokens, batch_mask=batch_masks, data_config=data_configs + ) def step(self): for _, task in enumerate(self.running_): diff --git a/mlora/executor/dispatcher/pipeline_dispatcher.py b/mlora/executor/dispatcher/pipeline_dispatcher.py deleted file mode 100644 index 91072ab0..00000000 --- a/mlora/executor/dispatcher/pipeline_dispatcher.py +++ /dev/null @@ -1,55 +0,0 @@ -from mlora.model.tokenizer import Tokenizer -from mlora.config import MLoRAConfig - -from typing import Dict - -from .dispatcher import Dispatcher - - -class PipelineDispatcher(Dispatcher): - _adapter_backward_cnt_: Dict[str, int] = {} - _adapter_forward_cnt_: Dict[str, int] = {} - _adapter_accumulation_step_: Dict[str, int] = {} - - def __init__(self, - config: MLoRAConfig, - tokenizer: Tokenizer) -> None: - super().__init__(config, tokenizer) - for lora_config in config.lora_configs_: - adapter_name = lora_config.adapter_name_ - accumulation_step = lora_config.batch_size_ / lora_config.micro_batch_size_ - self._adapter_forward_cnt_[adapter_name] = 0 - self._adapter_backward_cnt_[adapter_name] = 0 - self._adapter_accumulation_step_[adapter_name] = accumulation_step - - def update_backward_cnt(self, adapter_name: str): - self._adapter_backward_cnt_[adapter_name] += 1 - if self._adapter_backward_cnt_[adapter_name] == self._adapter_accumulation_step_[adapter_name]: - self._adapter_forward_cnt_[adapter_name] = 0 - self._adapter_backward_cnt_[adapter_name] = 0 - - def update_forward_cnt(self, adapter_name: str): - self._adapter_forward_cnt_[adapter_name] += 1 - - def __check_adapter_available(self, adapter_name: str) -> bool: - return self._adapter_forward_cnt_[adapter_name] < self._adapter_accumulation_step_[adapter_name] - - def rigister_strategies(self): - self.rigister_strategy("pipe", self.pipe_dispatch_strategy) - - def pipe_dispatch_strategy(self) -> Dict[str, any]: - ret_train_data = {} - cnt = 0 - for task in self.running_train_task_: - assert not task.is_train_done() - - # check the adapter is available - if not self.__check_adapter_available(task.adapter_name_): - continue - self.update_forward_cnt(task.adapter_name_) - ret_train_data[task.adapter_name_] = task.get_train_data() - cnt += 1 - if cnt >= self.train_lora_simultaneously_num_: - break - - return ret_train_data diff --git a/mlora/executor/executor.py b/mlora/executor/executor.py index 59dc810b..e2187172 100644 --- a/mlora/executor/executor.py +++ b/mlora/executor/executor.py @@ -1,33 +1,32 @@ +import logging +from typing import Callable, Dict, Optional + +import torch + from mlora.config import MLoRAConfig, TaskConfig +from mlora.model.args import MLoRAData from mlora.model.llm import LLMModel from mlora.model.tokenizer import Tokenizer -from mlora.model.args import MLoRAData -import torch -import logging -from typing import Dict, Callable - -from .dispatcher import Dispatcher, DISPATCHER_CLASS +from .dispatcher import DISPATCHER_CLASS, Dispatcher from .task import Task class Executor: - model_: LLMModel = None - tokenizer_: Tokenizer = None + model_: LLMModel + tokenizer_: Tokenizer - dispatcher_: Dispatcher = None + dispatcher_: Dispatcher - def __init__(self, - model: LLMModel, - tokenizer: Tokenizer, - config: MLoRAConfig) -> None: + def __init__( + self, model: LLMModel, tokenizer: Tokenizer, config: MLoRAConfig + ) -> None: self.model_ = model self.tokenizer_ = tokenizer dispatcher_name = config.dispatcher_.name_ assert dispatcher_name in DISPATCHER_CLASS - self.dispatcher_ = DISPATCHER_CLASS[dispatcher_name]( - config.dispatcher_) + self.dispatcher_ = DISPATCHER_CLASS[dispatcher_name](config.dispatcher_) hook_func = { "init": self.__task_init_hook, @@ -44,14 +43,15 @@ def register_hook(self, name: str, cb: Callable): def __task_init_hook(self, task: Task): logging.info( - f"Init {task.task_type()} : {task.task_name()} task with adapters: {task.adapter_name()}") + f"Init {task.task_type()} : {task.task_name()} " + + f"task with adapters: {task.adapter_name()}" + ) # init the task's dataset # init the task's adapter weight task.prepare(self.model_.linears_info(), self.tokenizer_) def __task_to_running_hook(self, task: Task): - logging.info( - f"Base model load adapters: {task.adapter_name()}") + logging.info(f"Base model load adapters: {task.adapter_name()}") # move the task's adapter weight to the gpu # move the task's optimizer weight to the gpu # attach the adapter to the model @@ -61,8 +61,7 @@ def __task_to_running_hook(self, task: Task): self.model_.load_adapter(adapter_model) def __task_to_ready_hook(self, task: Task): - logging.info( - f"Base model offload adapters: {task.adapter_name()}") + logging.info(f"Base model offload adapters: {task.adapter_name()}") # offload the adapter # move the task's adapter weight to the cpu for adapter_name in task.adapter_name(): @@ -70,8 +69,7 @@ def __task_to_ready_hook(self, task: Task): task.switch_device("cpu") def __task_to_done_hook(self, task: Task): - logging.info( - f"Finish and base model offload adapter - {task.adapter_name()}") + logging.info(f"Finish and base model offload adapter - {task.adapter_name()}") # offload the adapter # move the task's adapter weight to the cpu for adapter_name in task.adapter_name(): @@ -92,15 +90,15 @@ def execute(self) -> None: output = self.model_.forward(data.model_data()) labels = torch.tensor(data.batch_tokens_, dtype=torch.long) - total_loss = None + total_loss: Optional[torch.Tensor] = None for config in data.data_config_: - loss = config.loss_fn_( - output, labels, torch.tensor(data.batch_mask_)) + loss = config.loss_fn_(output, labels, torch.tensor(data.batch_mask_)) if loss is None: continue total_loss = loss if total_loss is None else total_loss + loss - total_loss.backward() + if total_loss is not None: + total_loss.backward() self.dispatcher_.step() diff --git a/mlora/executor/task/__init__.py b/mlora/executor/task/__init__.py index d1b7ddeb..2e842d2c 100644 --- a/mlora/executor/task/__init__.py +++ b/mlora/executor/task/__init__.py @@ -1,19 +1,8 @@ +from .cpo_task import CPOTask +from .dpo_task import DPOTask from .task import Task from .train_task import TrainTask -from .dpo_task import DPOTask -from .cpo_task import CPOTask - -TASK_CLASS = { - "train": TrainTask, - "dpo": DPOTask, - "cpo": CPOTask -} +TASK_CLASS = {"train": TrainTask, "dpo": DPOTask, "cpo": CPOTask} -__all__ = [ - "Task", - "TASK_CLASS", - "TrainTask", - "DPOTask", - "CPOTask" -] +__all__ = ["Task", "TASK_CLASS", "TrainTask", "DPOTask", "CPOTask"] diff --git a/mlora/executor/task/cpo_task.py b/mlora/executor/task/cpo_task.py index 520c1269..dbd08fd4 100644 --- a/mlora/executor/task/cpo_task.py +++ b/mlora/executor/task/cpo_task.py @@ -1,17 +1,21 @@ -from mlora.model.args import LinearInfo, Tokens, MLoRADataConfig -from mlora.model.tokenizer import Tokenizer - -import torch import logging -import torch.nn.functional as F from collections import OrderedDict from typing import List, Tuple, override +import torch +import torch.nn.functional as F + +from mlora.config import CPOTaskConfig +from mlora.executor.context import TrainLoRAContext +from mlora.model.args import LinearInfo, MLoRADataConfig, Tokens +from mlora.model.tokenizer import Tokenizer + from .train_task import TrainTask class CPOTask(TrainTask): - now_epoch_: int = 0 + context_: TrainLoRAContext + config_: CPOTaskConfig @override def prepare(self, linears_info: OrderedDict[str, LinearInfo], tokenizer: Tokenizer): @@ -22,7 +26,7 @@ def prepare(self, linears_info: OrderedDict[str, LinearInfo], tokenizer: Tokeniz LOSS_CLASS = { "sigmoid": self.__cpo_loss_sigmoid, - "hinge": self.__cpo_loss_hinge + "hinge": self.__cpo_loss_hinge, } self.context_.set_loss_fn(LOSS_CLASS[self.config_.loss_type_]) @@ -36,22 +40,30 @@ def __cpo_loss_hinge(self, logits: torch.Tensor) -> torch.Tensor: @override def data(self, start_idx: int) -> Tuple[List[Tokens], List[MLoRADataConfig]]: logging.info( - f'Adapter {self.context_.name_} epoch: { - self.now_epoch_}/{self.config_.num_epochs_}' - f' iteration: {self.now_data_idx_}/{len(self.data_)} step: {self.now_step_}') + f"Adapter {self.context_.name_} epoch: { + self.now_epoch_}/{self.config_.num_epochs_}" + f" iteration: {self.now_data_idx_}/{len(self.data_)} step: {self.now_step_}" + ) data_idx_s = self.now_data_idx_ data_idx_e = self.now_data_idx_ + self.config_.mini_batch_size_ # get the train raw string - batch_str = self.prompter_.generate_prompt( - self.data_[data_idx_s:data_idx_e]) + batch_str = self.prompter_.generate_prompt(self.data_[data_idx_s:data_idx_e]) # convert the string to tokens - ret_tokens = list(map(lambda raw_str: self.tokenizer_.encode( - raw_str, bos=True, eos=True, cutoff_len=self.config_.cutoff_len_), batch_str)) + ret_tokens = list( + map( + lambda raw_str: self.tokenizer_.encode( + raw_str, bos=True, eos=True, cutoff_len=self.config_.cutoff_len_ + ), + batch_str, + ) + ) end_idx = start_idx + len(ret_tokens) - def loss_fn(input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + def loss_fn( + input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor + ) -> torch.Tensor: mask = ~mask[start_idx:end_idx:, 1:] data_len = end_idx - start_idx @@ -59,21 +71,20 @@ def loss_fn(input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> to data_len = data_len // 2 batch_input = input[start_idx:end_idx, :-1, :].contiguous() - batch_label = target[start_idx:end_idx, - 1:].contiguous().to(input.device) + batch_label = target[start_idx:end_idx, 1:].contiguous().to(input.device) mask = mask.long().to(input.device) # step1. calc the chose loss vacab_size = input.shape[-1] chose_input = batch_input[:data_len].view(-1, vacab_size) chose_label = batch_label[:data_len].view(-1) - loss_chosen: torch.Tensor = F.cross_entropy( - chose_input, chose_label) + loss_chosen: torch.Tensor = F.cross_entropy(chose_input, chose_label) # step2. calc the prefer loss logits = batch_input.log_softmax(-1) per_token_logps = torch.gather( - logits, dim=2, index=batch_label.unsqueeze(2)).squeeze(2) + logits, dim=2, index=batch_label.unsqueeze(2) + ).squeeze(2) logps = (per_token_logps * mask).sum(-1) @@ -87,8 +98,13 @@ def loss_fn(input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> to logging.info(f"Adapter {self.context_.name_} loss: {loss}") return loss - data_config = MLoRADataConfig(self.context_.name_, self.context_.type_, - start_idx, end_idx, - self._expand_batch_tokens, loss_fn) + data_config = MLoRADataConfig( + self.context_.name_, + self.context_.type_, + start_idx, + end_idx, + self._expand_batch_tokens, + loss_fn, + ) return ret_tokens, [data_config] diff --git a/mlora/executor/task/dpo_task.py b/mlora/executor/task/dpo_task.py index 07af4e71..30e190b3 100644 --- a/mlora/executor/task/dpo_task.py +++ b/mlora/executor/task/dpo_task.py @@ -1,19 +1,23 @@ -from mlora.model.tokenizer import Tokenizer -from mlora.model.modules import AdapterModel -from mlora.model.args import Tokens, MLoRADataConfig, LinearInfo -from mlora.executor.context import TaskContext, INFERENCECONTEXT_CLASS - import copy -import torch import logging +from typing import List, Optional, OrderedDict, Tuple, override + +import torch import torch.nn.functional as F -from typing import List, Tuple, OrderedDict, override + +from mlora.config import DPOTaskConfig +from mlora.executor.context import INFERENCECONTEXT_CLASS, TaskContext, TrainTaskContext +from mlora.model.args import LinearInfo, MLoRADataConfig, Tokens +from mlora.model.modules import AdapterModel +from mlora.model.tokenizer import Tokenizer from .train_task import TrainTask class DPOTask(TrainTask): - ref_context_: TaskContext = None + config_: DPOTaskConfig + context_: TrainTaskContext + ref_context_: Optional[TaskContext] @override def prepare(self, linears_info: OrderedDict[str, LinearInfo], tokenizer: Tokenizer): @@ -23,10 +27,7 @@ def prepare(self, linears_info: OrderedDict[str, LinearInfo], tokenizer: Tokeniz self._pre_context(linears_info) self._pre_ref_context(linears_info) - LOSS_CLASS = { - "sigmoid": self.__dpo_loss_sigmoid, - "ipo": self.__dpo_loss_ipo - } + LOSS_CLASS = {"sigmoid": self.__dpo_loss_sigmoid, "ipo": self.__dpo_loss_ipo} self.context_.set_loss_fn(LOSS_CLASS[self.config_.loss_type_]) @@ -37,13 +38,15 @@ def _pre_ref_context(self, linears_info: OrderedDict[str, LinearInfo]): ref_adapter_type = self.config_.reference_.type_ self.ref_context_ = INFERENCECONTEXT_CLASS[ref_adapter_type]( - self.config_.reference_, linears_info) + self.config_.reference_, linears_info + ) def __dpo_loss_sigmoid(self, logits: torch.Tensor) -> torch.Tensor: - loss = -F.logsigmoid(self.config_.beta_ * logits) * \ - (1 - self.config_.label_smoothing_) - \ - F.logsigmoid(-self.config_.beta_ * logits) * \ - self.config_.label_smoothing_ + loss = ( + -F.logsigmoid(self.config_.beta_ * logits) + * (1 - self.config_.label_smoothing_) + - F.logsigmoid(-self.config_.beta_ * logits) * self.config_.label_smoothing_ + ) return loss def __dpo_loss_ipo(self, logits: torch.Tensor) -> torch.Tensor: @@ -59,7 +62,7 @@ def adapter_model(self) -> List[AdapterModel]: @override def adapter_name(self) -> List[str]: - if self.ref_context_ is None: + if self.config_.reference_ is None: return super().adapter_name() return [self.config_.adapter_.name_, self.config_.reference_.name_] @@ -74,22 +77,29 @@ def switch_device(self, device: str): @override def data(self, start_idx: int) -> Tuple[List[Tokens], List[MLoRADataConfig]]: logging.info( - f'Task - {self.context_.name_} epoch: {self.now_epoch_}/{self.config_.num_epochs_}' - f' iteration: {self.now_data_idx_}/{len(self.data_)} step: {self.now_step_}') + f"Task - {self.context_.name_} " + f"epoch: {self.now_epoch_}/{self.config_.num_epochs_}" + f" iteration: {self.now_data_idx_}/{len(self.data_)} step: {self.now_step_}" + ) data_idx_s = self.now_data_idx_ data_idx_e = self.now_data_idx_ + self.config_.mini_batch_size_ # 0...mid is chosen data # mid.end is reject data - batch_str = self.prompter_.generate_prompt( - self.data_[data_idx_s:data_idx_e]) + batch_str = self.prompter_.generate_prompt(self.data_[data_idx_s:data_idx_e]) assert len(batch_str) % 2 == 0 ret_tokens = [] # for refrerence - ref_model_token = list(map(lambda raw_str: self.tokenizer_.encode( - raw_str, bos=True, eos=True, cutoff_len=self.config_.cutoff_len_), batch_str)) + ref_model_token = list( + map( + lambda raw_str: self.tokenizer_.encode( + raw_str, bos=True, eos=True, cutoff_len=self.config_.cutoff_len_ + ), + batch_str, + ) + ) policy_model_token = copy.deepcopy(ref_model_token) ret_tokens.extend(ref_model_token) @@ -104,17 +114,18 @@ def data(self, start_idx: int) -> Tuple[List[Tokens], List[MLoRADataConfig]]: policy_start_idx = ref_end_idx policy_end_idx = policy_start_idx + len(policy_model_token) - def loss_fn(input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + def loss_fn( + input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor + ) -> torch.Tensor: mask = ~mask[ref_start_idx:policy_end_idx, 1:] - logits = input[ref_start_idx:policy_end_idx, - :-1, :].log_softmax(-1) - labels = target[ref_start_idx:policy_end_idx, 1:].to( - input.device) + logits = input[ref_start_idx:policy_end_idx, :-1, :].log_softmax(-1) + labels = target[ref_start_idx:policy_end_idx, 1:].to(input.device) mask = mask.long().to(input.device) per_token_logps = torch.gather( - logits, dim=2, index=labels.unsqueeze(2)).squeeze(2) + logits, dim=2, index=labels.unsqueeze(2) + ).squeeze(2) logps = (per_token_logps * mask).sum(-1) @@ -123,15 +134,14 @@ def loss_fn(input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> to data_len = data_len // 4 ref_chosen_logps, ref_reject_logps, pi_chosen_logps, pi_reject_logps = [ - logps[i * data_len:(i + 1) * data_len] for i in range(4)] + logps[i * data_len : (i + 1) * data_len] for i in range(4) + ] pi = pi_chosen_logps - pi_reject_logps ri = ref_chosen_logps - ref_reject_logps - chosen_reward = (pi_chosen_logps - - ref_chosen_logps) * self.config_.beta_ - reject_reward = (pi_reject_logps - - ref_reject_logps) * self.config_.beta_ + chosen_reward = (pi_chosen_logps - ref_chosen_logps) * self.config_.beta_ + reject_reward = (pi_reject_logps - ref_reject_logps) * self.config_.beta_ logits = pi - ri @@ -140,7 +150,9 @@ def loss_fn(input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> to logging.info( f"Task - {self.context_.name_} loss: {loss}, " - f"chosen_rewards: {chosen_reward.mean()}, rejected_rewards: {reject_reward.mean()}") + f"chosen_rewards: {chosen_reward.mean()}, " + f"rejected_rewards: {reject_reward.mean()}" + ) return loss @@ -151,13 +163,23 @@ def loss_fn(input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> to ref_model_name = self.ref_context_.name_ ref_model_type = self.ref_context_.type_ - ref_model_config = MLoRADataConfig(ref_model_name, ref_model_type, - ref_start_idx, ref_end_idx, - self._expand_batch_tokens, lambda *_: None) - - policy_model_config = MLoRADataConfig(self.context_.name_, self.context_.type_, - policy_start_idx, policy_end_idx, - self._expand_batch_tokens, loss_fn) + ref_model_config = MLoRADataConfig( + ref_model_name, + ref_model_type, + ref_start_idx, + ref_end_idx, + self._expand_batch_tokens, + lambda *_: None, + ) + + policy_model_config = MLoRADataConfig( + self.context_.name_, + self.context_.type_, + policy_start_idx, + policy_end_idx, + self._expand_batch_tokens, + loss_fn, + ) return ret_tokens, [ref_model_config, policy_model_config] diff --git a/mlora/executor/task/task.py b/mlora/executor/task/task.py index 1223102a..fd8573af 100644 --- a/mlora/executor/task/task.py +++ b/mlora/executor/task/task.py @@ -1,41 +1,39 @@ +import logging +from abc import abstractmethod +from collections import OrderedDict +from typing import Callable, Dict, List, Optional, Tuple + +from datasets import load_dataset +from tqdm import tqdm + from mlora.config import TaskConfig -from mlora.prompter import Prompter, PrompterFactory +from mlora.executor.context import TRAINCONTEXT_CLASS, TaskContext +from mlora.model.args import LinearInfo, Masks, MLoRADataConfig, Tokens from mlora.model.modules import AdapterModel -from mlora.model.args import LinearInfo, Tokens, Masks, MLoRADataConfig from mlora.model.tokenizer import Tokenizer -from mlora.executor.context import TaskContext, TRAINCONTEXT_CLASS - -import logging -from tqdm import tqdm -from datasets import load_dataset -from collections import OrderedDict -from typing import Dict, Callable, List, Optional, Tuple -from abc import abstractmethod +from mlora.prompter import Prompter, PrompterFactory class Task: config_: TaskConfig - now_step_: int = 0 + now_step_: int - tokenizer_: Tokenizer = None - context_: TaskContext = None + tokenizer_: Tokenizer + context_: TaskContext - data_: List[Dict[str, str]] = None - now_data_idx_: int = 0 + data_: List[Dict[str, str]] + now_data_idx_: int - prompter_: Prompter = None + prompter_: Prompter - llm_name_: str = "" + llm_name_: str def __init__(self, config: TaskConfig, llm_name: str) -> None: self.config_ = config self.now_step_ = 1 - self.tokenizer_ = None - self.context_ = None - self.data_ = [] self.now_data_idx_ = 0 @@ -48,31 +46,31 @@ def prepare(self, linears_info: OrderedDict[str, LinearInfo], tokenizer: Tokeniz ... @abstractmethod - def done(self): - ... + def done(self): ... @abstractmethod - def step(self): - ... + def step(self): ... @abstractmethod - def is_done(self) -> bool: - ... + def is_done(self) -> bool: ... @abstractmethod - def data(self) -> Tuple[List[Tokens], List[MLoRADataConfig]]: - ... + def data(self, start_idx: int) -> Tuple[List[Tokens], List[MLoRADataConfig]]: ... + + @abstractmethod + def task_progress(self) -> int: ... def _pre_dataset(self): preprocess_func: Dict[str, Callable] = { "default": lambda data: data, "shuffle": lambda data: data.shuffle(), - "sort": lambda data: data.sort() + "sort": lambda data: data.sort(), } logging.info(f"Task load data from {self.config_.dataset_.data_path_}") - data = load_dataset("json", - data_files={"data_points": self.config_.dataset_.data_path_}) + data = load_dataset( + "json", data_files={"data_points": self.config_.dataset_.data_path_} + ) preprocess_type = self.config_.dataset_.preprocess_ if preprocess_type not in preprocess_func: @@ -80,9 +78,9 @@ def _pre_dataset(self): data = preprocess_func[preprocess_type](data) logging.info( - f'Adapter {self.config_.adapter_.name_} data size: { - len(data["data_points"])} ' - f'epoch: {self.config_.num_epochs_} batch size: {self.config_.batch_size_} / {self.config_.mini_batch_size_}') + f"Adapter {self.config_.adapter_.name_} " + f"data size: {len(data["data_points"])}" + ) for _, data_point in tqdm(enumerate(data["data_points"])): self.data_.append(data_point) @@ -91,11 +89,12 @@ def _pre_context(self, linears_info: OrderedDict[str, LinearInfo]): adapter_type = self.config_.adapter_.type_ assert adapter_type in TRAINCONTEXT_CLASS self.context_ = TRAINCONTEXT_CLASS[adapter_type]( - self.config_.adapter_, linears_info) + self.config_.adapter_, linears_info + ) - def _expand_batch_tokens(self, - batch_tokens: List[Tokens], - align_len: Optional[int] = None) -> Tuple[List[Tokens], List[Masks]]: + def _expand_batch_tokens( + self, batch_tokens: List[Tokens], align_len: Optional[int] = None + ) -> Tuple[List[Tokens], List[Masks]]: if align_len is None: align_len = max(map(lambda x: len(x), batch_tokens)) @@ -120,10 +119,5 @@ def task_type(self) -> str: def task_name(self) -> str: return self.config_.name_ - def task_progress(self) -> int: - total_step = len(self.data_) // self.config_.mini_batch_size_ - total_step = total_step * self.config_.num_epochs_ - return int((self.now_step_ / total_step) * 100) - def switch_device(self, device: str): self.context_.switch_device(device) diff --git a/mlora/executor/task/train_task.py b/mlora/executor/task/train_task.py index c9f4ac88..03b8e176 100644 --- a/mlora/executor/task/train_task.py +++ b/mlora/executor/task/train_task.py @@ -1,19 +1,24 @@ -from mlora.config import TaskConfig -from mlora.model.args import LinearInfo, Tokens, Masks, MLoRADataConfig -from mlora.model.tokenizer import Tokenizer - -import os import json -import torch import logging +import os from collections import OrderedDict from typing import Dict, List, Optional, Tuple, override +import torch + +from mlora.config import TaskConfig, TrainTaskConfig +from mlora.executor.context import TrainTaskContext +from mlora.model.args import LinearInfo, Masks, MLoRADataConfig, Tokens +from mlora.model.tokenizer import Tokenizer + from .task import Task class TrainTask(Task): - now_epoch_: int = 0 + now_epoch_: int + + context_: TrainTaskContext + config_: TrainTaskConfig def __init__(self, config: TaskConfig, llm_name: str) -> None: super().__init__(config, llm_name) @@ -26,49 +31,67 @@ def is_done(self) -> bool: @override def prepare(self, linears_info: OrderedDict[str, LinearInfo], tokenizer: Tokenizer): self.tokenizer_ = tokenizer - # prepare the dataset and context + # prepare the context and the dataset self._pre_dataset() self._pre_context(linears_info) @override def data(self, start_idx: int) -> Tuple[List[Tokens], List[MLoRADataConfig]]: logging.info( - f'Adapter {self.context_.name_} epoch: { - self.now_epoch_}/{self.config_.num_epochs_}' - f' iteration: {self.now_data_idx_}/{len(self.data_)} step: {self.now_step_}') + f"Adapter {self.context_.name_} epoch: { + self.now_epoch_}/{self.config_.num_epochs_}" + f" iteration: {self.now_data_idx_}/{len(self.data_)} step: {self.now_step_}" + ) data_idx_s = self.now_data_idx_ data_idx_e = self.now_data_idx_ + self.config_.mini_batch_size_ # get the train raw string - batch_str = self.prompter_.generate_prompt( - self.data_[data_idx_s:data_idx_e]) + batch_str = self.prompter_.generate_prompt(self.data_[data_idx_s:data_idx_e]) # convert the string to tokens - ret_tokens = list(map(lambda raw_str: self.tokenizer_.encode( - raw_str, bos=True, eos=True, cutoff_len=self.config_.cutoff_len_), batch_str)) + ret_tokens = list( + map( + lambda raw_str: self.tokenizer_.encode( + raw_str, bos=True, eos=True, cutoff_len=self.config_.cutoff_len_ + ), + batch_str, + ) + ) end_idx = start_idx + len(ret_tokens) - def loss_fn(input: torch.Tensor, target: torch.Tensor, _: torch.Tensor) -> torch.Tensor: + def loss_fn( + input: torch.Tensor, target: torch.Tensor, _: torch.Tensor + ) -> torch.Tensor: vacab_size = input.shape[-1] - loss_input = input[start_idx:end_idx, :-1, - :].contiguous().view(-1, vacab_size) - loss_target = target[start_idx:end_idx, - 1:].contiguous().view(-1).to(loss_input.device) + loss_input = ( + input[start_idx:end_idx, :-1, :].contiguous().view(-1, vacab_size) + ) + loss_target = ( + target[start_idx:end_idx, 1:] + .contiguous() + .view(-1) + .to(loss_input.device) + ) loss = self.context_.loss_fn_(loss_input, loss_target) logging.info(f"Adapter {self.context_.name_} loss: {loss}") return loss - data_config = MLoRADataConfig(self.context_.name_, self.context_.type_, - start_idx, end_idx, - self._expand_batch_tokens, loss_fn) + data_config = MLoRADataConfig( + self.context_.name_, + self.context_.type_, + start_idx, + end_idx, + self._expand_batch_tokens, + loss_fn, + ) return ret_tokens, [data_config] - def _expand_batch_tokens(self, - batch_tokens: List[Tokens], - align_len: Optional[int] = None) -> Tuple[List[Tokens], List[Masks]]: + def _expand_batch_tokens( + self, batch_tokens: List[Tokens], align_len: Optional[int] = None + ) -> Tuple[List[Tokens], List[Masks]]: if align_len is None: align_len = max(map(lambda x: len(x), batch_tokens)) @@ -89,8 +112,9 @@ def _save(self, dir_suffix: str = "", additional_info: Dict[str, str] = {}): if not os.path.exists(output_dir): os.makedirs(output_dir) - torch.save(self.context_.weight_dict(), - output_dir + os.sep + "adapter_model.bin") + torch.save( + self.context_.weight_dict(), output_dir + os.sep + "adapter_model.bin" + ) adapter_config: Dict[str, str] = {} adapter_config["base_model_name_or_path"] = self.llm_name_ @@ -128,3 +152,9 @@ def step(self): # task finish we also need to step if not stepd and self.now_epoch_ >= self.config_.num_epochs_: self.context_.step() + + @override + def task_progress(self) -> int: + total_step = len(self.data_) // self.config_.mini_batch_size_ + total_step = total_step * self.config_.num_epochs_ + return int((self.now_step_ / total_step) * 100) diff --git a/mlora/model/args.py b/mlora/model/args.py index 63d69e0f..4b4f937d 100644 --- a/mlora/model/args.py +++ b/mlora/model/args.py @@ -1,7 +1,8 @@ -import torch import logging -from typing import List, Callable, Optional, Tuple from dataclasses import dataclass +from typing import Callable, List, Optional, Tuple + +import torch from transformers import PretrainedConfig Tokens = List[int] @@ -10,20 +11,20 @@ @dataclass class LLMModelArgs: - name_or_path_: str = "" - dim_: int = 4096 - multiple_of_: int = 256 - n_heads_: int = 32 - n_kv_heads_: int = 32 - n_layers_: int = 32 - rope_theta_: float = 10000.0 - norm_eps_: float = 1e-06 - hidden_dropout_: float = 0.0 - vocab_size_: int = -1 - pad_token_id_: int = -1 - max_seq_len_: int = 4096 - device_: str = "" - dtype_: torch.dtype = None + name_or_path_: str + dim_: int + multiple_of_: int + n_heads_: int + n_kv_heads_: int + n_layers_: int + rope_theta_: float + norm_eps_: float + hidden_dropout_: float + vocab_size_: int + pad_token_id_: int + max_seq_len_: int + device_: str + dtype_: torch.dtype def __init__(self, config: PretrainedConfig): self.__from_pretrained_config(config) @@ -31,22 +32,35 @@ def __init__(self, config: PretrainedConfig): def __from_pretrained_config(self, config: PretrainedConfig): self.name_or_path_ = config.name_or_path self.dim_ = config.hidden_size + self.multiple_of_ = 256 self.n_heads_ = config.num_attention_heads if hasattr(config, "num_key_value_heads"): self.n_kv_heads_ = config.num_key_value_heads self.n_layers_ = config.num_hidden_layers + self.rope_theta_ = 10000.0 self.norm_eps_ = config.rms_norm_eps + self.hidden_dropout_ = 0.0 self.vocab_size_ = config.vocab_size self.pad_token_id_ = config.pad_token_id + self.max_seq_len_ = 4096 if hasattr(config, "max_sequence_length"): self.max_seq_len_ = config.max_sequence_length - if hasattr(config, "sliding_window") and self.max_seq_len_ > config.sliding_window: + + if ( + hasattr(config, "sliding_window") + and self.max_seq_len_ > config.sliding_window + ): logging.warning( - "Shrink max sequence length to window size of sliding window attention.") + "Shrink max sequence length to window size of sliding window attention." + ) self.max_seq_len_ = config.sliding_window + if hasattr(config, "rope_theta"): self.rope_theta_ = config.rope_theta + self.device_ = "" + self.dtype_ = torch.float32 + @dataclass class LinearInfo: @@ -57,37 +71,45 @@ class LinearInfo: @dataclass class ModelDataConfig: - adapter_name_: str = "" - adapter_type_: str = "" + adapter_name_: str + adapter_type_: str - batch_start_idx_: int = -1 - batch_end_idx_: int = -1 + batch_start_idx_: int + batch_end_idx_: int @dataclass class ModelData: - batch_tokens_: List[Tokens] = None - batch_mask_: List[Masks] = None - data_config_: List[ModelDataConfig] = None + batch_tokens_: List[Tokens] + batch_mask_: List[Masks] + data_config_: List[ModelDataConfig] - enable_checkpoint_: bool = True + enable_checkpoint_: bool class MLoRADataConfig: - adapter_name_: str = "" - adapter_type_: str = "" - - batch_start_idx_: int = -1 - batch_end_idx_: int = -1 - - expand_fn_: Callable[[List[Tokens], Optional[int]], - Tuple[List[Tokens], List[Masks]]] = None - loss_fn_: Callable[[torch.Tensor, torch.Tensor, - torch.Tensor], torch.Tensor] = None - - def __init__(self, adapter_name: str, adapter_type: str, - start_idx: int, end_idx: int, - expand_fn: Callable, loss_fn: Callable) -> None: + adapter_name_: str + adapter_type_: str + + batch_start_idx_: int + batch_end_idx_: int + + expand_fn_: Callable[ + [List[Tokens], Optional[int]], Tuple[List[Tokens], List[Masks]] + ] + loss_fn_: Callable[ + [torch.Tensor, torch.Tensor, torch.Tensor], Optional[torch.Tensor] + ] + + def __init__( + self, + adapter_name: str, + adapter_type: str, + start_idx: int, + end_idx: int, + expand_fn: Callable, + loss_fn: Callable, + ) -> None: self.adapter_name_ = adapter_name self.adapter_type_ = adapter_type self.batch_start_idx_ = start_idx @@ -97,28 +119,34 @@ def __init__(self, adapter_name: str, adapter_type: str, self.loss_fn_ = loss_fn def model_data_config(self) -> ModelDataConfig: - return ModelDataConfig(adapter_name_=self.adapter_name_, - adapter_type_=self.adapter_type_, - batch_start_idx_=self.batch_start_idx_, - batch_end_idx_=self.batch_end_idx_) + return ModelDataConfig( + adapter_name_=self.adapter_name_, + adapter_type_=self.adapter_type_, + batch_start_idx_=self.batch_start_idx_, + batch_end_idx_=self.batch_end_idx_, + ) class MLoRAData: # the datas: batch_size * token - batch_tokens_: List[Tokens] = None - batch_mask_: List[Masks] = None - data_config_: List[MLoRADataConfig] = None - - def __init__(self, - batch_tokens: List[Tokens], - batch_mask: List[Masks], - data_config: List[MLoRADataConfig]) -> None: + batch_tokens_: List[Tokens] + batch_mask_: List[Masks] + data_config_: List[MLoRADataConfig] + + def __init__( + self, + batch_tokens: List[Tokens], + batch_mask: List[Masks], + data_config: List[MLoRADataConfig], + ) -> None: self.batch_tokens_ = batch_tokens self.batch_mask_ = batch_mask self.data_config_ = data_config def model_data(self) -> ModelData: - return ModelData(batch_tokens_=self.batch_tokens_, - batch_mask_=self.batch_mask_, - data_config_=[config.model_data_config() for config in self.data_config_], - enable_checkpoint_=True) + return ModelData( + batch_tokens_=self.batch_tokens_, + batch_mask_=self.batch_mask_, + data_config_=[config.model_data_config() for config in self.data_config_], + enable_checkpoint_=True, + ) diff --git a/mlora/model/checkpoint/checkpoint.py b/mlora/model/checkpoint/checkpoint.py index c17497da..cae567a7 100644 --- a/mlora/model/checkpoint/checkpoint.py +++ b/mlora/model/checkpoint/checkpoint.py @@ -1,4 +1,4 @@ -from typing import Tuple, Callable +from typing import Callable, Tuple import torch diff --git a/mlora/model/llm/__init__.py b/mlora/model/llm/__init__.py index 4d136ba6..37208681 100644 --- a/mlora/model/llm/__init__.py +++ b/mlora/model/llm/__init__.py @@ -1,7 +1,4 @@ -from .model_llm import LLMModel from .model_llama import LlamaModel +from .model_llm import LLMModel -__all__ = [ - "LLMModel", - "LlamaModel" -] +__all__ = ["LLMModel", "LlamaModel"] diff --git a/mlora/model/llm/model_llama.py b/mlora/model/llm/model_llama.py index 667ff3bc..8cdb5023 100644 --- a/mlora/model/llm/model_llama.py +++ b/mlora/model/llm/model_llama.py @@ -1,13 +1,14 @@ -from mlora.model.modules import Embedding, Decoder, RMSNorm, OutputLayer, AdapterModel -from mlora.model.checkpoint import CheckpointRecomputeFunction -from mlora.model.args import LLMModelArgs, Masks, LinearInfo, ModelData -from mlora.profiler import nvtx_wrapper, set_backward_tracepoint +import logging +from collections import OrderedDict +from typing import Dict, List, Optional, Tuple, override import torch -import logging from transformers import AutoConfig, AutoModelForCausalLM -from collections import OrderedDict -from typing import Tuple, List, override + +from mlora.model.args import LinearInfo, LLMModelArgs, Masks, ModelData +from mlora.model.checkpoint import CheckpointRecomputeFunction +from mlora.model.modules import AdapterModel, Decoder, Embedding, OutputLayer, RMSNorm +from mlora.profiler import nvtx_wrapper, set_backward_tracepoint from .model_llm import LLMModel @@ -23,12 +24,14 @@ # -inf -inf -inf # -inf 0 -inf # -inf 0 0 -def precompute_mask(input_tokens: torch.Tensor, - n_heads: int, - device: str, - additional_mask: List[Masks] = None, - diagonal: int = 1, - dtype: torch.dtype = torch.float32) -> torch.Tensor: +def precompute_mask( + input_tokens: torch.Tensor, + n_heads: int, + device: str, + additional_mask: List[Masks] | None = None, + diagonal: int = 1, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: if input_tokens.dim() == 2: batch_size, seq_len = input_tokens.shape elif input_tokens.dim() == 3: @@ -37,13 +40,16 @@ def precompute_mask(input_tokens: torch.Tensor, raise Exception("input dim is not correct {input_tokens.dim}") TORCH_MIN_VALUE = torch.finfo(dtype).min - mask = torch.full((batch_size, n_heads, seq_len, seq_len), - TORCH_MIN_VALUE, device=device, dtype=dtype) + mask = torch.full( + (batch_size, n_heads, seq_len, seq_len), + TORCH_MIN_VALUE, + device=device, + dtype=dtype, + ) mask = torch.triu(mask, diagonal=diagonal) if additional_mask is not None: - masks_metric = torch.tensor( - additional_mask, dtype=torch.bool, device=device) + masks_metric = torch.tensor(additional_mask, dtype=torch.bool, device=device) masks_metric = masks_metric.view(batch_size, 1, 1, seq_len) masks_metric = masks_metric.expand(-1, n_heads, seq_len, -1) mask = torch.masked_fill(mask, masks_metric, TORCH_MIN_VALUE) @@ -53,11 +59,12 @@ def precompute_mask(input_tokens: torch.Tensor, return mask.to(device=device, dtype=dtype) -LlamaSequentialModuleIO = Tuple[torch.Tensor, # the input batch tokens - torch.Tensor, # the mask matrics - ModelData, # batch data config - bool # whether to use checkpoint - ] +LlamaSequentialModuleIO = Tuple[ + torch.Tensor, # the input batch tokens + torch.Tensor, # the mask matrics + ModelData, # batch data config + bool, # whether to use checkpoint +] LEN_LLAMA_SEQUENTIAL_MODULE_IO = 4 LlamaCompatibleModelTypes = ["mistral", "qwen2", "llama"] @@ -72,7 +79,6 @@ def name(self) -> str: return type(self.wrapper_module_).__name__ def forward(self, input: LlamaSequentialModuleIO) -> LlamaSequentialModuleIO: - assert isinstance(input, Tuple) assert len(input) == LEN_LLAMA_SEQUENTIAL_MODULE_IO assert isinstance(input[0], torch.Tensor) assert isinstance(input[1], torch.Tensor) @@ -85,28 +91,29 @@ def embedding_forward(): output = self.wrapper_module_.forward(input[0]) if input[-1]: output = output.requires_grad_(True) - return (output, ) + input[1:] + return (output,) + input[1:] def decoder_forward(): if input[-1]: output = CheckpointRecomputeFunction( - self.wrapper_module_.forward, *input[:-1]) + self.wrapper_module_.forward, *input[:-1] + ) set_backward_tracepoint(output.grad_fn, "b_checkpoint") else: output = self.wrapper_module_.forward(*input[:-1]) - return (output, ) + input[1:] + return (output,) + input[1:] @nvtx_wrapper("f_rmsnorm") def rmsnorm_forward(): output = self.wrapper_module_.forward(input[0]) set_backward_tracepoint(output.grad_fn, "b_rmsnorm") - return (output, ) + input[1:] + return (output,) + input[1:] @nvtx_wrapper("f_output") def output_layer_forward(): output = self.wrapper_module_.forward(input[0]) set_backward_tracepoint(output.grad_fn, "b_output") - return (output, ) + input[1:] + return (output,) + input[1:] forward_func_dict = { "Embedding": embedding_forward, @@ -116,17 +123,20 @@ def output_layer_forward(): } module_name = self.name() - assert module_name in forward_func_dict, f"error module name { + assert ( + module_name in forward_func_dict + ), f"error module name { module_name}" return forward_func_dict[module_name]() class LlamaModel(LLMModel): + seq_module_: torch.nn.Sequential + def __init__(self, args: LLMModelArgs): self.name_or_path_: str = args.name_or_path_ # sequential model - self.seq_module_: torch.nn.Sequential = None self.norm_eps_ = args.norm_eps_ @@ -142,12 +152,11 @@ def __init__(self, args: LLMModelArgs): @override def forward(self, input: ModelData) -> torch.Tensor: # train model or inference model: output is probs - tokens = torch.tensor(input.batch_tokens_, - dtype=torch.int64, - device=self.device_) + tokens = torch.tensor( + input.batch_tokens_, dtype=torch.int64, device=self.device_ + ) - mask = precompute_mask(tokens, self.n_heads_, - self.device_, input.batch_mask_) + mask = precompute_mask(tokens, self.n_heads_, self.device_, input.batch_mask_) if input.enable_checkpoint_: data = (tokens, mask, input, True) @@ -159,22 +168,31 @@ def forward(self, input: ModelData) -> torch.Tensor: return data[0] + @override @staticmethod - def from_pretrained(path: str, - device: str, - precision: str, - partial_model_to_device: List[int] = None) -> LLMModel: + def from_pretrained( + path: str, + device: str, + precision: str, + partial_model_to_device: Optional[List[int]] = None, + ) -> LLMModel: # create the device map for parallelism - def create_device_map(): + def create_device_map() -> str | Dict[str, str]: + device_map: str | Dict[str, str] if partial_model_to_device is None: device_map = device else: config = AutoConfig.from_pretrained(path) # Be careful, this is hard coded. - weight_map = ["model.embed_tokens", - *[f"model.layers.{layer_id}" for layer_id in range(0, config.num_hidden_layers)], - "model.norm", - "lm_head"] + weight_map = [ + "model.embed_tokens", + *[ + f"model.layers.{layer_id}" + for layer_id in range(0, config.num_hidden_layers) + ], + "model.norm", + "lm_head", + ] device_map = {map_item: "disk" for map_item in weight_map} for partial_weight in partial_model_to_device: device_map[weight_map[partial_weight]] = device @@ -184,7 +202,7 @@ def create_device_map(): load_type_dict = { "fp32": torch.float32, "fp16": torch.float16, - "bf16": torch.bfloat16 + "bf16": torch.bfloat16, } additional_load_args = { @@ -201,13 +219,15 @@ def create_device_map(): load_8bit = precision == "int8" from transformers import BitsAndBytesConfig + additional_load_args["torch_dtype"] = torch.float32 additional_load_args["quantization_config"] = BitsAndBytesConfig( load_in_4bit=load_4bit, load_in_8bit=load_8bit, # int8 only for GPU, fp32 for cpu llm_int8_enable_fp32_cpu_offload=True, - # do not hold the fp16 part, when forward and backward need to convert int8 to fp16 + # do not hold the fp16 part + # when forward and backward need to convert int8 to fp16 llm_int8_has_fp16_weight=False, # only for qlora 4bit bnb_4bit_compute_dtype=torch.float16, @@ -215,15 +235,15 @@ def create_device_map(): bnb_4bit_quant_type=precision, ) - llama_model = AutoModelForCausalLM.from_pretrained( - path, **additional_load_args) + llama_model = AutoModelForCausalLM.from_pretrained(path, **additional_load_args) if llama_model.config.model_type not in LlamaCompatibleModelTypes: assert f"unsupported model type { llama_model.config.model_type}, loading with llama compatible mode." logging.info( - f"loading llama compatible model - {llama_model.config.model_type}") + f"loading llama compatible model - {llama_model.config.model_type}" + ) llama_args = LLMModelArgs(llama_model.config) if llama_args.pad_token_id_ is None: @@ -232,31 +252,47 @@ def create_device_map(): llama_args.dtype_ = llama_model.dtype # load model from pretrained large model - model = LlamaModel.convert_model_from_huggingface( - llama_model, llama_args) + model = LlamaModel.convert_model_from_huggingface(llama_model, llama_args) return model @staticmethod - def convert_model_from_huggingface(llama_model: AutoModelForCausalLM, - llama_args: LLMModelArgs): + def convert_model_from_huggingface( + llama_model: AutoModelForCausalLM, llama_args: LLMModelArgs + ): llama_model.requires_grad_(False) - seq_model = OrderedDict() + seq_model: OrderedDict[str, torch.nn.Module] = OrderedDict() - seq_model.update({"embedding": LlamaSequentialWrapper(Embedding( - llama_model.model.embed_tokens.weight, llama_args.pad_token_id_))}) + seq_model.update( + { + "embedding": LlamaSequentialWrapper( + Embedding( + llama_model.model.embed_tokens.weight, llama_args.pad_token_id_ + ) + ) + } + ) for idx, target_layer in enumerate(llama_model.model.layers): decoder = Decoder(idx, llama_args) decoder.from_pretrained(target_layer, llama_args.norm_eps_) - seq_model.update( - {f"layer{idx}": LlamaSequentialWrapper(decoder)}) - - seq_model.update({"norm": LlamaSequentialWrapper(RMSNorm( - llama_model.model.norm.weight, llama_args.norm_eps_))}) - seq_model.update({"output": LlamaSequentialWrapper( - OutputLayer(llama_model.lm_head.weight, llama_args))}) + seq_model.update({f"layer{idx}": LlamaSequentialWrapper(decoder)}) + + seq_model.update( + { + "norm": LlamaSequentialWrapper( + RMSNorm(llama_model.model.norm.weight, llama_args.norm_eps_) + ) + } + ) + seq_model.update( + { + "output": LlamaSequentialWrapper( + OutputLayer(llama_model.lm_head.weight, llama_args) + ) + } + ) model = LlamaModel(llama_args) model.seq_module_ = torch.nn.Sequential(seq_model) diff --git a/mlora/model/llm/model_llm.py b/mlora/model/llm/model_llm.py index 7feb2d4e..26c14f68 100644 --- a/mlora/model/llm/model_llm.py +++ b/mlora/model/llm/model_llm.py @@ -1,29 +1,35 @@ +from abc import ABCMeta, abstractmethod +from collections import OrderedDict +from typing import List, Optional + from mlora.model.args import LinearInfo, ModelData from mlora.model.modules import AdapterModel -from collections import OrderedDict -from abc import ABCMeta, abstractmethod - class LLMModel(metaclass=ABCMeta): - name_or_path_: str = "" - device_: str = "" - vocab_size_: int = -1 - n_heads_: int = -1 - dim_: int = -1 + name_or_path_: str + device_: str + vocab_size_: int + n_heads_: int + dim_: int + + @abstractmethod + def forward(self, input: ModelData): ... + @staticmethod @abstractmethod - def forward(self, input: ModelData): - pass + def from_pretrained( + path: str, + device: str, + precision: str, + partial_model_to_device: Optional[List[int]] = None, + ) -> "LLMModel": ... @abstractmethod - def load_adapter(self, adapter_model: AdapterModel): - pass + def load_adapter(self, adapter_model: AdapterModel): ... @abstractmethod - def offload_adapter(self, adapter_name: str): - pass + def offload_adapter(self, adapter_name: str): ... @abstractmethod - def linears_info(self) -> OrderedDict[str, LinearInfo]: - pass + def linears_info(self) -> OrderedDict[str, LinearInfo]: ... diff --git a/mlora/model/modules/__init__.py b/mlora/model/modules/__init__.py index 9dec097e..9204fd21 100644 --- a/mlora/model/modules/__init__.py +++ b/mlora/model/modules/__init__.py @@ -1,12 +1,12 @@ +from .adapter import Adapter, AdapterModel +from .attention import Attention +from .decoder import Decoder from .embedding import Embedding from .linear import Linear +from .lora import LoRA, LoRAFunction +from .mlp import MLP from .output_layer import OutputLayer -from .adapter import Adapter, AdapterModel from .rms_norm import RMSNorm -from .lora import LoRAFunction, LoRA -from .attention import Attention -from .mlp import MLP -from .decoder import Decoder __all__ = [ "Embedding", @@ -19,5 +19,5 @@ "LoRAFunction", "Attention", "MLP", - "Decoder" + "Decoder", ] diff --git a/mlora/model/modules/adapter.py b/mlora/model/modules/adapter.py index 86dae059..6f0e4b9e 100644 --- a/mlora/model/modules/adapter.py +++ b/mlora/model/modules/adapter.py @@ -1,12 +1,12 @@ -import torch - -from typing import Dict, List from abc import abstractmethod +from typing import Dict, List + +import torch class Adapter(torch.nn.Module): - adapter_type_: str = "" - adapter_name_: str = "" + adapter_type_: str + adapter_name_: str def __init__(self, adapter_type: str, adapter_name: str): super().__init__() @@ -15,8 +15,7 @@ def __init__(self, adapter_type: str, adapter_name: str): self.adapter_name_ = adapter_name @abstractmethod - def get_tensors(self) -> List[torch.Tensor]: - ... + def get_tensors(self) -> List[torch.Tensor]: ... def disable_grad(self): for tensor in self.get_tensors(): diff --git a/mlora/model/modules/attention.py b/mlora/model/modules/attention.py index 4ce98243..038547db 100644 --- a/mlora/model/modules/attention.py +++ b/mlora/model/modules/attention.py @@ -1,25 +1,27 @@ -from mlora.model.modules import AdapterModel -from mlora.model.args import LLMModelArgs, LinearInfo, ModelData -from mlora.profiler import nvtx_range, set_backward_tracepoint - import math +from collections import OrderedDict +from typing import Dict, Optional, Tuple + import torch import torch.nn.functional as F -from collections import OrderedDict -from typing import Tuple, Dict, Optional + +from mlora.model.args import LinearInfo, LLMModelArgs, ModelData +from mlora.model.modules import AdapterModel +from mlora.profiler import nvtx_range, set_backward_tracepoint from .linear import Linear def rotate_half(x: torch.Tensor) -> torch.Tensor: # see the above ref - left_part = x[..., :x.shape[-1] // 2] - right_part = x[..., x.shape[-1] // 2:] + left_part = x[..., : x.shape[-1] // 2] + right_part = x[..., x.shape[-1] // 2 :] return torch.cat((-right_part, left_part), dim=-1) -def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, - cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def apply_rotary_emb( + xq: torch.Tensor, xk: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: # data shape is: batch_size * n_head * seq_len * n_dim xq_embed = (xq * cos) + (rotate_half(xq) * sin) xk_embed = (xk * cos) + (rotate_half(xk) * sin) @@ -30,17 +32,17 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: batch, n_kv_heads, seq_len, head_dim = x.shape if n_rep == 1: return x - x = x[:, :, None, :, :].expand( - batch, n_kv_heads, n_rep, seq_len, head_dim) + x = x[:, :, None, :, :].expand(batch, n_kv_heads, n_rep, seq_len, head_dim) x = x.reshape(batch, n_kv_heads * n_rep, seq_len, head_dim) return x -def precompute_rope_angle(dim: int, seq_len: int, theta: float, device: str) -> Tuple[torch.Tensor, torch.Tensor]: +def precompute_rope_angle( + dim: int, seq_len: int, theta: float, device: str +) -> Tuple[torch.Tensor, torch.Tensor]: # this implement is different with facebooksearch/llama # ref: https://github.com/huggingface/transformers/issues/25199 - angles = 1.0 / \ - (theta ** (torch.arange(0, dim, 2).float().to(device) / dim)) + angles = 1.0 / (theta ** (torch.arange(0, dim, 2).float().to(device) / dim)) seq = torch.arange(seq_len, device=device, dtype=angles.dtype) emb = torch.outer(seq, angles) emb = torch.cat((emb, emb), dim=-1) @@ -51,31 +53,37 @@ def precompute_rope_angle(dim: int, seq_len: int, theta: float, device: str) -> @torch.jit.script -def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - attention_score = torch.matmul( - query, key.transpose(2, 3)) / math.sqrt(query.size(-1)) +def scaled_dot_product_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + attention_score = torch.matmul(query, key.transpose(2, 3)) / math.sqrt( + query.size(-1) + ) if attention_mask is not None: attention_score = attention_score + attention_mask - attention_score = F.softmax( - attention_score, dim=-1, dtype=torch.float32).to(value.dtype) + attention_score = F.softmax(attention_score, dim=-1, dtype=torch.float32).to( + value.dtype + ) attention_score = torch.matmul(attention_score, value) attention_score = attention_score.transpose(1, 2).contiguous() return attention_score class Attention(torch.nn.Module): + wq_: Linear + wk_: Linear + wv_: Linear + wo_: Linear + def __init__(self, layer_id: int, args: LLMModelArgs): super().__init__() # use layer id to local the adapter self.layer_id_: int = layer_id - self.wq_: Linear = None # dim * dim - self.wk_: Linear = None # dim * dim - self.wv_: Linear = None # dim * dim - self.wo_: Linear = None # dim * dim - self.n_heads_ = args.n_heads_ self.n_kv_heads_ = args.n_kv_heads_ self.head_dim_ = args.dim_ // args.n_heads_ @@ -83,13 +91,13 @@ def __init__(self, layer_id: int, args: LLMModelArgs): # rope angle cos and sin self.cos_, self.sin_ = precompute_rope_angle( - args.dim_ // args.n_heads_, args.max_seq_len_, - args.rope_theta_, args.device_) + args.dim_ // args.n_heads_, + args.max_seq_len_, + args.rope_theta_, + args.device_, + ) - def forward(self, - data: torch.Tensor, - mask: torch.Tensor, - input_args: ModelData): + def forward(self, data: torch.Tensor, mask: torch.Tensor, input_args: ModelData): batch_size, max_seq_len, _ = data.shape xq = self.wq_.forward(data, input_args) @@ -98,12 +106,15 @@ def forward(self, # conver shape to multi head # the shape is batch_size * number_of_head * seq_len * dim_of_head - xq = xq.view(batch_size, max_seq_len, self.n_heads_, - self.head_dim_).transpose(1, 2) - xk = xk.view(batch_size, max_seq_len, self.n_kv_heads_, - self.head_dim_).transpose(1, 2) - xv = xv.view(batch_size, max_seq_len, self.n_kv_heads_, - self.head_dim_).transpose(1, 2) + xq = xq.view(batch_size, max_seq_len, self.n_heads_, self.head_dim_).transpose( + 1, 2 + ) + xk = xk.view( + batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_ + ).transpose(1, 2) + xv = xv.view( + batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_ + ).transpose(1, 2) # apply rotary embedding assert xq.dtype == xk.dtype @@ -146,7 +157,7 @@ def linear_dict(self) -> Dict[str, Linear]: f"layers.{self.layer_id_}.self_attn.q_proj": self.wq_, f"layers.{self.layer_id_}.self_attn.k_proj": self.wk_, f"layers.{self.layer_id_}.self_attn.v_proj": self.wv_, - f"layers.{self.layer_id_}.self_attn.o_proj": self.wo_ + f"layers.{self.layer_id_}.self_attn.o_proj": self.wo_, } def load_adapter(self, adapter_model: AdapterModel): @@ -164,8 +175,10 @@ def linears_info(self) -> OrderedDict[str, LinearInfo]: for name, module in self.linear_dict.items(): assert isinstance(module, Linear) - ret_val[name] = LinearInfo(name_=name, - in_dim_=module.weight_.in_features, - out_dim_=module.weight_.out_features) + ret_val[name] = LinearInfo( + name_=name, + in_dim_=module.weight_.in_features, + out_dim_=module.weight_.out_features, + ) return ret_val diff --git a/mlora/model/modules/decoder.py b/mlora/model/modules/decoder.py index 39ad95e3..fd913200 100644 --- a/mlora/model/modules/decoder.py +++ b/mlora/model/modules/decoder.py @@ -1,9 +1,10 @@ -from mlora.model.modules import AdapterModel -from mlora.model.args import LLMModelArgs, LinearInfo, ModelData -from mlora.profiler import nvtx_range, set_backward_tracepoint +from collections import OrderedDict import torch -from collections import OrderedDict + +from mlora.model.args import LinearInfo, LLMModelArgs, ModelData +from mlora.model.modules import AdapterModel +from mlora.profiler import nvtx_range, set_backward_tracepoint from .attention import Attention from .mlp import MLP @@ -11,26 +12,24 @@ class Decoder(torch.nn.Module): + attn_norm_: RMSNorm + mlp_norm_: RMSNorm + def __init__(self, layer_id: int, args: LLMModelArgs): super().__init__() self.layer_id_ = layer_id - self.attn_norm_: RMSNorm = None - self.mlp_norm_: RMSNorm = None - self.attn_: Attention = Attention(layer_id, args) self.mlp_: MLP = MLP(layer_id) - def forward(self, - hidden_states: torch.Tensor, - mask: torch.Tensor, - input_args: ModelData): + def forward( + self, hidden_states: torch.Tensor, mask: torch.Tensor, input_args: ModelData + ): # Attention with nvtx_range("f_attention_norm"): attn_norm_output = self.attn_norm_.forward(hidden_states) - set_backward_tracepoint( - attn_norm_output.grad_fn, "b_attention_norm") + set_backward_tracepoint(attn_norm_output.grad_fn, "b_attention_norm") attn_output = self.attn_.forward(attn_norm_output, mask, input_args) @@ -50,13 +49,13 @@ def forward(self, return hidden_states - def from_pretrained(self, - transformer_layer: torch.nn.Module, - norm_eps: float) -> None: + def from_pretrained( + self, transformer_layer: torch.nn.Module, norm_eps: float + ) -> None: self.mlp_norm_ = RMSNorm( - transformer_layer.post_attention_layernorm.weight, norm_eps) - self.attn_norm_ = RMSNorm( - transformer_layer.input_layernorm.weight, norm_eps) + transformer_layer.post_attention_layernorm.weight, norm_eps + ) + self.attn_norm_ = RMSNorm(transformer_layer.input_layernorm.weight, norm_eps) self.attn_.from_pretrained(transformer_layer.self_attn) self.mlp_.from_pretrained(transformer_layer.mlp) @@ -70,7 +69,7 @@ def offload_adapter(self, adapter_name: str): self.mlp_.offload_adapter(adapter_name) def linears_info(self) -> OrderedDict[str, LinearInfo]: - ret_val = OrderedDict() + ret_val: OrderedDict[str, LinearInfo] = OrderedDict() ret_val.update(self.attn_.linears_info()) ret_val.update(self.mlp_.linears_info()) return ret_val diff --git a/mlora/model/modules/embedding.py b/mlora/model/modules/embedding.py index 1ab81813..1a2c25e3 100644 --- a/mlora/model/modules/embedding.py +++ b/mlora/model/modules/embedding.py @@ -9,6 +9,5 @@ def __init__(self, embedding: torch.Tensor, pad_token: int): self.padding_idx_: int = pad_token def forward(self, tokens: torch.Tensor) -> torch.Tensor: - data = F.embedding(tokens, self.token_embedding_, - padding_idx=self.padding_idx_) + data = F.embedding(tokens, self.token_embedding_, padding_idx=self.padding_idx_) return data diff --git a/mlora/model/modules/linear.py b/mlora/model/modules/linear.py index a9bf1365..c6e46a68 100644 --- a/mlora/model/modules/linear.py +++ b/mlora/model/modules/linear.py @@ -1,10 +1,10 @@ -from mlora.model.args import ModelData -from mlora.profiler import nvtx_range, set_backward_tracepoint +from typing import Dict, List, Optional, Tuple -import torch import bitsandbytes +import torch -from typing import Dict, Tuple, List +from mlora.model.args import ModelData +from mlora.profiler import nvtx_range, set_backward_tracepoint from .adapter import Adapter from .lora import LoRA, LoRAFunction @@ -18,7 +18,8 @@ def __init__(self, weight: torch.nn.Module): if not isinstance(weight, torch.nn.Linear): assert isinstance(weight, bitsandbytes.nn.Linear8bitLt) or isinstance( - weight, bitsandbytes.nn.Linear4bit), f"error type - {type(weight)}." + weight, bitsandbytes.nn.Linear4bit + ), f"error type - {type(weight)}." else: weight.requires_grad_(False) @@ -38,32 +39,36 @@ def forward(self, data: torch.Tensor, input_args: ModelData) -> torch.Tensor: return self.__lora_forward(data, input_args, result) - def __lora_forward(self, - data: torch.Tensor, - input_args: ModelData, - result: torch.Tensor) -> torch.Tensor: + def __lora_forward( + self, data: torch.Tensor, input_args: ModelData, result: torch.Tensor + ) -> torch.Tensor: # split the data and result - dropouts: List[float] = [] - scalings: List[float] = [] - loras: Tuple[torch.Tensor] = () + dropouts: List[Optional[float]] = [] + scalings: List[Optional[float]] = [] + loras: Tuple[torch.Tensor | None, ...] = () for lora_config in input_args.data_config_: adapter_name = lora_config.adapter_name_ - if adapter_name not in self.adapters_ or not isinstance(self.adapters_[adapter_name], LoRA): + if adapter_name not in self.adapters_ or not isinstance( + self.adapters_[adapter_name], LoRA + ): loras += (None, None) dropouts.append(None) scalings.append(None) continue - loras += (self.adapters_[adapter_name].lora_a_, - self.adapters_[adapter_name].lora_b_) + loras += ( + self.adapters_[adapter_name].lora_a_, + self.adapters_[adapter_name].lora_b_, + ) dropouts.append(self.adapters_[adapter_name].dropout_) scalings.append(self.adapters_[adapter_name].scaling_) with nvtx_range("f_lora"): result = LoRAFunction.apply( - result, data, input_args, dropouts, scalings, *loras) + result, data, input_args, dropouts, scalings, *loras + ) set_backward_tracepoint(result.grad_fn, "b_lora") return result diff --git a/mlora/model/modules/lora.py b/mlora/model/modules/lora.py index d7f4392c..c78806b5 100644 --- a/mlora/model/modules/lora.py +++ b/mlora/model/modules/lora.py @@ -1,12 +1,12 @@ -from mlora.model.args import ModelData - import math +from typing import Any, Dict, List, Tuple, override + import torch import torch.nn.functional as F -from typing import Dict, List, override -from .adapter import Adapter +from mlora.model.args import ModelData +from .adapter import Adapter g_cached_range_tensor: Dict[torch.device, torch.Tensor] = {} # also max batch size @@ -19,30 +19,31 @@ def get_range_tensor(device: torch.device, batch_size: int = 1024): if device not in g_cached_range_tensor or batch_size > g_max_range: g_max_range = g_max_range if g_max_range > batch_size else batch_size g_cached_range_tensor[device] = torch.arange( - 0, g_max_range, step=1, device=device) + 0, g_max_range, step=1, device=device + ) return g_cached_range_tensor[device] class LoRAFunction(torch.autograd.Function): @staticmethod - def forward(ctx, - result: torch.Tensor, - data: torch.Tensor, - input_args: ModelData, - dropouts: List[float], - scalings: List[float], - *args): + def forward( + ctx, + result: torch.Tensor, + data: torch.Tensor, + input_args: ModelData, + dropouts: List[float], + scalings: List[float], + *args, + ): # the lora module is f32 precision data = data.to(torch.float32) - save_inputs = (data,) + save_inputs: Tuple[torch.Tensor | None, ...] = (data,) lora_range = get_range_tensor(data.device, data.shape[0]) - for lora_a, lora_b, lora_config, dropout, scaling in zip(args[::2], - args[1::2], - input_args.data_config_, - dropouts, - scalings): + for lora_a, lora_b, lora_config, dropout, scaling in zip( + args[::2], args[1::2], input_args.data_config_, dropouts, scalings + ): assert not ((lora_a is None) ^ (lora_b is None)) if lora_a is None and lora_b is None: save_inputs += (None, None, None) @@ -57,11 +58,13 @@ def forward(ctx, end_idx = lora_config.batch_end_idx_ # must ensure the dropout is not zero - # is dropout == 0, dropdata is a data's referece, so the data will be changed + # is dropout == 0 + # dropdata is a data's referece, so the data will be changed assert dropout != 0 drop_data = F.dropout( - data[start_idx:end_idx], p=dropout, training=True, inplace=False) + data[start_idx:end_idx], p=dropout, training=True, inplace=False + ) drop_data.mul_(scaling) drop_data = drop_data @ lora_a.transpose(0, 1) @@ -71,7 +74,8 @@ def forward(ctx, lora_data = lora_data.to(result.dtype) result.index_add_( - dim=0, index=lora_range[start_idx:end_idx], source=lora_data) + dim=0, index=lora_range[start_idx:end_idx], source=lora_data + ) save_inputs += (lora_a, lora_b, drop_data) @@ -83,13 +87,14 @@ def forward(ctx, return result @staticmethod - def backward(ctx, grad_output: torch.Tensor): + def backward(ctx: Any, *grad_outputs: Any) -> Any: + grad_output: torch.Tensor = grad_outputs[0] grad_result = None - grad_data = None + grad_data: torch.Tensor | None = None grad_input_args = None grad_dropouts = None grad_scalings = None - grad_loras = () + grad_loras: Tuple[torch.Tensor | None, ...] = () data = ctx.saved_tensors[0] loras = ctx.saved_tensors[1:] @@ -102,20 +107,25 @@ def backward(ctx, grad_output: torch.Tensor): # the lora module is fp32 precision grad_output = grad_output.to(torch.float32) lora_range = get_range_tensor( - grad_output.device, batch_size=grad_output.shape[0]) - for lora_a, lora_b, drop_data, dropout, scaling, lora_config in zip(loras[::3], - loras[1::3], - loras[2::3], - ctx.dropouts, - ctx.scalings, - ctx.input_args.data_config_): + grad_output.device, batch_size=grad_output.shape[0] + ) + for lora_a, lora_b, drop_data, dropout, scaling, lora_config in zip( + loras[::3], + loras[1::3], + loras[2::3], + ctx.dropouts, + ctx.scalings, + ctx.input_args.data_config_, + ): start_idx = lora_config.batch_start_idx_ end_idx = lora_config.batch_end_idx_ assert not ((lora_a is None) ^ (lora_b is None)) if lora_a is None and lora_b is None: grad_loras += (None, None) - grad_data.index_fill_( - dim=0, index=lora_range[start_idx:end_idx], value=0) + if grad_data is not None: + grad_data.index_fill_( + dim=0, index=lora_range[start_idx:end_idx], value=0 + ) continue # lora_data shape is batch_size * seq_len * in_dim @@ -125,7 +135,7 @@ def backward(ctx, grad_output: torch.Tensor): # bstage shape is batch_size * seq_len * r bstage = grad_y @ lora_b - bstage *= (scaling / (1 - dropout)) + bstage *= scaling / (1 - dropout) grad_a = torch.sum(bstage.transpose(1, 2) @ lora_data, dim=0) grad_b = torch.sum(grad_y.transpose(1, 2) @ drop_data, dim=0) @@ -135,35 +145,65 @@ def backward(ctx, grad_output: torch.Tensor): if grad_data is not None: grad_x = bstage @ lora_a grad_data.index_copy_( - dim=0, index=lora_range[start_idx:end_idx], source=grad_x) + dim=0, index=lora_range[start_idx:end_idx], source=grad_x + ) - return grad_result, grad_data, grad_input_args, grad_dropouts, grad_scalings, *grad_loras + return ( + grad_result, + grad_data, + grad_input_args, + grad_dropouts, + grad_scalings, + *grad_loras, + ) class LoRA(Adapter): - def __init__(self, adapter_name: str, in_dim: int, out_dim: int, r: int, alpha: int, dropout: float): + def __init__( + self, + adapter_name: str, + in_dim: int, + out_dim: int, + r: int, + alpha: int, + dropout: float, + ): super().__init__("lora", adapter_name) self.lora_a_: torch.Tensor = torch.zeros( - size=(r, in_dim), device="cpu", requires_grad=True, dtype=torch.float32) + size=(r, in_dim), device="cpu", requires_grad=True, dtype=torch.float32 + ) self.lora_b_: torch.Tensor = torch.zeros( - size=(out_dim, r), device="cpu", requires_grad=True, dtype=torch.float32) + size=(out_dim, r), device="cpu", requires_grad=True, dtype=torch.float32 + ) self.r_: int = r self.alpha_: int = alpha self.dropout_: float = dropout self.scaling_: float = alpha / r - def init_weight(self, lora_a: torch.Tensor = None, lora_b: torch.Tensor = None): + def init_weight( + self, lora_a: torch.Tensor | None = None, lora_b: torch.Tensor | None = None + ): if lora_a is None: torch.nn.init.kaiming_normal_(self.lora_a_, a=math.sqrt(5)) else: - self.lora_a_ = lora_a.to("cpu").detach().clone().to( - dtype=torch.float32).requires_grad_(True) + self.lora_a_ = ( + lora_a.to("cpu") + .detach() + .clone() + .to(dtype=torch.float32) + .requires_grad_(True) + ) if lora_b is not None: - self.lora_b_ = lora_b.to("cpu").detach().clone().to( - dtype=torch.float32).requires_grad_(True) + self.lora_b_ = ( + lora_b.to("cpu") + .detach() + .clone() + .to(dtype=torch.float32) + .requires_grad_(True) + ) @override def get_tensors(self) -> List[torch.Tensor]: diff --git a/mlora/model/modules/mlp.py b/mlora/model/modules/mlp.py index 7c2c763c..8b0b29d2 100644 --- a/mlora/model/modules/mlp.py +++ b/mlora/model/modules/mlp.py @@ -1,26 +1,27 @@ -from mlora.model.modules import AdapterModel -from mlora.model.args import LinearInfo, ModelData -from mlora.profiler import nvtx_range, set_backward_tracepoint +from collections import OrderedDict +from typing import Dict import torch import torch.nn.functional as F -from collections import OrderedDict -from typing import Dict + +from mlora.model.args import LinearInfo, ModelData +from mlora.model.modules import AdapterModel +from mlora.profiler import nvtx_range, set_backward_tracepoint from .linear import Linear class MLP(torch.nn.Module): + gate_: Linear # also gate FNN * dim + down_: Linear # also down dim * FNN + up_: Linear # also up FNN * dim + def __init__(self, layer_id: int): super().__init__() # use layer id to local the adapter self.layer_id_ = layer_id - self.gate_: Linear = None # also gate FNN * dim - self.down_: Linear = None # also down dim * FNN - self.up_: Linear = None # also up FNN * dim - def forward(self, data: torch.Tensor, input_args: ModelData) -> torch.Tensor: # feed forward fully connected with nvtx_range("f_mlp"): @@ -44,7 +45,7 @@ def linear_dict(self) -> Dict[str, Linear]: return { f"layers.{self.layer_id_}.mlp.gate_proj": self.gate_, f"layers.{self.layer_id_}.mlp.down_proj": self.down_, - f"layers.{self.layer_id_}.mlp.up_proj": self.up_ + f"layers.{self.layer_id_}.mlp.up_proj": self.up_, } def load_adapter(self, adapter_model: AdapterModel): @@ -62,8 +63,10 @@ def linears_info(self) -> OrderedDict[str, LinearInfo]: for name, module in self.linear_dict.items(): assert isinstance(module, Linear) - ret_val[name] = LinearInfo(name_=name, - in_dim_=module.weight_.in_features, - out_dim_=module.weight_.out_features) + ret_val[name] = LinearInfo( + name_=name, + in_dim_=module.weight_.in_features, + out_dim_=module.weight_.out_features, + ) return ret_val diff --git a/mlora/model/modules/output_layer.py b/mlora/model/modules/output_layer.py index 2656d268..217c5dc6 100644 --- a/mlora/model/modules/output_layer.py +++ b/mlora/model/modules/output_layer.py @@ -1,16 +1,21 @@ -from mlora.model.args import LLMModelArgs - import torch +from mlora.model.args import LLMModelArgs + class OutputLayer(torch.nn.Module): def __init__(self, weight: torch.Tensor, args: LLMModelArgs): super().__init__() self.lm_head_ = torch.nn.Linear( - args.dim_, args.vocab_size_, bias=False, device=args.device_, dtype=args.dtype_) + args.dim_, + args.vocab_size_, + bias=False, + device=args.device_, + dtype=args.dtype_, + ) with torch.no_grad(): - if weight.device == torch.device('meta'): + if weight.device == torch.device("meta"): self.lm_head_.weight = weight else: self.lm_head_.weight.copy_(weight) diff --git a/mlora/model/tokenizer/__init__.py b/mlora/model/tokenizer/__init__.py index f7d03115..562646c8 100644 --- a/mlora/model/tokenizer/__init__.py +++ b/mlora/model/tokenizer/__init__.py @@ -1,5 +1,3 @@ from .tokenizer import Tokenizer -__all__ = [ - "Tokenizer" -] +__all__ = ["Tokenizer"] diff --git a/mlora/model/tokenizer/tokenizer.py b/mlora/model/tokenizer/tokenizer.py index 6d330a16..eb2eb614 100644 --- a/mlora/model/tokenizer/tokenizer.py +++ b/mlora/model/tokenizer/tokenizer.py @@ -1,13 +1,15 @@ -from mlora.model.args import Tokens, Masks +from typing import Tuple from transformers import AutoTokenizer -from typing import Tuple + +from mlora.model.args import Masks, Tokens class Tokenizer: def __init__(self, model_path: str): self.tokenizer_ = AutoTokenizer.from_pretrained( - model_path, trust_remote_code=True) + model_path, trust_remote_code=True + ) self.n_words_ = self.tokenizer_.vocab_size self.bos_id_ = self.tokenizer_.bos_token_id self.eos_id_ = self.tokenizer_.eos_token_id @@ -19,7 +21,7 @@ def __init__(self, model_path: str): def encode(self, data: str, bos=True, eos=True, cutoff_len=4096) -> Tokens: tokens = self.tokenizer_.encode(data, add_special_tokens=False) - tokens = tokens[:cutoff_len - int(bos) - int(eos)] + tokens = tokens[: cutoff_len - int(bos) - int(eos)] if bos: tokens = [self.bos_id_] + tokens if eos: diff --git a/mlora/pipeline/function.py b/mlora/pipeline/function.py deleted file mode 100644 index c5bfb900..00000000 --- a/mlora/pipeline/function.py +++ /dev/null @@ -1,72 +0,0 @@ -from mlora.pipeline.messages import PipeMessage, PipeMessageType -from mlora.pipeline.transport import Transport -from mlora.model.args import MLoRABatchData - -import logging -import torch - - -class SendOperator(torch.autograd.Function): - # helper to reduce the activation memory - @staticmethod - def forward(ctx, - phony: torch.Tensor, - tensor_data: torch.Tensor, - transport: Transport, - msg_id: int, - input_args: MLoRABatchData): - assert isinstance(tensor_data, torch.Tensor) - - msg = PipeMessage(src_=transport.worker_name, - dst_=transport.next_worker_name, - msg_type_=PipeMessageType.ACTIVATIONS, - msg_id_=msg_id, - tensor_data_=tensor_data, - batch_data_=input_args) - transport.send_message(msg, False) - - return phony - - @staticmethod - def backward(ctx, grad_output): - assert ctx.grad_from_next_worker is not None - - return (None, ctx.grad_from_next_worker, None, None, None) - - -class RecvOperator(torch.autograd.Function): - # backward will auto send the grad to pre worker - @staticmethod - def forward(ctx, - phony: torch.Tensor, - transport: Transport, - msg: PipeMessage) -> torch.Tensor: - assert msg.msg_type_ == PipeMessageType.ACTIVATIONS - assert isinstance(msg.tensor_data_, torch.Tensor) - - ctx.msg_id_ = msg.msg_id_ - ctx.transport_ = transport - ctx.batch_data_ = msg.batch_data_ - - return msg.tensor_data_ * phony - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - # now only signle grad can be support - assert isinstance(grad_output, torch.Tensor) - - transport: Transport = ctx.transport_ - if hasattr(ctx, 'pre_stage_fn') and ctx.pre_stage_fn is not None: - ctx.pre_stage_fn() - - logging.debug(f"Send the gradients to {transport.prev_worker_name}") - transport.send_message(PipeMessage( - src_=transport.worker_name, - dst_=transport.prev_worker_name, - msg_type_=PipeMessageType.GRADIENTS, - msg_id_=ctx.msg_id_, - tensor_data_=grad_output, - batch_data_=ctx.batch_data_, - )) - - return (None, None, None) diff --git a/mlora/pipeline/messages.py b/mlora/pipeline/messages.py deleted file mode 100644 index 4ea3bf0c..00000000 --- a/mlora/pipeline/messages.py +++ /dev/null @@ -1,24 +0,0 @@ -from mlora.model.args import MLoRABatchData - -import torch - - -from dataclasses import dataclass -from enum import Enum - - -class PipeMessageType(Enum): - ACTIVATIONS = "ACTIVATIONS" - GRADIENTS = "GRADIENTS" - - -@dataclass() -class PipeMessage: - src_: str - dst_: str - - msg_type_: PipeMessageType - msg_id_: int - - tensor_data_: torch.Tensor - batch_data_: MLoRABatchData diff --git a/mlora/pipeline/pipe.py b/mlora/pipeline/pipe.py deleted file mode 100644 index 0dd9f3c1..00000000 --- a/mlora/pipeline/pipe.py +++ /dev/null @@ -1,332 +0,0 @@ -from mlora.pipeline.queue import DeviceSwapQueue -from mlora.pipeline.transport import RpcTransport -from mlora.pipeline.stream import CudaStream -from mlora.pipeline.messages import PipeMessage, PipeMessageType -from mlora.pipeline.function import RecvOperator, SendOperator -from mlora.model.llm.model_llm import LLMModel, precompute_mask -from mlora.model.args import MLoRADataConfig, MLoRABatchData -from mlora.executor.executor import MultiTrainerContext -from mlora.config import MLoRAConfig - -import torch -import uuid -import logging -import os -import json -import time - -from enum import Enum, auto -from typing import Dict, List - - -class PipelineDispatcher: - None - - -class WorkerRole(Enum): - HEAD = auto() - MID = auto() - TAIL = auto() - - -class Pipe(): - world_size_: int = -1 - rank_: int = -1 - device_: torch.device = None - role_: WorkerRole = None - - forward_stop_: bool = False - input_stop_: bool = False - forward_cnt_: int = 0 - backward_cache_: Dict[int, torch.Tensor] = {} - stop_signal_: torch.tensor = None - - model_partition_: torch.nn.Sequential = torch.nn.Sequential() - dispatcher_: PipelineDispatcher = None - n_heads_: int = -1 - - config_: MLoRAConfig = None - - multi_trainer_context_: MultiTrainerContext = None - input_queue_: DeviceSwapQueue = None - - def is_stop_signal(self, data: torch.tensor) -> bool: - return data.dtype == torch.long and torch.numel(data) == 1 - - def __init__(self, - model: LLMModel, - config: MLoRAConfig, - dispatcher: PipelineDispatcher, - device: torch.device, - rank: int, - balance: List[int]) -> None: - self.world_size_ = len(balance) - assert self.world_size_ == len(balance) - - self.rank_ = rank - self.device_ = device - self.balance_ = balance - - if rank == 0: - self.role_ = WorkerRole.HEAD - self.input_queue_ = DeviceSwapQueue( - torch.device('cpu'), device, 4, 'input_data_queue') - self.input_queue_.start() - elif rank == self.world_size_ - 1: - self.role_ = WorkerRole.TAIL - else: - self.role_ = WorkerRole.MID - - self.transport_ = RpcTransport( - self.rank_, self.world_size_, self.device_) - - self.default_stream_ = CudaStream( - torch.cuda.default_stream(self.device_)) - - self.config_ = config - self.dispatcher_ = dispatcher - - # need the config value, so must in the last stage to init - self.init_partition(model) - - def run(self): - if self.role_ == WorkerRole.HEAD: - self.forward_stop_ = True - if self.role_ != WorkerRole.HEAD: - self.input_stop_ = True - - while True: - if self.role_ != WorkerRole.TAIL: - self.process_backward() - - if not self.input_stop_: - self.process_input() - - if not self.forward_stop_: - self.process_forward() - - if len(self.backward_cache_) == 0 and self.forward_stop_ and self.input_stop_: - # no froward and backward request - break - - logging.info("Pipe done and to stop.") - - logging.info("saving all models...") - self.save_all_model() - - # clear the pipeline resource - self.stop() - - def stop(self): - transport = self.transport_ - if isinstance(transport, RpcTransport): - transport.stop() - logging.info("Transport stop.") - if self.input_queue_: - self.input_queue_.stop() - - def process_input(self): - def put_train_data(): - train_input = self.dispatcher_.get_train_data() - if not train_input: - # avoid the busy loop - time.sleep(1 / 10000000) - return - for lora_config in train_input.lora_batch_data_config_: - logging.info(f'load lora: {lora_config.adapter_name_}') - data = torch.tensor(train_input.batch_tokens_, - dtype=torch.int64, device="cpu") - msg = PipeMessage(self.device_, self.device_, PipeMessageType.ACTIVATIONS, - 0, data, train_input) - self.input_queue_.put(msg) - - assert self.role_ == WorkerRole.HEAD - assert not self.input_stop_ - - if not self.dispatcher_.check_task_done(): - put_train_data() - # fetch train data - msg = self.input_queue_.get_nowait() - if not msg: - return - train_input = msg.batch_data_ - data = self.forward(msg.tensor_data_, msg.batch_data_) - self.forward_cnt_ += 1 - else: - # stop - self.input_stop_ = True - train_input = None - data = torch.tensor( - [self.forward_cnt_], dtype=torch.long, device="cpu", requires_grad=False) - assert self.is_stop_signal(data) - logging.info("Forward done be signaled.") - - self.default_stream_.poll() - self.send_next_worker(data, train_input) - - def process_backward(self): - assert self.role_ != WorkerRole.TAIL - - message = self.transport_.recv_message( - PipeMessageType.GRADIENTS, block=False) - if message is None: - return - logging.info( - f"Recv the gradients - {str(message.msg_id_)[:8]} from {message.src_}") - - msg_id = message.msg_id_ - - assert msg_id in self.backward_cache_ - phony: torch.Tensor = self.backward_cache_[msg_id] - phony.grad_fn.grad_from_next_worker = message.tensor_data_ - phony.backward() - - self.trainer_step(message.batch_data_.lora_batch_data_config_) - - del self.backward_cache_[msg_id] - - def process_forward(self): - assert self.role_ != WorkerRole.HEAD - assert not self.forward_stop_ - - # recv the tensors from prev-worker - message = self.transport_.recv_message( - PipeMessageType.ACTIVATIONS, block=False) - if message is None: - return - logging.debug( - f"Recv the activations - {str(message.msg_id_)[:8]} from {message.src_}") - - # use RecvOperator get the real data - # the operator also auto send the backward grad to prev worker - if self.is_stop_signal(message.tensor_data_): - self.stop_signal_ = message.tensor_data_ - data = message.tensor_data_ - logging.info("Forward done be signaled.") - else: - data = RecvOperator.apply( - torch.tensor(1.0, requires_grad=True), self.transport_, message) - data.grad_fn.pre_stage_fn = self.default_stream_.poll - self.forward_cnt_ += 1 - data = self.forward(data, message.batch_data_) - - # stop signal may arrive before the activation - # so we check forward cnt to assure all activations have been processed - if self.stop_signal_ is not None and self.stop_signal_.item() == self.forward_cnt_: - self.forward_stop_ = True - - # mid worker need to send the result to next worker - if self.role_ != WorkerRole.TAIL: - self.default_stream_.poll() - return self.send_next_worker(data, message.batch_data_) - - # tail worker need to calc the backward - if not self.forward_stop_ and not self.is_stop_signal(message.tensor_data_): - lora_configs = message.batch_data_.lora_batch_data_config_ - total_loss = self.multi_trainer_context_.calc_loss( - message.batch_data_, data) - # backward doesn't need to save batch_tokens - message.batch_data_.batch_tokens_ = None - total_loss.backward() - - self.trainer_step(lora_configs) - - def trainer_step(self, lora_configs: List[MLoRADataConfig]): - for lora_config in lora_configs: - adapter_name = lora_config.adapter_name_ - self.multi_trainer_context_.step(adapter_name) - step_cnt = self.multi_trainer_context_.get_step_cnt(adapter_name) - if self.multi_trainer_context_.is_save_step(adapter_name): - self.save_model(adapter_name, f"{step_cnt}") - if self.role_ == WorkerRole.HEAD: - self.dispatcher_.update_backward_cnt(adapter_name) - - def save_all_model(self): - for adapter_name in self.multi_trainer_context_.trainer_context_: - self.save_model(adapter_name, "final") - - def save_model(self, adapter_name: str, dir_suffix: str = ""): - # create saved dir - lora_output_dir = adapter_name - if dir_suffix != "": - lora_output_dir += os.sep + adapter_name + "_" + dir_suffix - if not os.path.exists(lora_output_dir): - os.makedirs(lora_output_dir) - - # get lora weights - lora_weights = {} - for layer_model in self.model_partition_: - wrapper_module = layer_model.wrapper_module_ - if hasattr(wrapper_module, 'get_lora_weight_dict'): - lora_weight, _ = wrapper_module.get_lora_weight_dict( - adapter_name) - lora_weights.update(lora_weight) - - saved_path = lora_output_dir + os.sep + \ - f"adapter_model_{self.rank_}.bin" - logging.info(f'save {adapter_name} to {saved_path}') - torch.save(lora_weights, saved_path) - - # save json only on tail worker - if self.role_ == WorkerRole.TAIL: - context = self.multi_trainer_context_.get_trainer_context( - adapter_name) - with open(lora_output_dir + os.sep + "adapter_config.json", "w") as f: - json.dump(context.export_config(), f, indent=4) - - def send_next_worker(self, - tensor_data: torch.Tensor, - batch_data: MLoRABatchData) -> None: - assert isinstance(tensor_data, torch.Tensor) - assert batch_data is None or isinstance(batch_data, MLoRABatchData) - - msg_id = uuid.uuid4().int - assert msg_id not in self.backward_cache_ - - if self.is_stop_signal(tensor_data): - msg_id = -1 - - phony: torch.Tensor = SendOperator.apply(torch.tensor( - 1.0, requires_grad=True), tensor_data, self.transport_, msg_id, batch_data) - - if self.is_stop_signal(tensor_data): - return - - self.backward_cache_[msg_id] = phony - - def init_partition(self, model: LLMModel) -> None: - balance = self.balance_[self.rank_] - start_module_idx = sum(self.balance_[:self.rank_]) - logging.info( - f"RANK-{self.rank_} in device {self.device_} to load module layers " - f"from {start_module_idx} to {start_module_idx + balance}.") - - seq_model = model.sequential_module() - del seq_model[:start_module_idx] - del seq_model[balance:] - for idx in range(0, len(seq_model)): - self.model_partition_.append(seq_model[idx]) - - assert len(self.model_partition_) == balance - - self.n_heads_ = model.n_heads_ - worker_train_paramas: Dict[str, - List[torch.Tensor]] = model.get_train_paramas() - - self.multi_trainer_context_ = MultiTrainerContext( - self.config_, worker_train_paramas) - - del model - torch.cuda.empty_cache() - - def forward(self, - tensor_data: torch.Tensor, - batch_data: MLoRABatchData): - mask = precompute_mask(tensor_data, self.n_heads_, - self.device_, batch_data.batch_mask_) - data = (tensor_data, mask, batch_data, True) - - for seq in self.model_partition_: - data = seq.forward(data) - - return data[0] diff --git a/mlora/pipeline/queue.py b/mlora/pipeline/queue.py deleted file mode 100644 index 9bb1e207..00000000 --- a/mlora/pipeline/queue.py +++ /dev/null @@ -1,99 +0,0 @@ -from mlora.pipeline.messages import PipeMessage -from mlora.pipeline.stream import CudaStream - -import torch -import logging - -from queue import Queue -from threading import Thread -from typing import Optional - - -class DeviceSwapQueue: - def __init__(self, - source_device: torch.device, - target_device: torch.device, - target_size: int = 0, - queue_name: str = "default") -> None: - source_device_is_cpu: bool = True if source_device == torch.device( - "cpu") else False - target_device_is_cpu: bool = True if target_device == torch.device( - "cpu") else False - - assert source_device_is_cpu ^ target_device_is_cpu - - if source_device_is_cpu: - self.copy_stream_: CudaStream = CudaStream( - torch.cuda.Stream(target_device)) - else: - self.copy_stream_: CudaStream = CudaStream( - torch.cuda.Stream(source_device)) - - self.target_device_: torch.device = target_device - self.source_device_: torch.device = source_device - # TODO: change the size by the size of avaliable gpu memory - self.src_queue_: Queue = Queue() - self.dst_queue_: Queue = Queue(target_size) - - self.queue_name_: str = queue_name - - self.stop_: bool = False - - def swap_thread_loop(self): - try: - msg: PipeMessage = self.src_queue_.get(block=True, timeout=10) - except: - return - logging.debug( - f"{self.queue_name_} swap the message - {str(msg.msg_id_)[:8]} start.") - - # must ensure the msg.tensor_data_ sync done - with torch.cuda.stream(self.copy_stream_.stream_): - # do not use the pined_memory maybe can speedup - # need more test - copy_tensor = torch.zeros_like(msg.tensor_data_, device=self.target_device_).copy_( - msg.tensor_data_, non_blocking=True).detach() - msg.tensor_data_ = copy_tensor - # msg.tensor_data_ = msg.tensor_data_.to( - # self.target_device_, non_blocking=True).detach() - - self.copy_stream_.poll() - - logging.debug( - f"{self.queue_name_} swap the message - {str(msg.msg_id_)[:8]} device end.") - self.dst_queue_.put(msg, block=True) - - def swap_thread(self): - logging.info(f"DeviceSwapQueue - {self.queue_name_} start.") - while not self.stop_ or not self.src_queue_.empty(): - self.swap_thread_loop() - logging.info(f"DeviceSwapQueue - {self.queue_name_} stop.") - - def start(self): - self.swap_thread_ = Thread(target=self.swap_thread) - self.swap_thread_.start() - - def stop(self): - self.stop_ = True - self.swap_thread_.join() - - def get(self) -> PipeMessage: - return self.dst_queue_.get(block=True) - - def get_waitime(self, timeout: int = 10) -> Optional[PipeMessage]: - try: - return self.dst_queue_.get(block=True, timeout=timeout) - except: - return None - - def get_nowait(self) -> Optional[PipeMessage]: - try: - return self.dst_queue_.get_nowait() - except: - return None - - def put(self, msg: PipeMessage): - self.src_queue_.put(msg) - - def empty(self) -> bool: - return self.src_queue_.empty() and self.dst_queue_.empty() diff --git a/mlora/pipeline/stream.py b/mlora/pipeline/stream.py deleted file mode 100644 index 74eeca5d..00000000 --- a/mlora/pipeline/stream.py +++ /dev/null @@ -1,16 +0,0 @@ -import torch -import time - - -class CudaStream(): - stream_: torch.cuda.Stream = None - event_: torch.cuda.Event = None - - def __init__(self, stream: torch.cuda.Stream) -> None: - self.stream_ = stream - self.event_ = torch.cuda.Event() - - def poll(self) -> None: - self.event_.record(stream=self.stream_) - while not self.event_.query(): - time.sleep(1 / 1000) diff --git a/mlora/pipeline/transport.py b/mlora/pipeline/transport.py deleted file mode 100644 index bb805213..00000000 --- a/mlora/pipeline/transport.py +++ /dev/null @@ -1,179 +0,0 @@ -from mlora.pipeline.messages import PipeMessageType, PipeMessage -from mlora.pipeline.queue import DeviceSwapQueue - -import os -import logging -import torch -import torch.distributed.rpc - -from typing import Dict -from abc import ABC, abstractmethod -from threading import Thread - -# save by different message type -# recv/send queue will automatically change the tensors' device -RPCMessageRecvQueues: Dict[PipeMessageType, DeviceSwapQueue] = { - PipeMessageType.ACTIVATIONS: None, - PipeMessageType.GRADIENTS: None -} - -RPCMessageSendQueues: Dict[PipeMessageType, DeviceSwapQueue] = { - PipeMessageType.ACTIVATIONS: None, - PipeMessageType.GRADIENTS: None -} - - -def rpc_push_queue(msg: PipeMessage) -> None: - global RPCMessageRecvQueues - - assert msg.msg_type_ in RPCMessageRecvQueues, f"No this message type: {msg.msg_type_.value}" - assert RPCMessageRecvQueues[msg.msg_type_] is not None - - logging.debug( - f"RpcTransport async recv the message: {str(msg.msg_id_)[:8]}.") - RPCMessageRecvQueues[msg.msg_type_].put(msg) - - -class Transport(ABC): - rank_: int - device_: torch.device - - @property - def next_worker_name(self) -> str: - return f"worker-{self.rank_ + 1}" - - @property - def prev_worker_name(self) -> str: - return f"worker-{self.rank_ - 1}" - - @property - def worker_name(self) -> str: - return f"worker-{self.rank_}" - - @abstractmethod - def recv_message(self, msg_type: PipeMessageType, block: bool = False) -> PipeMessage: - pass - - @abstractmethod - def send_message(self, msg: PipeMessage, sync: bool = False) -> None: - pass - - -# rpc transport thread -class RpcTransport(Transport): - rank_: int = -1 - world_size_: int = -1 - worker_device_: torch.device = None - - stop_: bool = False - activations_send_thread_: Thread = None - gradients_send_thread_: Thread = None - - def __init__(self, rank: int, world_size: int, worker_device: torch.device) -> None: - self.rank_ = rank - self.world_size_ = world_size - self.worker_device_ = worker_device - - self.stop_: bool = False - - self.init_rpc() - self.init_device_swap_queue() - self.init_background_thread() - - def init_rpc(self) -> None: - if "MASTER_ADDR" not in os.environ: - os.environ["MASTER_ADDR"] = "localhost" - if "MASTER_PORT" not in os.environ: - os.environ["MASTER_PORT"] = "12355" - - assert self.rank_ > -1 - assert self.world_size_ > -1 - assert self.worker_device_ is not None - - # will be block when all world size's gpu join the group - torch.distributed.rpc.init_rpc( - f"worker-{self.rank_}", rank=self.rank_, world_size=self.world_size_) - - logging.info( - f"Init rpc with rank {self.rank_} world_size: {self.world_size_}") - - def init_device_swap_queue(self): - cpu_device = torch.device("cpu") - - global RPCMessageSendQueues - for key in RPCMessageSendQueues: - RPCMessageSendQueues[key] = DeviceSwapQueue( - self.worker_device_, cpu_device, queue_name=f"{key.value}_send") - RPCMessageSendQueues[key].start() - - global RPCMessageRecvQueues - for key in RPCMessageRecvQueues: - RPCMessageRecvQueues[key] = DeviceSwapQueue( - cpu_device, self.worker_device_, queue_name=f"{key.value}_recv") - RPCMessageRecvQueues[key].start() - - def init_background_thread(self): - self.gradients_send_thread_ = Thread( - target=self.send_loop, args=(PipeMessageType.GRADIENTS,)) - self.activations_send_thread_ = Thread( - target=self.send_loop, args=(PipeMessageType.ACTIVATIONS,)) - - self.gradients_send_thread_.start() - self.activations_send_thread_.start() - - def send_loop(self, msg_type: PipeMessageType): - global RPCMessageSendQueues - send_queue: DeviceSwapQueue = RPCMessageSendQueues[msg_type] - assert send_queue is not None - - while not self.stop_ or not send_queue.empty(): - msg = send_queue.get_waitime() - if msg is None: - continue - assert msg.tensor_data_.device == torch.device("cpu") - logging.debug( - f"RpcTransport async send the message: {str(msg.msg_id_)[:8]} to {msg.dst_}.") - torch.distributed.rpc.rpc_async( - msg.dst_, rpc_push_queue, args=(msg,)) - - def stop_send_loop(self): - global RPCMessageRecvQueues - global RPCMessageSendQueues - - # first should stop the recv queue - for key in RPCMessageRecvQueues: - RPCMessageRecvQueues[key].stop() - - # then stop the send queue - for key in RPCMessageSendQueues: - RPCMessageSendQueues[key].stop() - - self.stop_ = True - self.activations_send_thread_.join() - self.gradients_send_thread_.join() - - def stop_rpc(self): - torch.distributed.rpc.shutdown() - - def stop(self): - self.stop_send_loop() - self.stop_rpc() - - def recv_message(self, msg_type: PipeMessageType, block: bool = True) -> PipeMessage: - global RPCMessageRecvQueues - - assert msg_type in RPCMessageRecvQueues - recv_queue: DeviceSwapQueue = RPCMessageRecvQueues[msg_type] - - if block: - return recv_queue.get() - else: - return recv_queue.get_nowait() - - def send_message(self, msg: PipeMessage, sync: bool = False) -> None: - assert not sync, "RPC transport do not suppose sync == true!" - - global RPCMessageSendQueues - assert msg.msg_type_ in RPCMessageSendQueues - send_queue: DeviceSwapQueue = RPCMessageSendQueues[msg.msg_type_] - send_queue.put(msg) diff --git a/mlora/profiler/__init__.py b/mlora/profiler/__init__.py index 6232e652..d14e917e 100644 --- a/mlora/profiler/__init__.py +++ b/mlora/profiler/__init__.py @@ -1,10 +1,15 @@ -from .profiler import (setup_trace_mode, nvtx_range, nvtx_wrapper, - set_backward_tracepoint, grad_fn_nvtx_wrapper_by_tracepoint) +from .profiler import ( + grad_fn_nvtx_wrapper_by_tracepoint, + nvtx_range, + nvtx_wrapper, + set_backward_tracepoint, + setup_trace_mode, +) __all__ = [ "setup_trace_mode", "nvtx_range", "nvtx_wrapper", "set_backward_tracepoint", - "grad_fn_nvtx_wrapper_by_tracepoint" + "grad_fn_nvtx_wrapper_by_tracepoint", ] diff --git a/mlora/profiler/profiler.py b/mlora/profiler/profiler.py index b7497371..23762775 100644 --- a/mlora/profiler/profiler.py +++ b/mlora/profiler/profiler.py @@ -1,8 +1,8 @@ -import torch import logging - from contextlib import contextmanager -from typing import Callable, Tuple, Set, List +from typing import Callable, List, Set, Tuple + +import torch TRACEPOINT_KEY = "__tp_name" @@ -22,25 +22,26 @@ def is_trace_model() -> bool: def __get_scope_name(grad_fn: torch.autograd.graph.Node): - if TRACEPOINT_KEY in grad_fn.metadata: - return grad_fn.metadata[TRACEPOINT_KEY] + if TRACEPOINT_KEY in grad_fn.metadata(): + return grad_fn.metadata()[TRACEPOINT_KEY] return grad_fn.name() -def nvtx_range_wrapper(func: Callable, - msg: str): +def nvtx_range_wrapper(func: Callable, msg: str): if not is_trace_model(): return func def wrap(*args, **kwargs): with torch.cuda.nvtx.range(msg=msg): return func(*args, **kwargs) + return wrap def nvtx_wrapper(msg: str): def func_decorator(func): return nvtx_range_wrapper(func, msg) + return func_decorator @@ -57,8 +58,7 @@ def nvtx_range(msg, *args, **kwargs): g_scope_stack: List[str] = [] -def __nvtx_pre_hook_wrapper(func: Callable, - grad_fn: torch.autograd.graph.Node): +def __nvtx_pre_hook_wrapper(func: Callable, grad_fn: torch.autograd.graph.Node): global g_scope_stack scope_name = __get_scope_name(grad_fn) @@ -78,16 +78,17 @@ def wrap(*args, **kwargs): pass return func(*args, **kwargs) + return wrap -def __nvtx_hook_wrapper(func: Callable, - grad_fn: torch.autograd.graph.Node): +def __nvtx_hook_wrapper(func: Callable, grad_fn: torch.autograd.graph.Node): global g_scope_stack # do not capture the func object, will cost memory to hold it - is_last_node = not hasattr(grad_fn, "next_functions") or len( - grad_fn.next_functions) == 0 + is_last_node = ( + not hasattr(grad_fn, "next_functions") or len(grad_fn.next_functions) == 0 + ) def wrap(*args, **kwargs): @@ -96,6 +97,7 @@ def wrap(*args, **kwargs): torch.cuda.nvtx.range_pop() return func(*args, **kwargs) + return wrap @@ -103,7 +105,9 @@ def __grad_fn_pre_hook_dummy(grad_outputs: Tuple[torch.Tensor]) -> None: return None -def __grad_fn_hook_dummy(grad_inputs: Tuple[torch.Tensor], grad_outputs: Tuple[torch.Tensor]) -> None: +def __grad_fn_hook_dummy( + grad_inputs: Tuple[torch.Tensor], grad_outputs: Tuple[torch.Tensor] +) -> None: return None @@ -112,16 +116,16 @@ def __grad_fn_nvtx_wrapper(grad_fn: torch.autograd.graph.Node): return assert isinstance( - grad_fn, torch.autograd.graph.Node), f"error type: {type(grad_fn)}" + grad_fn, torch.autograd.graph.Node + ), f"error type: {type(grad_fn)}" - grad_fn.register_prehook( - __nvtx_pre_hook_wrapper(__grad_fn_pre_hook_dummy, grad_fn)) + grad_fn.register_prehook(__nvtx_pre_hook_wrapper(__grad_fn_pre_hook_dummy, grad_fn)) grad_fn.register_hook(__nvtx_hook_wrapper(__grad_fn_hook_dummy, grad_fn)) -def set_backward_tracepoint(grad_fn: torch.autograd.graph.Node, - tp_name: str, - recursion: bool = True): +def set_backward_tracepoint( + grad_fn: torch.autograd.graph.Node | None, tp_name: str, recursion: bool = True +): if not is_trace_model(): return # tp - tracepoint @@ -129,13 +133,14 @@ def set_backward_tracepoint(grad_fn: torch.autograd.graph.Node, return assert isinstance( - grad_fn, torch.autograd.graph.Node), f"error type: {type(grad_fn)}" + grad_fn, torch.autograd.graph.Node + ), f"error type: {type(grad_fn)}" - if TRACEPOINT_KEY in grad_fn.metadata: + if TRACEPOINT_KEY in grad_fn.metadata(): return if not recursion: - grad_fn.metadata[TRACEPOINT_KEY] = tp_name + grad_fn.metadata()[TRACEPOINT_KEY] = tp_name return visited: Set[torch.autograd.graph.Node] = set() @@ -145,7 +150,7 @@ def set_backward_tracepoint(grad_fn: torch.autograd.graph.Node, while len(to_visited_stack) > 0: to_visit = to_visited_stack.pop() - to_visit.metadata[TRACEPOINT_KEY] = tp_name + to_visit.metadata()[TRACEPOINT_KEY] = tp_name visited.add(to_visit) @@ -157,7 +162,7 @@ def set_backward_tracepoint(grad_fn: torch.autograd.graph.Node, continue if next_fn[0] in visited: continue - if TRACEPOINT_KEY in next_fn[0].metadata: + if TRACEPOINT_KEY in next_fn[0].metadata(): continue to_visited_stack.append(next_fn[0]) @@ -189,11 +194,3 @@ def grad_fn_nvtx_wrapper_by_tracepoint(grad_fn: torch.autograd.graph.Node): to_visited_stack.append(next_fn[0]) visited.clear() - - -def tensors_nvtx_wrapper_by_tracepoint(tensors: Tuple[torch.Tensor]): - if not is_trace_model(): - return - - for tensor in tensors: - grad_fn_nvtx_wrapper_by_tracepoint(tensor.grad_fn) diff --git a/mlora/profiler/traceviz.py b/mlora/profiler/traceviz.py index bbddc962..41368113 100644 --- a/mlora/profiler/traceviz.py +++ b/mlora/profiler/traceviz.py @@ -1,15 +1,17 @@ -import torch +from typing import Set, Tuple +import torch from graphviz import Digraph -from typing import Tuple -G_NODE_ATTR = dict(style='filled', - shape='box', - align='left', - fontsize='10', - ranksep='0.1', - height='0.2', - fontname='monospace') +G_NODE_ATTR = dict( + style="filled", + shape="box", + align="left", + fontsize="10", + ranksep="0.1", + height="0.2", + fontname="monospace", +) def __sizeof_fmt(num, suffix="B"): @@ -23,33 +25,29 @@ def __sizeof_fmt(num, suffix="B"): def __name_of_size(var: torch.Tensor): size_arr = ["%d" % v for v in var.size()] memory_size = var.element_size() * var.nelement() - return "[" + ", ".join(size_arr) + "] * " + str(var.element_size()) + " = " + __sizeof_fmt(memory_size / 8) + return ( + f'[{", ".join(size_arr)}] * {str(var.element_size())} ' + + f"= {__sizeof_fmt(memory_size / 8)}" + ) def __name_of_grad_fn(var: torch.autograd.graph.Node) -> str: class_name = var.name() split_index = class_name.rfind("::") if split_index != -1: - class_name = class_name[split_index + 2:] + class_name = class_name[split_index + 2 :] return class_name -def __add_the_attr_to_name(grad_name: str, - attr: str, - var: torch.Tensor) -> str: - if not torch.is_tensor(var): - return grad_name +def __add_the_attr_to_name(grad_name: str, attr: str, var: torch.Tensor) -> str: grad_name += "\n" grad_name += attr + " : " + __name_of_size(var) return grad_name -def __tuple_tensor_add_attr_to_name(grad_fn_name: str, - attr: str, - var: Tuple) -> str: +def __tuple_tensor_add_attr_to_name(grad_fn_name: str, attr: str, var: Tuple) -> str: for item_val in var: - grad_fn_name = __add_the_attr_to_name( - grad_fn_name, attr, item_val) + grad_fn_name = __add_the_attr_to_name(grad_fn_name, attr, item_val) return grad_fn_name @@ -64,50 +62,50 @@ def __dot_add_nodes(dot: Digraph, grad_fn: torch.autograd.graph.Node, visited: s for attr in dir(grad_fn): if not attr.startswith("_saved"): continue - var = getattr(grad_fn, attr) - - grad_fn_name = __add_the_attr_to_name(grad_fn_name, attr, var) - - if not isinstance(var, Tuple): - continue + var: torch.Tensor | Tuple[torch.Tensor, ...] = getattr(grad_fn, attr) - grad_fn_name = __tuple_tensor_add_attr_to_name(grad_fn_name, attr, var) + if isinstance(var, torch.Tensor): + var_tensor: torch.Tensor = var + grad_fn_name = __add_the_attr_to_name(grad_fn_name, attr, var_tensor) + else: + var_tuple: Tuple[torch.Tensor, ...] = var + grad_fn_name = __tuple_tensor_add_attr_to_name( + grad_fn_name, attr, var_tuple + ) - if "__tp_name" in grad_fn.metadata: - grad_fn_name += ("\ntracepoint : " + grad_fn.metadata["__tp_name"]) + if "__tp_name" in grad_fn.metadata(): + grad_fn_name += "\ntracepoint : " + grad_fn.metadata()["__tp_name"] dot.node(str(id(grad_fn)), grad_fn_name) if hasattr(grad_fn, "variable"): - var = grad_fn.variable - dot.node(str(id(var)), __name_of_size(var), fillcolor="lightblue") + grad_var: torch.Tensor = grad_fn.variable + dot.node(str(id(var)), __name_of_size(grad_var), fillcolor="lightblue") dot.edge(str(id(var)), str(id(grad_fn))) - if hasattr(grad_fn, 'saved_tensors'): + if hasattr(grad_fn, "saved_tensors"): for item_val in grad_fn.saved_tensors: - dot.node(str(id(item_val)), __name_of_size( - item_val), fillcolor='orange') + dot.node(str(id(item_val)), __name_of_size(item_val), fillcolor="orange") dot.edge(str(id(item_val)), str(id(grad_fn)), dir="none") - if hasattr(grad_fn, 'next_functions'): + if hasattr(grad_fn, "next_functions"): for item_grad in grad_fn.next_functions: if item_grad[0] is not None: __dot_add_nodes(dot, item_grad[0], visited) dot.edge(str(id(item_grad[0])), str(id(grad_fn))) -def __add_base_tensor(dot: Digraph, var: torch.Tensor, visited: set): +def __add_base_tensor(dot: Digraph, var: torch.Tensor, visited: Set[torch.Tensor]): assert isinstance(var, torch.Tensor) if var in visited: return visited.add(var) - dot.node(str(id(var)), __name_of_size( - var), fillcolor="darkolivegreen1") + dot.node(str(id(var)), __name_of_size(var), fillcolor="darkolivegreen1") if var.grad_fn: __dot_add_nodes(dot, var.grad_fn, visited) dot.edge(str(id(var.grad_fn)), str(id(var))) - if var._is_view(): + if var._base is not None: __add_base_tensor(dot, var._base, visited) dot.edge(str(id(var._base)), str(id(var)), style="dotted") @@ -115,7 +113,7 @@ def __add_base_tensor(dot: Digraph, var: torch.Tensor, visited: set): def trace(var: torch.Tensor, file_name: str): dot = Digraph(node_attr=G_NODE_ATTR, graph_attr=dict(size="12,12")) - visited = set() + visited: Set[torch.Tensor] = set() __add_base_tensor(dot, var, visited) visited.clear() diff --git a/mlora/prompter/__init__.py b/mlora/prompter/__init__.py index 69ddfaf8..5156b0c7 100644 --- a/mlora/prompter/__init__.py +++ b/mlora/prompter/__init__.py @@ -1,13 +1,12 @@ from mlora.config import DatasetConfig -from .prompter import Prompter -from .preference_data_prompter import PreferenceDataPrompter from .instruction_data_prompter import InstructionDataPrompter - +from .preference_data_prompter import PreferenceDataPrompter +from .prompter import Prompter _PROMPTER_CLASS = { "instruction": InstructionDataPrompter, - "preference": PreferenceDataPrompter + "preference": PreferenceDataPrompter, } diff --git a/mlora/prompter/instruction_data_prompter.py b/mlora/prompter/instruction_data_prompter.py index 850e35b2..fdf0d35b 100644 --- a/mlora/prompter/instruction_data_prompter.py +++ b/mlora/prompter/instruction_data_prompter.py @@ -1,8 +1,7 @@ -from .prompter import Prompter - - from typing import Dict, List, override +from .prompter import Prompter + class InstructionDataPrompter(Prompter): def __init__(self, template: str): diff --git a/mlora/prompter/preference_data_prompter.py b/mlora/prompter/preference_data_prompter.py index 216e9e4b..3eee0a75 100644 --- a/mlora/prompter/preference_data_prompter.py +++ b/mlora/prompter/preference_data_prompter.py @@ -1,18 +1,15 @@ -from .prompter import Prompter - - from typing import Dict, List, Tuple, override +from .prompter import Prompter + class PreferenceDataPrompter(Prompter): def __init__(self, template: str): super().__init__(template) def __generate_prompt(self, data_point: Dict[str, str]) -> Tuple[str, str]: - chosen_data = self.template_.render( - data_point=data_point, is_chosen=True) - reject_data = self.template_.render( - data_point=data_point, is_chosen=False) + chosen_data = self.template_.render(data_point=data_point, is_chosen=True) + reject_data = self.template_.render(data_point=data_point, is_chosen=False) return chosen_data, reject_data @override diff --git a/mlora/prompter/prompter.py b/mlora/prompter/prompter.py index 67af38a0..39024d5b 100644 --- a/mlora/prompter/prompter.py +++ b/mlora/prompter/prompter.py @@ -1,21 +1,20 @@ -import yaml -import jinja2 +from abc import abstractmethod from typing import Dict, List + +import jinja2 +import yaml from jinja2.sandbox import ImmutableSandboxedEnvironment -from abc import abstractmethod class Prompter: - template_: jinja2.Template = None + template_: jinja2.Template def __init__(self, template: str): with open(template) as fp: template_str = yaml.safe_load(fp) - jinja_env = ImmutableSandboxedEnvironment( - trim_blocks=True, lstrip_blocks=True) + jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True) self.template_ = jinja_env.from_string(template_str["template"]) @abstractmethod - def generate_prompt(self, data_points: List[Dict[str, str]]) -> List[str]: - ... + def generate_prompt(self, data_points: List[Dict[str, str]]) -> List[str]: ... diff --git a/mlora/server/__init__.py b/mlora/server/__init__.py index b843640d..c525270e 100644 --- a/mlora/server/__init__.py +++ b/mlora/server/__init__.py @@ -1,16 +1,22 @@ +from .adapter import router as adapter_router +from .dataset import router as dataset_router from .dispatcher import router as dispatcher_router from .file import router as file_router -from .dataset import router as dataset_router -from .adapter import router as adapter_router +from .pipe import m_create_task, m_dispatcher +from .storage import ( + db_get_obj, + db_get_str, + db_it_obj, + db_it_str, + db_put_obj, + db_put_str, + root_dir, + root_dir_list, + set_db, + set_root_dir, + set_root_dir_list, +) from .task import router as task_router -from .storage import (set_db, - db_get_str, db_put_str, - db_get_obj, db_put_obj, - db_it_str, db_it_obj, - root_dir, set_root_dir, - root_dir_list, set_root_dir_list) -from .pipe import m_dispatcher, m_create_task - __all__ = [ "dispatcher_router", @@ -20,11 +26,15 @@ "task_router", "m_dispatcher", "m_create_task", - "db_get_str", "db_put_str", "db_get_obj", "db_put_obj", - "db_it_str", "db_it_obj", + "db_get_str", + "db_put_str", + "db_get_obj", + "db_put_obj", + "db_it_str", + "db_it_obj", "set_db", "root_dir", "set_root_dir", "root_dir_list", - "set_root_dir_list" + "set_root_dir_list", ] diff --git a/mlora/server/adapter.py b/mlora/server/adapter.py index 77499c32..b46451b2 100644 --- a/mlora/server/adapter.py +++ b/mlora/server/adapter.py @@ -1,8 +1,9 @@ -import uuid import logging +import uuid + from fastapi import APIRouter, Request -from .storage import db_it_str, db_get_str, db_put_obj +from .storage import db_get_str, db_it_str, db_put_obj router = APIRouter() diff --git a/mlora/server/dataset.py b/mlora/server/dataset.py index 489cddd3..70b22c0b 100644 --- a/mlora/server/dataset.py +++ b/mlora/server/dataset.py @@ -1,12 +1,13 @@ -from mlora.prompter import PrompterFactory -from mlora.config import DatasetConfig - -import os import logging -from fastapi import APIRouter, Request +import os + from datasets import load_dataset +from fastapi import APIRouter, Request + +from mlora.config import DatasetConfig +from mlora.prompter import PrompterFactory -from .storage import db_it_str, db_get_str, db_put_obj, db_get_obj, root_dir_list +from .storage import db_get_obj, db_get_str, db_it_str, db_put_obj, root_dir_list router = APIRouter() @@ -21,7 +22,7 @@ def get_dataset(): @router.get("/showcase") def showcase_dataset(name: str): - dataset = db_get_obj(f'__dataset__{name}') + dataset = db_get_obj(f"__dataset__{name}") if dataset is None: return {"message": "the dataset not exist"} @@ -29,15 +30,18 @@ def showcase_dataset(name: str): dataset_config = DatasetConfig(dataset) dataset_config.data_path_ = os.path.join( - root_dir_list()["data"], dataset_config.data_path_) + root_dir_list()["data"], dataset_config.data_path_ + ) dataset_config.prompt_path_ = os.path.join( - root_dir_list()["prompt"], dataset_config.prompt_path_) + root_dir_list()["prompt"], dataset_config.prompt_path_ + ) prompter = PrompterFactory.create(dataset_config) # just read one item data_points = load_dataset( - "json", data_files=dataset_config.data_path_, split="train[:1]") + "json", data_files=dataset_config.data_path_, split="train[:1]" + ) ret = prompter.generate_prompt(data_points) @@ -64,7 +68,7 @@ async def post_dataset(request: Request): "data": data_file["file_path"], "prompt": prompt_file["file_path"], "prompt_type": prompt_file["prompt_type"], - "preprocess": req["preprocess"] + "preprocess": req["preprocess"], } logging.info(f'Create new dataset: {req["name"]}') diff --git a/mlora/server/file.py b/mlora/server/file.py index 3b67f318..b70c3a63 100644 --- a/mlora/server/file.py +++ b/mlora/server/file.py @@ -1,9 +1,10 @@ +import logging import os import uuid -import logging + from fastapi import APIRouter, UploadFile -from .storage import root_dir_list, db_get_str, db_put_obj, db_it_obj +from .storage import db_get_str, db_it_obj, db_put_obj, root_dir_list router = APIRouter() @@ -11,12 +12,15 @@ def get_local_file(file_type: str): ret = [] for key, value in db_it_obj(file_type): - ret.append({"name": key[len(file_type):], "file": value}) + ret.append({"name": key[len(file_type) :], "file": value}) return ret def save_local_file(file_type: str, name: str, data_file: UploadFile): + if data_file.filename is None: + return {"message": "error file name"} + file_postfix = data_file.filename.split(".")[-1] if file_postfix != "json" and file_postfix != "yaml": return {"message": "unsupport file type"} @@ -45,9 +49,7 @@ def get_data(): def post_data(name: str, data_file: UploadFile): file_name = save_local_file("data", name, data_file) - db_put_obj(f"__data__{name}", { - "file_path": file_name - }) + db_put_obj(f"__data__{name}", {"file_path": file_name}) return {"message": "success"} @@ -61,9 +63,12 @@ def get_prompt(): def post_prompt(name: str, prompt_type: str, data_file: UploadFile): file_name = save_local_file("prompt", name, data_file) - db_put_obj(f"__prompt__{name}", { - "file_path": file_name, - "prompt_type": prompt_type, - }) + db_put_obj( + f"__prompt__{name}", + { + "file_path": file_name, + "prompt_type": prompt_type, + }, + ) return {"message": "success"} diff --git a/mlora/server/storage.py b/mlora/server/storage.py index 2a72f6c1..46ccf647 100644 --- a/mlora/server/storage.py +++ b/mlora/server/storage.py @@ -1,14 +1,17 @@ import json -import plyvel from typing import Dict +import plyvel + # define the root_dir __g_db: plyvel.DB = None __g_root_dir: str = "" -__g_root_dir_list = {"data": "./datas", - "prompt": "./prompts", - "adapter": "./adapters", - "db": "./db"} +__g_root_dir_list = { + "data": "./datas", + "prompt": "./prompts", + "adapter": "./adapters", + "db": "./db", +} def db() -> plyvel.DB: diff --git a/mlora/server/task.py b/mlora/server/task.py index fb3bdd70..06e2ace2 100644 --- a/mlora/server/task.py +++ b/mlora/server/task.py @@ -1,18 +1,18 @@ -from mlora.config import DatasetConfig, TASKCONFIG_CLASS, ADAPTERCONFIG_CLASS - -import os import logging +import os + from fastapi import APIRouter, Request -from .storage import root_dir_list, db_get_str, db_put_obj, db_it_str, db_get_obj +from mlora.config import ADAPTERCONFIG_CLASS, TASKCONFIG_CLASS, DatasetConfig + from .pipe import g_s_create_task +from .storage import db_get_obj, db_get_str, db_it_str, db_put_obj, root_dir_list router = APIRouter() def complete_path(obj, dir_type: str, file_name: str): - obj[file_name] = os.path.join( - root_dir_list()[dir_type], "./" + obj[file_name]) + obj[file_name] = os.path.join(root_dir_list()[dir_type], "./" + obj[file_name]) return obj @@ -55,8 +55,7 @@ async def post_task(request: Request): if adapter is None: return {"message": "can not found the reference adapter"} adapter = complete_path(adapter, "adapter", "path") - adapters[adapter["name"] - ] = ADAPTERCONFIG_CLASS[adapter["type"]](adapter) + adapters[adapter["name"]] = ADAPTERCONFIG_CLASS[adapter["type"]](adapter) task_conf = TASKCONFIG_CLASS[req["type"]](req, adapters, datasets) diff --git a/mlora/utils/__init__.py b/mlora/utils/__init__.py index d147ea4a..f4331b04 100644 --- a/mlora/utils/__init__.py +++ b/mlora/utils/__init__.py @@ -1,11 +1,6 @@ from .cmd import get_cmd_args, get_server_cmd_args from .loader import load_model -from .setup import ( - setup_seed, - setup_logging, - setup_cuda_check, - setup_trace_mode -) +from .setup import setup_cuda_check, setup_logging, setup_seed, setup_trace_mode __all__ = [ "get_cmd_args", @@ -14,5 +9,5 @@ "setup_seed", "setup_logging", "setup_cuda_check", - "setup_trace_mode" + "setup_trace_mode", ] diff --git a/mlora/utils/cmd.py b/mlora/utils/cmd.py index b79d97bf..c4ed39c8 100644 --- a/mlora/utils/cmd.py +++ b/mlora/utils/cmd.py @@ -2,50 +2,67 @@ def _add_base_cmd(parser): - parser.add_argument('--base_model', type=str, required=True, - help='Path to or name of base model') - parser.add_argument('--model_type', type=str, default="llama", - help='The model type, support: llama, chatglm') - parser.add_argument('--device', type=str, default='cuda:0', - help='Specify which GPU to be used, default is cuda:0') - parser.add_argument('--precision', type=str, default="int8", - help='Load model with different precision, include int8') - parser.add_argument('--seed', type=int, default=42, - help='Random seed in integer, default is 42') + parser.add_argument( + "--base_model", type=str, required=True, help="Path to or name of base model" + ) + parser.add_argument( + "--model_type", + type=str, + default="llama", + help="The model type, support: llama, chatglm", + ) + parser.add_argument( + "--device", + type=str, + default="cuda:0", + help="Specify which GPU to be used, default is cuda:0", + ) + parser.add_argument( + "--precision", + type=str, + default="int8", + help="Load model with different precision, include int8", + ) + parser.add_argument( + "--seed", type=int, default=42, help="Random seed in integer, default is 42" + ) # the argument about pipeline - parser.add_argument('--pipeline', action="store_true", - help="Train the LoRA model use the pipeline parallelism") - parser.add_argument('--rank', type=int, default=-1, - help="The device's rank number") - parser.add_argument('--balance', type=int, nargs="+", - help="The model's balance") + parser.add_argument( + "--pipeline", + action="store_true", + help="Train the LoRA model use the pipeline parallelism", + ) + parser.add_argument("--rank", type=int, default=-1, help="The device's rank number") + parser.add_argument("--balance", type=int, nargs="+", help="The model's balance") # configuration about log - parser.add_argument('--log_level', type=str, default="INFO", - help="Set the log level.") - parser.add_argument('--log_file', type=str, - help="Save log to specific file.") + parser.add_argument( + "--log_level", type=str, default="INFO", help="Set the log level." + ) + parser.add_argument("--log_file", type=str, help="Save log to specific file.") return parser def get_cmd_args(): - parser = argparse.ArgumentParser(description='m-LoRA main program') + parser = argparse.ArgumentParser(description="m-LoRA main program") parser = _add_base_cmd(parser) # configuration - parser.add_argument('--config', type=str, required=True, - help='Path to finetune configuration') + parser.add_argument( + "--config", type=str, required=True, help="Path to finetune configuration" + ) # the argument about the trace mode - parser.add_argument('--trace', action="store_true", - help="enbale the trace mode.") + parser.add_argument("--trace", action="store_true", help="enbale the trace mode.") return parser.parse_args() def get_server_cmd_args(): - parser = argparse.ArgumentParser(description='m-LoRA server program') + parser = argparse.ArgumentParser(description="m-LoRA server program") parser = _add_base_cmd(parser) # configuration about dispatcher - parser.add_argument('--concurrency_num', type=int, default=2, - help='The concurrency num of task') - parser.add_argument('--root', type=str, required=True, - help='The root dir to save data') + parser.add_argument( + "--concurrency_num", type=int, default=2, help="The concurrency num of task" + ) + parser.add_argument( + "--root", type=str, required=True, help="The root dir to save data" + ) return parser.parse_args() diff --git a/mlora/utils/loader.py b/mlora/utils/loader.py index 24fb18b5..cb87a917 100644 --- a/mlora/utils/loader.py +++ b/mlora/utils/loader.py @@ -1,11 +1,10 @@ -from mlora.model.llm import LLMModel, LlamaModel -from mlora.model.tokenizer import Tokenizer - import logging +from typing import Tuple -from typing import Tuple, Dict +from mlora.model.llm import LlamaModel, LLMModel +from mlora.model.tokenizer import Tokenizer -MODEL_TYPE_DICT: Dict[str, LLMModel] = { +MODEL_TYPE_DICT = { "llama": LlamaModel, } @@ -16,22 +15,29 @@ def load_partial_model(args) -> LLMModel: assert len(args.balance) >= args.rank logging.info( - f"Pipeline parallelism, rank is {args.rank} and balance is {args.balance}.") + f"Pipeline parallelism, rank is {args.rank} and balance is {args.balance}." + ) partial_model_to_device = [ - index + sum(args.balance[:args.rank])for index in range(0, args.balance[args.rank])] + index + sum(args.balance[: args.rank]) + for index in range(0, args.balance[args.rank]) + ] - return MODEL_TYPE_DICT[args.model_type].from_pretrained(path=args.base_model, - device=args.device, - precision=args.precision, - partial_model_to_device=partial_model_to_device) + return MODEL_TYPE_DICT[args.model_type].from_pretrained( + path=args.base_model, + device=args.device, + precision=args.precision, + partial_model_to_device=partial_model_to_device, + ) def load_full_model(args) -> LLMModel: - return MODEL_TYPE_DICT[args.model_type].from_pretrained(path=args.base_model, - device=args.device, - precision=args.precision, - partial_model_to_device=None) + return MODEL_TYPE_DICT[args.model_type].from_pretrained( + path=args.base_model, + device=args.device, + precision=args.precision, + partial_model_to_device=None, + ) def load_model(args) -> Tuple[Tokenizer, LLMModel]: diff --git a/mlora/utils/setup.py b/mlora/utils/setup.py index 36834e51..f8ab2421 100644 --- a/mlora/utils/setup.py +++ b/mlora/utils/setup.py @@ -1,8 +1,10 @@ -import mlora -import torch import logging import random +from typing import List, Optional +import torch + +import mlora import mlora.profiler @@ -13,26 +15,30 @@ def setup_seed(seed): torch.backends.cudnn.deterministic = True -def setup_logging(log_level: str = "INFO", log_file: str = None): +def setup_logging(log_level: str = "INFO", log_file: Optional[str] = None): # set the logger - log_handlers = [logging.StreamHandler()] + log_handlers: List[logging.Handler] = [logging.StreamHandler()] if log_file is not None: log_handlers.append(logging.FileHandler(log_file)) - logging.basicConfig(format="[%(asctime)s] m-LoRA: %(message)s", - level=log_level, - handlers=log_handlers, - force=True) + logging.basicConfig( + format="[%(asctime)s] m-LoRA: %(message)s", + level=log_level, + handlers=log_handlers, + force=True, + ) def setup_cuda_check(): # check the enviroment if torch.cuda.is_available(): - logging.info('NVIDIA CUDA initialized successfully.') - logging.info('Total %i GPU(s) detected.' % torch.cuda.device_count()) + logging.info("NVIDIA CUDA initialized successfully.") + logging.info("Total %i GPU(s) detected." % torch.cuda.device_count()) else: logging.error( - 'm-LoRA requires NVIDIA CUDA computing capacity. Please check your PyTorch installation.') + "m-LoRA requires NVIDIA CUDA computing capacity. " + "Please check your PyTorch installation." + ) exit(1) diff --git a/pyproject.toml b/pyproject.toml index f6f048ce..64010fad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ ] [project.optional-dependencies] -ci_test = ["pytest", "flake8", "lizard"] +ci_test = ["pytest", "flake8", "lizard", "black", "isort", "mypy"] test = ["peft", "setuptools"] debug = ["graphviz"] deploy = ["rich", "fastapi", "plyvel", "uvicorn", "InquirerPy"] diff --git a/tests/lora_op_test.py b/tests/lora_op_test.py index c844f101..6ad97775 100644 --- a/tests/lora_op_test.py +++ b/tests/lora_op_test.py @@ -39,10 +39,15 @@ def lora_pytorch(self): def lora_mlora(self): lora_a, lora_b, in_data, weight = self.set_test_tensor() input_args = ModelData( - data_config_=[ModelDataConfig("", "", 0, 2)]) + batch_tokens_=[], + batch_mask_=[], + data_config_=[ModelDataConfig("", "", 0, 2)], + enable_checkpoint_=False, + ) - weight = LoRAFunction.apply(weight, in_data, input_args, [ - 1e-4], [2.0], lora_a, lora_b) + weight = LoRAFunction.apply( + weight, in_data, input_args, [1e-4], [2.0], lora_a, lora_b + ) loss = weight.sum() self.mlora_loss = loss.item() @@ -61,5 +66,5 @@ def test_lora(self): assert torch.allclose(self.py_grad_a, self.mlora_grad_a, 1e-4) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()