Skip to content

Commit

Permalink
[improvement] add code formatter - black and isort, static typing (#227)
Browse files Browse the repository at this point in the history
checker mypy.
  • Loading branch information
yezhengmao1 committed Jun 28, 2024
1 parent dcb641f commit a9bb7a3
Show file tree
Hide file tree
Showing 79 changed files with 1,401 additions and 2,127 deletions.
11 changes: 10 additions & 1 deletion .github/workflows/pr-clean-code-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,16 @@ jobs:
lizard -l python ./mlora -C 12
- name: Lint with flake8
run: |
flake8 . --count --show-source --statistics --max-line-length=127 --max-complexity 15 --ignore=E722,W504
flake8 ./mlora --count --show-source --statistics --max-line-length=88 --max-complexity 15 --ignore=E203,W503,E704
- name: Lint with black
run: |
black --check ./mlora
- name: Lint with isort
run: |
isort ./mlora --check --profile black
- name: Static code check with mypy
run: |
mypy ./mlora --ignore-missing-imports --non-interactive --install-types --check-untyped-defs
- name: Test with pytest
run: |
pytest
8 changes: 7 additions & 1 deletion .github/workflows/pre-commit
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

lizard -l python ./mlora -C 12

flake8 . --count --show-source --statistics --max-line-length=127 --max-complexity 15 --ignore=E722,W504
black --check ./mlora

isort ./mlora --check --profile black

flake8 ./mlora --count --show-source --statistics --max-line-length=88 --max-complexity 15 --ignore=E203,W503,E704

mypy ./mlora --ignore-missing-imports --non-interactive --install-types --check-untyped-defs

pytest
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Firstly, you should clone this repository and install dependencies:
# Clone Repository
git clone https://github.com/TUDB-Labs/mLoRA
cd mLoRA
# Install requirements
# Install requirements need the Python >= 3.12
pip install .
```

Expand Down Expand Up @@ -123,8 +123,14 @@ Submit a pull request with a detailed explanation of your changes.

You can use the pre-commit to check your code.
```bash
# Install requirements
pip install .[ci_test]
ln -s ../../.github/workflows/pre-commit .git/hooks/pre-commit
```
Or just call the script to check your code
```bash
.github/workflows/pre-commit
```

## Citation
Please cite the repo if you use the code in this repo.
Expand Down
8 changes: 4 additions & 4 deletions mlora/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .setting import G_HOST, G_PORT
from .adapter import do_adapter, help_adapter
from .dataset import do_dataset, help_dataset
from .dispatcher import do_dispatcher, help_dispatcher
from .file import do_file, help_file
from .dataset import do_dataset, help_dataset
from .adapter import do_adapter, help_adapter
from .setting import G_HOST, G_PORT
from .task import do_task, help_task

__all__ = [
Expand All @@ -17,5 +17,5 @@
"help_adapter",
"do_adapter",
"help_task",
"do_task"
"do_task",
]
83 changes: 36 additions & 47 deletions mlora/cli/adapter.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import json
from typing import Any, Dict

import requests
from InquirerPy import inquirer
from InquirerPy import validator
from InquirerPy import inquirer, validator
from InquirerPy.base import Choice
from rich import print
from rich.table import Table
from rich.box import ASCII
from typing import Dict
from rich.table import Table

from .setting import url


def list_adapter(obj):
ret = requests.get(url() + "/adapter")
ret = json.loads(ret.text)
ret_items = json.loads(ret.text)

table = Table(show_header=True, show_lines=True, box=ASCII)
table.add_column("name", justify="center")
Expand All @@ -23,63 +23,58 @@ def list_adapter(obj):

obj.ret_ = []

for item in ret:
item = json.loads(item)
for ret_item in ret_items:
item = json.loads(ret_item)
table.add_row(item["name"], item["type"], item["path"], item["state"])
obj.ret_.append(item["name"])

obj.pret_ = table


def adapter_type_set(adapter_conf: Dict[str, any]):
def adapter_type_set(adapter_conf: Dict[str, Any]):
adapter_type = inquirer.select(
message="type:", choices=["lora", "loraplus"]).execute()
message="type:", choices=["lora", "loraplus"]
).execute()
adapter_conf["type"] = adapter_type

if adapter_type == "loraplus":
lr_ratio = inquirer.number(
message="lr_ratio:",
float_allowed=True,
default=8.0,
replace_mode=True
message="lr_ratio:", float_allowed=True, default=8.0, replace_mode=True
).execute()
adapter_conf["lr_ratio"] = lr_ratio

return adapter_conf


def adapter_optimizer_set(adapter_conf: Dict[str, any]):
def adapter_optimizer_set(adapter_conf: Dict[str, Any]):
optimizer = inquirer.select(
message="optimizer:", choices=["adamw", "sgd"]).execute()
message="optimizer:", choices=["adamw", "sgd"]
).execute()
adapter_conf["optimizer"] = optimizer

lr = inquirer.number(
message="learning rate:",
float_allowed=True,
default=3e-4,
replace_mode=True
message="learning rate:", float_allowed=True, default=3e-4, replace_mode=True
).execute()
adapter_conf["lr"] = lr

if optimizer == "sgd":
momentum = inquirer.number(
message="momentum:",
float_allowed=True,
default=0.0,
replace_mode=True
message="momentum:", float_allowed=True, default=0.0, replace_mode=True
).execute()
adapter_conf["momentum"] = momentum
return adapter_conf


def adapter_lr_scheduler_set(adapter_conf: Dict[str, any]):
def adapter_lr_scheduler_set(adapter_conf: Dict[str, Any]):
need_lr_scheduler = inquirer.confirm(
message="Need learning rate scheduler:", default=False).execute()
message="Need learning rate scheduler:", default=False
).execute()
if not need_lr_scheduler:
return adapter_conf

lr_scheduler_type = inquirer.select(
message="optimizer:", choices=["cosine"]).execute()
message="optimizer:", choices=["cosine"]
).execute()
adapter_conf["lrscheduler"] = lr_scheduler_type

if lr_scheduler_type == "cosine":
Expand All @@ -101,24 +96,15 @@ def adapter_lr_scheduler_set(adapter_conf: Dict[str, any]):
return adapter_conf


def adapter_set(adapter_conf: Dict[str, any]):
r = inquirer.number(
message="rank:",
default=32
).execute()
def adapter_set(adapter_conf: Dict[str, Any]):
r = inquirer.number(message="rank:", default=32).execute()
adapter_conf["r"] = r

alpha = inquirer.number(
message="alpha:",
default=64
).execute()
alpha = inquirer.number(message="alpha:", default=64).execute()
adapter_conf["alpha"] = alpha

dropout = inquirer.number(
message="dropout:",
float_allowed=True,
replace_mode=True,
default=0.05
message="dropout:", float_allowed=True, replace_mode=True, default=0.05
).execute()
adapter_conf["dropout"] = dropout

Expand All @@ -133,13 +119,15 @@ def adapter_set(adapter_conf: Dict[str, any]):
}
target_modules = inquirer.checkbox(
message="target_modules:",
choices=[Choice("q_proj", enabled=True),
Choice("k_proj", enabled=True),
Choice("v_proj", enabled=True),
Choice("o_proj", enabled=True),
Choice("gate_proj", enabled=False),
Choice("down_proj", enabled=False),
Choice("up_proj", enabled=False)]
choices=[
Choice("q_proj", enabled=True),
Choice("k_proj", enabled=True),
Choice("v_proj", enabled=True),
Choice("o_proj", enabled=True),
Choice("gate_proj", enabled=False),
Choice("down_proj", enabled=False),
Choice("up_proj", enabled=False),
],
).execute()
for target in target_modules:
adapter_conf["target_modules"][target] = True
Expand All @@ -152,7 +140,8 @@ def create_adapter():

name = inquirer.text(
message="name:",
validate=validator.EmptyInputValidator("Input should not be empty")).execute()
validate=validator.EmptyInputValidator("Input should not be empty"),
).execute()
adapter_conf["name"] = name

adapter_conf = adapter_type_set(adapter_conf)
Expand Down
75 changes: 42 additions & 33 deletions mlora/cli/dataset.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import json

import requests
from InquirerPy import inquirer, separator
from rich import print
from rich.table import Table
from rich.box import ASCII
from rich.table import Table

from .setting import url
from .file import list_file
from .setting import url


def list_dataset(obj):
ret = requests.get(url() + "/dataset")
ret = json.loads(ret.text)
ret_items = json.loads(ret.text)

table = Table(show_header=True, show_lines=True, box=ASCII)
table.add_column("name", justify="center")
Expand All @@ -22,21 +23,22 @@ def list_dataset(obj):

obj.ret_ = []

for item in ret:
item = json.loads(item)
table.add_row(item["name"],
item["data_name"],
item["prompt_name"],
item["prompt_type"],
item["preprocess"])
for ret_item in ret_items:
item = json.loads(ret_item)
table.add_row(
item["name"],
item["data_name"],
item["prompt_name"],
item["prompt_type"],
item["preprocess"],
)
obj.ret_.append(item["name"])

obj.pret_ = table


def create_dataset(obj):
name = inquirer.text(
message="name:").execute()
name = inquirer.text(message="name:").execute()

list_file(obj, "data")
all_train_data = [item["name"] for item in obj.ret_]
Expand All @@ -51,26 +53,33 @@ def create_dataset(obj):
return

use_train = inquirer.select(
message="train data file:", choices=[separator.Separator(),
*all_train_data,
separator.Separator()]).execute()
message="train data file:",
choices=[separator.Separator(), *all_train_data, separator.Separator()],
).execute()
use_prompt = inquirer.select(
message="prompt template file:", choices=[separator.Separator(),
*all_prompt,
separator.Separator()]).execute()
message="prompt template file:",
choices=[separator.Separator(), *all_prompt, separator.Separator()],
).execute()
use_preprocess = inquirer.select(
message="data preprocessing:", choices=[separator.Separator(),
"default",
"shuffle",
"sort",
separator.Separator()]).execute()

ret = requests.post(url() + "/dataset", json={
"name": name,
"data_name": use_train,
"prompt_name": use_prompt,
"preprocess": use_preprocess
})
message="data preprocessing:",
choices=[
separator.Separator(),
"default",
"shuffle",
"sort",
separator.Separator(),
],
).execute()

ret = requests.post(
url() + "/dataset",
json={
"name": name,
"data_name": use_train,
"prompt_name": use_prompt,
"preprocess": use_preprocess,
},
)

print(json.loads(ret.text))

Expand All @@ -84,9 +93,9 @@ def showcase_dataset(obj):
return

use_dataset = inquirer.select(
message="dataset name:", choices=[separator.Separator(),
*all_dataset,
separator.Separator()]).execute()
message="dataset name:",
choices=[separator.Separator(), *all_dataset, separator.Separator()],
).execute()

ret = requests.get(url() + f"/showcase?name={use_dataset}")
ret = json.loads(ret.text)
Expand Down
7 changes: 4 additions & 3 deletions mlora/cli/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json

import requests
from rich import print
from rich.table import Table
from rich.box import ASCII
from rich.table import Table

from .setting import url

Expand All @@ -13,12 +14,12 @@ def help_dispatcher(_):

def do_dispatcher(*_):
ret = requests.get(url() + "/dispatcher")
ret = json.loads(ret.text)
ret_text = json.loads(ret.text)

table = Table(show_header=True, show_lines=True, box=ASCII)
table.add_column("Item", justify="center")
table.add_column("Value", justify="center")
for item, value in ret.items():
for item, value in ret_text.items():
table.add_row(item, str(value))

print(table)
Loading

0 comments on commit a9bb7a3

Please sign in to comment.