From 40c352809a282cebd4030c41dafc4c7f38c9a6fe Mon Sep 17 00:00:00 2001 From: yezhem Date: Sat, 29 Jun 2024 16:53:29 +0000 Subject: [PATCH] [feature] support to terminate task and train other task. --- mlora/cli/adapter.py | 46 +++++++++++++-- mlora/cli/dataset.py | 43 +++++++++++++- mlora/cli/file.py | 78 ++++++++++++++----------- mlora/cli/task.py | 67 ++++++++++++++++----- mlora/executor/dispatcher/dispatcher.py | 27 +++++++-- mlora/executor/executor.py | 11 ++++ mlora/executor/task/dpo_task.py | 7 +++ mlora/executor/task/task.py | 24 ++++++-- mlora/executor/task/train_task.py | 10 +++- mlora/server/__init__.py | 5 +- mlora/server/adapter.py | 19 +++++- mlora/server/dataset.py | 31 ++++++++-- mlora/server/file.py | 43 ++++++++++---- mlora/server/pipe.py | 6 ++ mlora/server/task.py | 49 +++++++++++++--- mlora_server.py | 69 +++++++++++++++++----- 16 files changed, 423 insertions(+), 112 deletions(-) diff --git a/mlora/cli/adapter.py b/mlora/cli/adapter.py index 2b663b88..7fa5bff9 100644 --- a/mlora/cli/adapter.py +++ b/mlora/cli/adapter.py @@ -2,7 +2,7 @@ from typing import Any, Dict import requests -from InquirerPy import inquirer, validator +from InquirerPy import inquirer, separator, validator from InquirerPy.base import Choice from rich import print from rich.box import ASCII @@ -20,20 +20,24 @@ def list_adapter(obj): table.add_column("type", justify="center") table.add_column("dir", justify="center") table.add_column("state", justify="center") + table.add_column("task", justify="center") obj.ret_ = [] for ret_item in ret_items: item = json.loads(ret_item) - table.add_row(item["name"], item["type"], item["path"], item["state"]) - obj.ret_.append(item["name"]) + table.add_row( + item["name"], item["type"], item["path"], item["state"], item["task"] + ) + obj.ret_.append((item["name"], item["state"], item["task"])) obj.pret_ = table def adapter_type_set(adapter_conf: Dict[str, Any]): adapter_type = inquirer.select( - message="type:", choices=["lora", "loraplus"] + message="type:", + choices=[separator.Separator(), "lora", "loraplus", separator.Separator()], ).execute() adapter_conf["type"] = adapter_type @@ -48,7 +52,8 @@ def adapter_type_set(adapter_conf: Dict[str, Any]): def adapter_optimizer_set(adapter_conf: Dict[str, Any]): optimizer = inquirer.select( - message="optimizer:", choices=["adamw", "sgd"] + message="optimizer:", + choices=[separator.Separator(), "adamw", "sgd", separator.Separator()], ).execute() adapter_conf["optimizer"] = optimizer @@ -73,7 +78,8 @@ def adapter_lr_scheduler_set(adapter_conf: Dict[str, Any]): return adapter_conf lr_scheduler_type = inquirer.select( - message="optimizer:", choices=["cosine"] + message="optimizer:", + choices=[separator.Separator(), "cosine", separator.Separator()], ).execute() adapter_conf["lrscheduler"] = lr_scheduler_type @@ -120,6 +126,7 @@ def adapter_set(adapter_conf: Dict[str, Any]): target_modules = inquirer.checkbox( message="target_modules:", choices=[ + separator.Separator(), Choice("q_proj", enabled=True), Choice("k_proj", enabled=True), Choice("v_proj", enabled=True), @@ -127,6 +134,7 @@ def adapter_set(adapter_conf: Dict[str, Any]): Choice("gate_proj", enabled=False), Choice("down_proj", enabled=False), Choice("up_proj", enabled=False), + separator.Separator(), ], ).execute() for target in target_modules: @@ -154,12 +162,36 @@ def create_adapter(): print(json.loads(ret.text)) +def delete_adapter(obj): + list_adapter(obj) + all_adapters = obj.ret_ + all_adapters = [ + item for item in all_adapters if item[2] == "NO" or item[1] == "DONE" + ] + + if len(all_adapters) == 0: + print("no adapter, please create one") + return + + adapter_name = inquirer.select( + message="adapter name:", + choices=[separator.Separator(), *all_adapters, separator.Separator()], + ).execute() + + ret = requests.delete(url() + f"/adapter?name={adapter_name[0]}") + ret = json.loads(ret.text) + + print(ret) + + def help_adapter(_): print("Usage of adapter:") print(" ls") print(" list all the adapter.") print(" create") print(" create a new adapter.") + print(" delete") + print(" delete a adapter.") def do_adapter(obj, args): @@ -170,5 +202,7 @@ def do_adapter(obj, args): return print(obj.pret_) elif args[0] == "create": return create_adapter() + elif args[0] == "delete": + return delete_adapter(obj) help_adapter(None) diff --git a/mlora/cli/dataset.py b/mlora/cli/dataset.py index 09fe3096..5caa6494 100644 --- a/mlora/cli/dataset.py +++ b/mlora/cli/dataset.py @@ -1,7 +1,7 @@ import json import requests -from InquirerPy import inquirer, separator +from InquirerPy import inquirer, separator, validator from rich import print from rich.box import ASCII from rich.table import Table @@ -38,7 +38,10 @@ def list_dataset(obj): def create_dataset(obj): - name = inquirer.text(message="name:").execute() + name = inquirer.text( + message="name:", + validate=validator.EmptyInputValidator("name should not be empty"), + ).execute() list_file(obj, "data") all_train_data = [item["name"] for item in obj.ret_] @@ -56,10 +59,22 @@ def create_dataset(obj): message="train data file:", choices=[separator.Separator(), *all_train_data, separator.Separator()], ).execute() + use_prompt = inquirer.select( message="prompt template file:", choices=[separator.Separator(), *all_prompt, separator.Separator()], ).execute() + + use_prompter = inquirer.select( + message="prompter:", + choices=[ + separator.Separator(), + "instruction", + "preference", + separator.Separator(), + ], + ).execute() + use_preprocess = inquirer.select( message="data preprocessing:", choices=[ @@ -77,6 +92,7 @@ def create_dataset(obj): "name": name, "data_name": use_train, "prompt_name": use_prompt, + "prompt_type": use_prompter, "preprocess": use_preprocess, }, ) @@ -84,6 +100,25 @@ def create_dataset(obj): print(json.loads(ret.text)) +def delete_dataset(obj): + list_dataset(obj) + all_dataset = obj.ret_ + + if len(all_dataset) == 0: + print("no dataset, please create one") + return + + dataset_name = inquirer.select( + message="dataset name:", + choices=[separator.Separator(), *all_dataset, separator.Separator()], + ).execute() + + ret = requests.delete(url() + f"/dataset?name={dataset_name}") + ret = json.loads(ret.text) + + print(ret) + + def showcase_dataset(obj): list_dataset(obj) all_dataset = obj.ret_ @@ -109,6 +144,8 @@ def help_dataset(_): print(" list all the dataset.") print(" create") print(" create a new dataset.") + print(" delete") + print(" delete a dataset.") print(" showcase") print(" display training data composed of prompt and dataset.") @@ -121,6 +158,8 @@ def do_dataset(obj, args): return print(obj.pret_) elif args[0] == "create": return create_dataset(obj) + elif args[0] == "delete": + return delete_dataset(obj) elif args[0] == "showcase": return showcase_dataset(obj) diff --git a/mlora/cli/file.py b/mlora/cli/file.py index ba5fba55..e23742f4 100644 --- a/mlora/cli/file.py +++ b/mlora/cli/file.py @@ -8,7 +8,7 @@ from .setting import url -g_file_type_map = {"train data": "data", "prompt data": "prompt"} +g_file_type_map = {"train data": "data", "prompt template": "prompt"} def list_file(obj, file_type: str): @@ -18,13 +18,9 @@ def list_file(obj, file_type: str): table = Table(show_header=True, show_lines=True, box=ASCII) table.add_column("name", justify="center") table.add_column("file", justify="center") - if file_type == "prompt": - table.add_column("prompter", justify="center") for item in ret_items: - row_data = [item["name"], item["file"]["file_path"]] - if file_type == "prompt": - row_data.append(item["file"]["prompt_type"]) + row_data = [item["name"], item["file"]] table.add_row(*row_data) obj.ret_ = ret_items @@ -32,30 +28,18 @@ def list_file(obj, file_type: str): def upload_file(): - name = inquirer.text( - message="name:", - validate=validator.EmptyInputValidator("name should not be empty"), - ).execute() - file_type = inquirer.select( message="file type:", choices=[separator.Separator(), *g_file_type_map.keys(), separator.Separator()], ).execute() file_type = g_file_type_map[file_type] - post_url = url() + f"/{file_type}?name={name}" + name = inquirer.text( + message="name:", + validate=validator.EmptyInputValidator("name should not be empty"), + ).execute() - if file_type == "prompt": - prompt_type = inquirer.select( - message="prompter type:", - choices=[ - separator.Separator(), - "instruction", - "preference", - separator.Separator(), - ], - ).execute() - post_url += f"&prompt_type={prompt_type}" + post_url = url() + f"/{file_type}?name={name}" path = inquirer.filepath( message="file path:", @@ -69,12 +53,38 @@ def upload_file(): print(json.loads(ret.text)) +def delete_file(obj): + list_file(obj, "data") + data_file_list = [("data", item["name"]) for item in obj.ret_] + + list_file(obj, "prompt") + prompt_file_list = [("prompt", item["name"]) for item in obj.ret_] + + chose_item = inquirer.select( + message="file name:", + choices=[ + separator.Separator(), + *data_file_list, + *prompt_file_list, + separator.Separator(), + ], + ).execute() + + delete_url = url() + f"/{chose_item[0]}?name={chose_item[1]}" + + ret = requests.delete(delete_url) + + print(json.loads(ret.text)) + + def help_file(_): print("Usage of file:") print(" ls") - print(" list the usable data or prompt data.") + print(" list the train or prompt data.") print(" upload") - print(" upload a training data or prompt data.") + print(" upload a train or prompt data.") + print(" delete") + print(" delete a train or prompt data.") def do_file(obj, args): @@ -82,18 +92,16 @@ def do_file(obj, args): if args[0] == "ls": # to chose file type - file_type = inquirer.select( - message="type:", - choices=[ - separator.Separator(), - *g_file_type_map.keys(), - separator.Separator(), - ], - ).execute() - file_type = g_file_type_map[file_type] - list_file(obj, file_type) + list_file(obj, "data") + print("Data files:") + print(obj.pret_) + + list_file(obj, "prompt") + print("Prompt files:") return print(obj.pret_) elif args[0] == "upload": return upload_file() + elif args[0] == "delete": + return delete_file(obj) help_file(None) diff --git a/mlora/cli/task.py b/mlora/cli/task.py index f0825dd2..63275c31 100644 --- a/mlora/cli/task.py +++ b/mlora/cli/task.py @@ -1,7 +1,7 @@ import json import requests -from InquirerPy import inquirer +from InquirerPy import inquirer, separator, validator from rich import print from rich.box import ASCII from rich.table import Table @@ -22,16 +22,18 @@ def list_task(obj): table.add_column("adapter", justify="center") table.add_column("state", justify="center") + obj.ret_ = [] for ret_item in ret_items: item = json.loads(ret_item) table.add_row( item["name"], item["type"], item["dataset"], item["adapter"], item["state"] ) + obj.ret_.append((item["name"], item["state"])) obj.pret_ = table -def task_type_set(task_conf, all_adapters): +def task_type_set(obj, task_conf): if task_conf["type"] == "dpo" or task_conf["type"] == "cpo": beta = inquirer.number( message="beta:", float_allowed=True, default=0.1, replace_mode=True @@ -48,21 +50,30 @@ def task_type_set(task_conf, all_adapters): if task_conf["type"] == "cpo": loss_type = inquirer.select( - message="loss_type:", choices=["sigmoid", "hinge"] + message="loss_type:", + choices=[separator.Separator(), "sigmoid", "hinge", separator.Separator()], ).execute() task_conf["loss_type"] = loss_type if task_conf["type"] == "dpo": loss_type = inquirer.select( - message="loss_type:", choices=["sigmoid", "ipo"] + message="loss_type:", + choices=[separator.Separator(), "sigmoid", "ipo", separator.Separator()], ).execute() task_conf["loss_type"] = loss_type - all_adapters.append("base") + list_adapter(obj) + all_ref_adapters = [ + item + for item in obj.ret_ + if item[1] == "DONE" and item[0] != task_conf["adapter"] + ] + all_ref_adapters.append(("base", "use the base llm model")) reference = inquirer.select( - message="reference model:", choices=all_adapters + message="reference model:", + choices=[separator.Separator(), *all_ref_adapters, separator.Separator()], ).execute() - task_conf["reference"] = reference + task_conf["reference"] = reference[0] return task_conf @@ -110,11 +121,15 @@ def create_task(obj): task_conf = {} task_type = inquirer.select( - message="type:", choices=["train", "dpo", "cpo"] + message="type:", + choices=[separator.Separator(), "train", "dpo", "cpo", separator.Separator()], ).execute() task_conf["type"] = task_type - name = inquirer.text(message="name:").execute() + name = inquirer.text( + message="name:", + validate=validator.EmptyInputValidator("Input should not be empty"), + ).execute() task_conf["name"] = name list_dataset(obj) @@ -124,7 +139,10 @@ def create_task(obj): print("no dataset, please create one") return - dataset = inquirer.select(message="dataset:", choices=all_dataset).execute() + dataset = inquirer.select( + message="dataset:", + choices=[separator.Separator(), *all_dataset, separator.Separator()], + ).execute() task_conf["dataset"] = dataset list_adapter(obj) @@ -134,10 +152,13 @@ def create_task(obj): print("no adapter can be train, please create one") return - adapter = inquirer.select(message="train adapter:", choices=all_adapters).execute() - task_conf["adapter"] = adapter + adapter = inquirer.select( + message="train adapter:", + choices=[separator.Separator(), *all_adapters, separator.Separator()], + ).execute() + task_conf["adapter"] = adapter[0] - task_conf = task_type_set(task_conf, all_adapters.copy()) + task_conf = task_type_set(obj, task_conf) task_conf = task_set(task_conf) ret = requests.post(url() + "/task", json=task_conf) @@ -145,12 +166,30 @@ def create_task(obj): print(json.loads(ret.text)) +def delete_task(obj): + list_task(obj) + all_task = obj.ret_ + + delete_task = inquirer.select( + message="termiate task:", + choices=[separator.Separator(), *all_task, separator.Separator()], + ).execute() + + delete_task_name = delete_task[0] + + ret = requests.delete(url() + f"/task?name={delete_task_name}") + + print(json.loads(ret.text)) + + def help_task(_): print("Usage of task:") print(" ls") print(" list the task.") print(" create") print(" create a task.") + print(" delete") + print(" delete a task.") def do_task(obj, args): @@ -161,5 +200,7 @@ def do_task(obj, args): return print(obj.pret_) elif args[0] == "create": return create_task(obj) + elif args[0] == "delete": + return delete_task(obj) help_task(None) diff --git a/mlora/executor/dispatcher/dispatcher.py b/mlora/executor/dispatcher/dispatcher.py index 798cacf8..5ff8b435 100644 --- a/mlora/executor/dispatcher/dispatcher.py +++ b/mlora/executor/dispatcher/dispatcher.py @@ -26,13 +26,13 @@ class Dispatcher: ready_: List[Task] running_: List[Task] - done_: List[Task] init_event_: DispatcherEvent running_event_: DispatcherEvent ready_event_: DispatcherEvent done_event_: DispatcherEvent step_event_: DispatcherEvent + terminate_event_: DispatcherEvent concurrency_num_: int = 2 @@ -42,13 +42,13 @@ def __init__(self, config: DispatcherConfig) -> None: self.ready_ = [] self.running_ = [] - self.done_ = [] self.init_event_ = DispatcherEvent() self.running_event_ = DispatcherEvent() self.ready_event_ = DispatcherEvent() self.done_event_ = DispatcherEvent() self.step_event_ = DispatcherEvent() + self.terminate_event_ = DispatcherEvent() def info(self) -> Dict[str, Any]: return {"name": self.name_, "concurrency_num": self.concurrency_num_} @@ -60,6 +60,7 @@ def register_hook(self, name: str, cb: Callable) -> None: "ready": self.ready_event_, "done": self.done_event_, "step": self.step_event_, + "terminate": self.terminate_event_, } assert name in event_map @@ -72,10 +73,23 @@ def add_task(self, config: TaskConfig, llm_name: str): self.init_event_.notify(task) self.ready_.append(task) + def notify_terminate_task(self, task_name: str): + for task in [*self.running_, *self.ready_]: + if task.task_name() != task_name: + continue + task.notify_terminate() + def is_done(self) -> bool: return len(self.running_) == 0 and len(self.ready_) == 0 def _dispatch_task_in(self): + # ready task to terminate + terminate_task = [task for task in self.ready_ if task.is_terminate()] + self.ready_ = [task for task in self.ready_ if not task.is_terminate()] + + for task in terminate_task: + self.terminate_event_.notify(task) + # ready task to running task assert len(self.running_) <= self.concurrency_num_ if len(self.running_) == self.concurrency_num_: @@ -87,10 +101,15 @@ def _dispatch_task_in(self): self.running_event_.notify(task) def _dispatch_task_out(self): - # running task to ready task or done task + # running task to terminate + terminate_task = [task for task in self.running_ if task.is_terminate()] + self.running_ = [task for task in self.running_ if not task.is_terminate()] + for task in terminate_task: + self.terminate_event_.notify(task) + + # running task to ready done_task = [task for task in self.running_ if task.is_done()] self.running_ = [task for task in self.running_ if not task.is_done()] - self.done_.extend(done_task) for task in done_task: self.done_event_.notify(task) diff --git a/mlora/executor/executor.py b/mlora/executor/executor.py index e2187172..e13660b1 100644 --- a/mlora/executor/executor.py +++ b/mlora/executor/executor.py @@ -33,6 +33,7 @@ def __init__( "running": self.__task_to_running_hook, "ready": self.__task_to_ready_hook, "done": self.__task_to_done_hook, + "terminate": self.__task_to_terminate_hook, } for hook, cb in hook_func.items(): @@ -77,12 +78,22 @@ def __task_to_done_hook(self, task: Task): task.switch_device("cpu") task.done() + def __task_to_terminate_hook(self, task: Task): + logging.info(f"Task - {task.task_name()} terminate.") + for adapter_name in task.adapter_name(): + self.model_.offload_adapter(adapter_name) + task.switch_device("cpu") + task.terminate() + def dispatcher_info(self) -> Dict[str, str]: return self.dispatcher_.info() def add_task(self, config: TaskConfig): self.dispatcher_.add_task(config, self.model_.name_or_path_) + def notify_terminate_task(self, task_name: str): + self.dispatcher_.notify_terminate_task(task_name) + def execute(self) -> None: while not self.dispatcher_.is_done(): data: MLoRAData = self.dispatcher_.data() diff --git a/mlora/executor/task/dpo_task.py b/mlora/executor/task/dpo_task.py index 30e190b3..24553dae 100644 --- a/mlora/executor/task/dpo_task.py +++ b/mlora/executor/task/dpo_task.py @@ -190,3 +190,10 @@ def done(self): del self.context_ if self.ref_context_ is not None: del self.ref_context_ + + @override + def terminate(self): + # release the context + del self.context_ + if self.ref_context_ is not None: + del self.ref_context_ diff --git a/mlora/executor/task/task.py b/mlora/executor/task/task.py index fd8573af..5e1b2f74 100644 --- a/mlora/executor/task/task.py +++ b/mlora/executor/task/task.py @@ -17,27 +17,30 @@ class Task: config_: TaskConfig - now_step_: int - + prompter_: Prompter tokenizer_: Tokenizer + context_: TaskContext data_: List[Dict[str, str]] now_data_idx_: int + now_step_: int - prompter_: Prompter - + terminate_: bool + # need_terminal_ the llm name just for export the config file llm_name_: str def __init__(self, config: TaskConfig, llm_name: str) -> None: self.config_ = config - self.now_step_ = 1 + self.prompter_ = PrompterFactory.create(config.dataset_) self.data_ = [] self.now_data_idx_ = 0 + self.now_step_ = 1 + + self.terminate_ = False - self.prompter_ = PrompterFactory.create(config.dataset_) self.llm_name_ = llm_name @abstractmethod @@ -48,6 +51,9 @@ def prepare(self, linears_info: OrderedDict[str, LinearInfo], tokenizer: Tokeniz @abstractmethod def done(self): ... + @abstractmethod + def terminate(self): ... + @abstractmethod def step(self): ... @@ -60,6 +66,12 @@ def data(self, start_idx: int) -> Tuple[List[Tokens], List[MLoRADataConfig]]: .. @abstractmethod def task_progress(self) -> int: ... + def notify_terminate(self): + self.terminate_ = True + + def is_terminate(self) -> bool: + return self.terminate_ + def _pre_dataset(self): preprocess_func: Dict[str, Callable] = { "default": lambda data: data, diff --git a/mlora/executor/task/train_task.py b/mlora/executor/task/train_task.py index 03b8e176..a69129e1 100644 --- a/mlora/executor/task/train_task.py +++ b/mlora/executor/task/train_task.py @@ -38,9 +38,9 @@ def prepare(self, linears_info: OrderedDict[str, LinearInfo], tokenizer: Tokeniz @override def data(self, start_idx: int) -> Tuple[List[Tokens], List[MLoRADataConfig]]: logging.info( - f"Adapter {self.context_.name_} epoch: { - self.now_epoch_}/{self.config_.num_epochs_}" - f" iteration: {self.now_data_idx_}/{len(self.data_)} step: {self.now_step_}" + f"Adapter {self.context_.name_} " + f"epoch: {self.now_epoch_}/{self.config_.num_epochs_} " + f"iteration: {self.now_data_idx_}/{len(self.data_)} step: {self.now_step_}" ) data_idx_s = self.now_data_idx_ data_idx_e = self.now_data_idx_ + self.config_.mini_batch_size_ @@ -130,6 +130,10 @@ def done(self): # release the context del self.context_ + @override + def terminate(self): + del self.context_ + @override def step(self): stepd: bool = False diff --git a/mlora/server/__init__.py b/mlora/server/__init__.py index c525270e..35bb1bbb 100644 --- a/mlora/server/__init__.py +++ b/mlora/server/__init__.py @@ -2,8 +2,9 @@ from .dataset import router as dataset_router from .dispatcher import router as dispatcher_router from .file import router as file_router -from .pipe import m_create_task, m_dispatcher +from .pipe import m_create_task, m_dispatcher, m_notify_terminate_task from .storage import ( + db_del, db_get_obj, db_get_str, db_it_obj, @@ -26,12 +27,14 @@ "task_router", "m_dispatcher", "m_create_task", + "m_notify_terminate_task", "db_get_str", "db_put_str", "db_get_obj", "db_put_obj", "db_it_str", "db_it_obj", + "db_del", "set_db", "root_dir", "set_root_dir", diff --git a/mlora/server/adapter.py b/mlora/server/adapter.py index b46451b2..67159468 100644 --- a/mlora/server/adapter.py +++ b/mlora/server/adapter.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, Request -from .storage import db_get_str, db_it_str, db_put_obj +from .storage import db_del, db_get_obj, db_get_str, db_it_str, db_put_obj router = APIRouter() @@ -27,9 +27,26 @@ async def post_adapter(request: Request): req["path"] = adapter_dir req["state"] = "UNK" + req["task"] = "NO" logging.info(f"Create new adapter: {req}") db_put_obj(f'__adapter__{req["name"]}', req) return {"message": "success"} + + +@router.delete("/adapter") +def delete_adapter(name: str): + adapter = db_get_obj(f"__adapter__{name}") + + if adapter is None: + return {"message": "the adapter not exist"} + + # only adapter no task will train it can be delete + if adapter["task"] != "NO" and adapter["state"] != "DONE": + return {"message": "adapter with a task, cannot be delete."} + + db_del(f"__adapter__{name}") + + return {"message": "delete the adapter"} diff --git a/mlora/server/dataset.py b/mlora/server/dataset.py index 70b22c0b..c1031e0d 100644 --- a/mlora/server/dataset.py +++ b/mlora/server/dataset.py @@ -7,7 +7,14 @@ from mlora.config import DatasetConfig from mlora.prompter import PrompterFactory -from .storage import db_get_obj, db_get_str, db_it_str, db_put_obj, root_dir_list +from .storage import ( + db_del, + db_get_obj, + db_get_str, + db_it_str, + db_put_obj, + root_dir_list, +) router = APIRouter() @@ -52,8 +59,8 @@ def showcase_dataset(name: str): async def post_dataset(request: Request): req = await request.json() - data_file = db_get_obj(f'__data__{req["data_name"]}') - prompt_file = db_get_obj(f'__prompt__{req["prompt_name"]}') + data_file = db_get_str(f'__data__{req["data_name"]}') + prompt_file = db_get_str(f'__prompt__{req["prompt_name"]}') if data_file is None or prompt_file is None: return {"message": "error parameters"} @@ -65,9 +72,9 @@ async def post_dataset(request: Request): "name": req["name"], "data_name": req["data_name"], "prompt_name": req["prompt_name"], - "data": data_file["file_path"], - "prompt": prompt_file["file_path"], - "prompt_type": prompt_file["prompt_type"], + "data": data_file, + "prompt": prompt_file, + "prompt_type": req["prompt_type"], "preprocess": req["preprocess"], } @@ -76,3 +83,15 @@ async def post_dataset(request: Request): db_put_obj(f'__dataset__{req["name"]}', dataset) return {"message": "success"} + + +@router.delete("/dataset") +def delete_dataset(name: str): + dataset = db_get_obj(f"__dataset__{name}") + + if dataset is None: + return {"message": "the dataset not exist"} + + db_del(f"__dataset__{name}") + + return {"message": "delete the dataset"} diff --git a/mlora/server/file.py b/mlora/server/file.py index b70c3a63..4dd46de8 100644 --- a/mlora/server/file.py +++ b/mlora/server/file.py @@ -4,14 +4,14 @@ from fastapi import APIRouter, UploadFile -from .storage import db_get_str, db_it_obj, db_put_obj, root_dir_list +from .storage import db_del, db_get_str, db_it_str, db_put_str, root_dir_list router = APIRouter() def get_local_file(file_type: str): ret = [] - for key, value in db_it_obj(file_type): + for key, value in db_it_str(file_type): ret.append({"name": key[len(file_type) :], "file": value}) return ret @@ -22,7 +22,8 @@ def save_local_file(file_type: str, name: str, data_file: UploadFile): return {"message": "error file name"} file_postfix = data_file.filename.split(".")[-1] - if file_postfix != "json" and file_postfix != "yaml": + check_postfix = ["json", "yaml"] + if file_postfix not in check_postfix: return {"message": "unsupport file type"} if db_get_str(f"__{file_type}__{name}") is not None: @@ -49,26 +50,44 @@ def get_data(): def post_data(name: str, data_file: UploadFile): file_name = save_local_file("data", name, data_file) - db_put_obj(f"__data__{name}", {"file_path": file_name}) + db_put_str(f"__data__{name}", file_name) return {"message": "success"} +@router.delete("/data") +def delete_data(name: str): + file_name: str | None = db_get_str(f"__data__{name}") + + if file_name is None: + return {"message": "file not exist"} + + db_del(f"__data__{name}") + + return {"message": "delete success"} + + @router.get("/prompt") def get_prompt(): return get_local_file("__prompt__") @router.post("/prompt") -def post_prompt(name: str, prompt_type: str, data_file: UploadFile): +def post_prompt(name: str, data_file: UploadFile): file_name = save_local_file("prompt", name, data_file) - db_put_obj( - f"__prompt__{name}", - { - "file_path": file_name, - "prompt_type": prompt_type, - }, - ) + db_put_str(f"__prompt__{name}", file_name) return {"message": "success"} + + +@router.delete("/prompt") +def delete_prompt(name: str): + file_name: str | None = db_get_str(f"__prompt__{name}") + + if file_name is None: + return {"message": "file not exist"} + + db_del(f"__prompt__{name}") + + return {"message": "delete success"} diff --git a/mlora/server/pipe.py b/mlora/server/pipe.py index 4f3f1efc..b97e9d28 100644 --- a/mlora/server/pipe.py +++ b/mlora/server/pipe.py @@ -5,6 +5,7 @@ # g_m: model side use, g_s: server side use g_m_dispatcher, g_s_dispatcher = multiprocessing.Pipe(True) g_m_create_task, g_s_create_task = multiprocessing.Pipe(True) +g_m_notify_terminate_task, g_s_notify_terminate_task = multiprocessing.Pipe(True) def m_dispatcher() -> multiprocessing.connection.Connection: @@ -15,3 +16,8 @@ def m_dispatcher() -> multiprocessing.connection.Connection: def m_create_task() -> multiprocessing.connection.Connection: global g_m_create_task return g_m_create_task + + +def m_notify_terminate_task() -> multiprocessing.connection.Connection: + global g_m_notify_terminate_task + return g_m_notify_terminate_task diff --git a/mlora/server/task.py b/mlora/server/task.py index 06e2ace2..90d61de8 100644 --- a/mlora/server/task.py +++ b/mlora/server/task.py @@ -5,8 +5,15 @@ from mlora.config import ADAPTERCONFIG_CLASS, TASKCONFIG_CLASS, DatasetConfig -from .pipe import g_s_create_task -from .storage import db_get_obj, db_get_str, db_it_str, db_put_obj, root_dir_list +from .pipe import g_s_create_task, g_s_notify_terminate_task +from .storage import ( + db_del, + db_get_obj, + db_get_str, + db_it_str, + db_put_obj, + root_dir_list, +) router = APIRouter() @@ -39,6 +46,7 @@ async def post_task(request: Request): if adapter is None: return {"message": "can not found the adapter"} + # create the task config for executor datasets = {} adapters = {} # complete the storage path @@ -51,19 +59,46 @@ async def post_task(request: Request): # dpo need add the reference adapter if "reference" in req and req["reference"] != "base": - adapter = db_get_obj(f'__adapter__{req["reference"]}') - if adapter is None: + ref_adapter = db_get_obj(f'__adapter__{req["reference"]}') + if ref_adapter is None: return {"message": "can not found the reference adapter"} - adapter = complete_path(adapter, "adapter", "path") - adapters[adapter["name"]] = ADAPTERCONFIG_CLASS[adapter["type"]](adapter) + ref_adapter = complete_path(ref_adapter, "adapter", "path") + adapters[ref_adapter["name"]] = ADAPTERCONFIG_CLASS[ref_adapter["type"]]( + ref_adapter + ) task_conf = TASKCONFIG_CLASS[req["type"]](req, adapters, datasets) - logging.info(f"Create new task: {req["name"]}") + logging.info(f"Create new task: {req["name"]} with adapter") + # set the task's state req["state"] = "UNK" db_put_obj(f'__task__{req["name"]}', req) + # set the adapter's state + adapter = db_get_obj(f'__adapter__{req["adapter"]}') + adapter["task"] = req["name"] + db_put_obj(f'__adapter__{req["adapter"]}', adapter) + g_s_create_task.send(task_conf) return {"message": "success"} + + +@router.delete("/task") +def terminate_task(name: str): + task = db_get_obj(f"__task__{name}") + + if task is None: + return {"message": "the task not exist"} + + if task["state"] == "DONE": + db_del(f"__task__{name}") + return {"message": "delete the done task."} + + g_s_notify_terminate_task.send(name) + + task["state"] = "TERMINATING" + db_put_obj(f"__task__{name}", task) + + return {"message": f"to terminate the task {name}, wait."} diff --git a/mlora_server.py b/mlora_server.py index d21d54bd..3e1db8e7 100644 --- a/mlora_server.py +++ b/mlora_server.py @@ -23,7 +23,6 @@ import mlora.server import os -import json import plyvel import logging import uvicorn @@ -33,27 +32,49 @@ m_task_done, s_task_done = multiprocessing.Pipe(True) m_task_step, s_task_step = multiprocessing.Pipe(True) +m_task_terminate, s_task_terminate = multiprocessing.Pipe(True) def backend_server_set_task_state(task_name: str, state: str): - task_info = mlora.server.db_get_str(f'__task__{task_name}') - task_info = json.loads(task_info) + # to get the task, and set it's state + task_info = mlora.server.db_get_obj(f"__task__{task_name}") + if task_info is None: + logging.info(f"the task {task_name} maybe be terminated.") + return + task_info["state"] = state - mlora.server.db_put_str(f'__task__{task_name}', json.dumps(task_info)) - # to get the adapter in the task, and to set it done + mlora.server.db_put_obj(f"__task__{task_name}", task_info) + + # to get the adapter in the task, and to set it's state adapter_name = task_info["adapter"] - adapter_info = mlora.server.db_get_str(f'__adapter__{adapter_name}') - adapter_info = json.loads(adapter_info) + adapter_info = mlora.server.db_get_obj(f"__adapter__{adapter_name}") adapter_info["state"] = state - mlora.server.db_put_str(f'__adapter__{adapter_name}', json.dumps(adapter_info)) + mlora.server.db_put_obj(f"__adapter__{adapter_name}", adapter_info) + + +def backend_server_delete_task(task_name: str): + # to get the task, and set the adapters' state + task_info = mlora.server.db_get_obj(f"__task__{task_name}") + if task_info is None: + logging.info(f"the task {task_name} maybe be terminated.") + return + + # to get the adapter in the task, and to set it's state + adapter_name = task_info["adapter"] + adapter_info = mlora.server.db_get_obj(f"__adapter__{adapter_name}") + adapter_info["task"] = "NO" + mlora.server.db_put_obj(f"__adapter__{adapter_name}", adapter_info) + + mlora.server.db_del(f"__task__{task_name}") def backend_server_run_fn(args): mlora.server.set_root_dir(args.root) root_dir_list = mlora.server.root_dir_list() - root_dir_list = dict(map(lambda kv: (kv[0], os.path.join( - args.root, kv[1])), root_dir_list.items())) + root_dir_list = dict( + map(lambda kv: (kv[0], os.path.join(args.root, kv[1])), root_dir_list.items()) + ) mlora.server.set_root_dir_list(root_dir_list) @@ -72,7 +93,7 @@ def backend_server_run_fn(args): mLoRAServer.include_router(mlora.server.adapter_router) mLoRAServer.include_router(mlora.server.task_router) - web_thread = threading.Thread(target=uvicorn.run, args=(mLoRAServer, )) + web_thread = threading.Thread(target=uvicorn.run, args=(mLoRAServer,)) logging.info("Start the backend web server run thread") web_thread.start() @@ -84,12 +105,19 @@ def backend_server_run_fn(args): backend_server_set_task_state(task_name, "DONE") if s_task_step.poll(timeout=0.1): task_name, progress = s_task_step.recv() + # the step maybe after the done + if progress >= 100: + continue backend_server_set_task_state(task_name, str(progress) + "%") + if s_task_terminate.poll(timeout=0.1): + task_name = s_task_terminate.recv() + backend_server_delete_task(task_name) def backend_model_run_fn(executor: mlora.executor.Executor): m_dispatcher = mlora.server.m_dispatcher() m_create_task = mlora.server.m_create_task() + m_ternimate_task = mlora.server.m_notify_terminate_task() while True: if m_dispatcher.poll(timeout=0.1): @@ -98,6 +126,9 @@ def backend_model_run_fn(executor: mlora.executor.Executor): if m_create_task.poll(timeout=0.1): task_conf = m_create_task.recv() executor.add_task(task_conf) + if m_ternimate_task.poll(timeout=0.1): + task_name = m_ternimate_task.recv() + executor.notify_terminate_task(task_name) def task_done_callback_fn(task: mlora.executor.task.Task): @@ -111,6 +142,11 @@ def task_step_callback_fn(task: mlora.executor.task.Task): m_task_step.send((task_name, task.task_progress())) +def task_terminate_callback_fn(task: mlora.executor.task.Task): + task_name = task.task_name() + m_task_terminate.send(task_name) + + if __name__ == "__main__": args = mlora.utils.get_server_cmd_args() @@ -119,18 +155,19 @@ def task_step_callback_fn(task: mlora.executor.task.Task): mlora.utils.setup_cuda_check() backend_server_run_process = multiprocessing.Process( - target=backend_server_run_fn, args=(args,)) + target=backend_server_run_fn, args=(args,) + ) backend_server_run_process.start() logging.info("Start the backend model run process") tokenizer, model = mlora.utils.load_model(args) - config = mlora.config.MLoRAServerConfig({ - "name": "backend", - "concurrency_num": args.concurrency_num - }) + config = mlora.config.MLoRAServerConfig( + {"name": "backend", "concurrency_num": args.concurrency_num} + ) executor = mlora.executor.Executor(model, tokenizer, config) executor.register_hook("done", task_done_callback_fn) executor.register_hook("step", task_step_callback_fn) + executor.register_hook("terminate", task_terminate_callback_fn) # model to execute the task execute_thread = threading.Thread(target=executor.execute, args=())