Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] support to terminate task and train other task. #228

Merged
merged 1 commit into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions mlora/cli/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Dict

import requests
from InquirerPy import inquirer, validator
from InquirerPy import inquirer, separator, validator
from InquirerPy.base import Choice
from rich import print
from rich.box import ASCII
Expand All @@ -20,20 +20,24 @@ def list_adapter(obj):
table.add_column("type", justify="center")
table.add_column("dir", justify="center")
table.add_column("state", justify="center")
table.add_column("task", justify="center")

obj.ret_ = []

for ret_item in ret_items:
item = json.loads(ret_item)
table.add_row(item["name"], item["type"], item["path"], item["state"])
obj.ret_.append(item["name"])
table.add_row(
item["name"], item["type"], item["path"], item["state"], item["task"]
)
obj.ret_.append((item["name"], item["state"], item["task"]))

obj.pret_ = table


def adapter_type_set(adapter_conf: Dict[str, Any]):
adapter_type = inquirer.select(
message="type:", choices=["lora", "loraplus"]
message="type:",
choices=[separator.Separator(), "lora", "loraplus", separator.Separator()],
).execute()
adapter_conf["type"] = adapter_type

Expand All @@ -48,7 +52,8 @@ def adapter_type_set(adapter_conf: Dict[str, Any]):

def adapter_optimizer_set(adapter_conf: Dict[str, Any]):
optimizer = inquirer.select(
message="optimizer:", choices=["adamw", "sgd"]
message="optimizer:",
choices=[separator.Separator(), "adamw", "sgd", separator.Separator()],
).execute()
adapter_conf["optimizer"] = optimizer

Expand All @@ -73,7 +78,8 @@ def adapter_lr_scheduler_set(adapter_conf: Dict[str, Any]):
return adapter_conf

lr_scheduler_type = inquirer.select(
message="optimizer:", choices=["cosine"]
message="optimizer:",
choices=[separator.Separator(), "cosine", separator.Separator()],
).execute()
adapter_conf["lrscheduler"] = lr_scheduler_type

Expand Down Expand Up @@ -120,13 +126,15 @@ def adapter_set(adapter_conf: Dict[str, Any]):
target_modules = inquirer.checkbox(
message="target_modules:",
choices=[
separator.Separator(),
Choice("q_proj", enabled=True),
Choice("k_proj", enabled=True),
Choice("v_proj", enabled=True),
Choice("o_proj", enabled=True),
Choice("gate_proj", enabled=False),
Choice("down_proj", enabled=False),
Choice("up_proj", enabled=False),
separator.Separator(),
],
).execute()
for target in target_modules:
Expand Down Expand Up @@ -154,12 +162,36 @@ def create_adapter():
print(json.loads(ret.text))


def delete_adapter(obj):
list_adapter(obj)
all_adapters = obj.ret_
all_adapters = [
item for item in all_adapters if item[2] == "NO" or item[1] == "DONE"
]

if len(all_adapters) == 0:
print("no adapter, please create one")
return

adapter_name = inquirer.select(
message="adapter name:",
choices=[separator.Separator(), *all_adapters, separator.Separator()],
).execute()

ret = requests.delete(url() + f"/adapter?name={adapter_name[0]}")
ret = json.loads(ret.text)

print(ret)


def help_adapter(_):
print("Usage of adapter:")
print(" ls")
print(" list all the adapter.")
print(" create")
print(" create a new adapter.")
print(" delete")
print(" delete a adapter.")


def do_adapter(obj, args):
Expand All @@ -170,5 +202,7 @@ def do_adapter(obj, args):
return print(obj.pret_)
elif args[0] == "create":
return create_adapter()
elif args[0] == "delete":
return delete_adapter(obj)

help_adapter(None)
43 changes: 41 additions & 2 deletions mlora/cli/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json

import requests
from InquirerPy import inquirer, separator
from InquirerPy import inquirer, separator, validator
from rich import print
from rich.box import ASCII
from rich.table import Table
Expand Down Expand Up @@ -38,7 +38,10 @@ def list_dataset(obj):


def create_dataset(obj):
name = inquirer.text(message="name:").execute()
name = inquirer.text(
message="name:",
validate=validator.EmptyInputValidator("name should not be empty"),
).execute()

list_file(obj, "data")
all_train_data = [item["name"] for item in obj.ret_]
Expand All @@ -56,10 +59,22 @@ def create_dataset(obj):
message="train data file:",
choices=[separator.Separator(), *all_train_data, separator.Separator()],
).execute()

use_prompt = inquirer.select(
message="prompt template file:",
choices=[separator.Separator(), *all_prompt, separator.Separator()],
).execute()

use_prompter = inquirer.select(
message="prompter:",
choices=[
separator.Separator(),
"instruction",
"preference",
separator.Separator(),
],
).execute()

use_preprocess = inquirer.select(
message="data preprocessing:",
choices=[
Expand All @@ -77,13 +92,33 @@ def create_dataset(obj):
"name": name,
"data_name": use_train,
"prompt_name": use_prompt,
"prompt_type": use_prompter,
"preprocess": use_preprocess,
},
)

print(json.loads(ret.text))


def delete_dataset(obj):
list_dataset(obj)
all_dataset = obj.ret_

if len(all_dataset) == 0:
print("no dataset, please create one")
return

dataset_name = inquirer.select(
message="dataset name:",
choices=[separator.Separator(), *all_dataset, separator.Separator()],
).execute()

ret = requests.delete(url() + f"/dataset?name={dataset_name}")
ret = json.loads(ret.text)

print(ret)


def showcase_dataset(obj):
list_dataset(obj)
all_dataset = obj.ret_
Expand All @@ -109,6 +144,8 @@ def help_dataset(_):
print(" list all the dataset.")
print(" create")
print(" create a new dataset.")
print(" delete")
print(" delete a dataset.")
print(" showcase")
print(" display training data composed of prompt and dataset.")

Expand All @@ -121,6 +158,8 @@ def do_dataset(obj, args):
return print(obj.pret_)
elif args[0] == "create":
return create_dataset(obj)
elif args[0] == "delete":
return delete_dataset(obj)
elif args[0] == "showcase":
return showcase_dataset(obj)

Expand Down
78 changes: 43 additions & 35 deletions mlora/cli/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .setting import url

g_file_type_map = {"train data": "data", "prompt data": "prompt"}
g_file_type_map = {"train data": "data", "prompt template": "prompt"}


def list_file(obj, file_type: str):
Expand All @@ -18,44 +18,28 @@ def list_file(obj, file_type: str):
table = Table(show_header=True, show_lines=True, box=ASCII)
table.add_column("name", justify="center")
table.add_column("file", justify="center")
if file_type == "prompt":
table.add_column("prompter", justify="center")

for item in ret_items:
row_data = [item["name"], item["file"]["file_path"]]
if file_type == "prompt":
row_data.append(item["file"]["prompt_type"])
row_data = [item["name"], item["file"]]
table.add_row(*row_data)

obj.ret_ = ret_items
obj.pret_ = table


def upload_file():
name = inquirer.text(
message="name:",
validate=validator.EmptyInputValidator("name should not be empty"),
).execute()

file_type = inquirer.select(
message="file type:",
choices=[separator.Separator(), *g_file_type_map.keys(), separator.Separator()],
).execute()
file_type = g_file_type_map[file_type]

post_url = url() + f"/{file_type}?name={name}"
name = inquirer.text(
message="name:",
validate=validator.EmptyInputValidator("name should not be empty"),
).execute()

if file_type == "prompt":
prompt_type = inquirer.select(
message="prompter type:",
choices=[
separator.Separator(),
"instruction",
"preference",
separator.Separator(),
],
).execute()
post_url += f"&prompt_type={prompt_type}"
post_url = url() + f"/{file_type}?name={name}"

path = inquirer.filepath(
message="file path:",
Expand All @@ -69,31 +53,55 @@ def upload_file():
print(json.loads(ret.text))


def delete_file(obj):
list_file(obj, "data")
data_file_list = [("data", item["name"]) for item in obj.ret_]

list_file(obj, "prompt")
prompt_file_list = [("prompt", item["name"]) for item in obj.ret_]

chose_item = inquirer.select(
message="file name:",
choices=[
separator.Separator(),
*data_file_list,
*prompt_file_list,
separator.Separator(),
],
).execute()

delete_url = url() + f"/{chose_item[0]}?name={chose_item[1]}"

ret = requests.delete(delete_url)

print(json.loads(ret.text))


def help_file(_):
print("Usage of file:")
print(" ls")
print(" list the usable data or prompt data.")
print(" list the train or prompt data.")
print(" upload")
print(" upload a training data or prompt data.")
print(" upload a train or prompt data.")
print(" delete")
print(" delete a train or prompt data.")


def do_file(obj, args):
args = args.split(" ")

if args[0] == "ls":
# to chose file type
file_type = inquirer.select(
message="type:",
choices=[
separator.Separator(),
*g_file_type_map.keys(),
separator.Separator(),
],
).execute()
file_type = g_file_type_map[file_type]
list_file(obj, file_type)
list_file(obj, "data")
print("Data files:")
print(obj.pret_)

list_file(obj, "prompt")
print("Prompt files:")
return print(obj.pret_)
elif args[0] == "upload":
return upload_file()
elif args[0] == "delete":
return delete_file(obj)

help_file(None)
Loading
Loading