Skip to content

Commit

Permalink
add llava to llamaboard
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Apr 25, 2024
1 parent e83e2fa commit cd3a960
Show file tree
Hide file tree
Showing 13 changed files with 66 additions and 15 deletions.
1 change: 1 addition & 0 deletions examples/inference/web_demo.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/bin/bash
# add `--visual_inputs True` to load MLLM

CUDA_VISIBLE_DEVICES=0 python ../../src/web_demo.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
Expand Down
15 changes: 15 additions & 0 deletions src/llmtuner/extras/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

METHODS = ["full", "freeze", "lora"]

MLLM_LIST = ["LLaVA1.5"]

MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"]

PEFT_METHODS = ["lora"]
Expand Down Expand Up @@ -566,6 +568,19 @@ def register_model_group(
)


register_model_group(
models={
"LLaVA1.5-7B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
},
"LLaVA1.5-13B-Chat": {
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
},
},
template="vicuna",
)


register_model_group(
models={
"Mistral-7B-v0.1": {
Expand Down
1 change: 1 addition & 0 deletions src/llmtuner/webui/chatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def load_model(self, data) -> Generator[str, None, None]:
template=get("top.template"),
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
visual_inputs=get("top.visual_inputs"),
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
infer_backend=get("infer.infer_backend"),
)
Expand Down
5 changes: 5 additions & 0 deletions src/llmtuner/webui/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
DATA_CONFIG,
DEFAULT_MODULE,
DEFAULT_TEMPLATE,
MLLM_LIST,
PEFT_METHODS,
STAGES_USE_PAIR_DATA,
SUPPORTED_MODELS,
Expand Down Expand Up @@ -105,6 +106,10 @@ def get_template(model_name: str) -> str:
return "default"


def get_visual(model_name: str) -> bool:
return get_prefix(model_name) in MLLM_LIST


def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown":
if finetuning_type not in PEFT_METHODS:
return gr.Dropdown(value=[], choices=[], interactive=False)
Expand Down
7 changes: 4 additions & 3 deletions src/llmtuner/webui/components/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

def create_chat_box(
engine: "Engine", visible: bool = False
) -> Tuple["gr.Column", "Component", "Component", Dict[str, "Component"]]:
) -> Tuple["Component", "Component", Dict[str, "Component"]]:
with gr.Column(visible=visible) as chat_box:
chatbot = gr.Chatbot(show_copy_button=True)
messages = gr.State([])
Expand All @@ -29,7 +29,7 @@ def create_chat_box(
system = gr.Textbox(show_label=False)
tools = gr.Textbox(show_label=False, lines=4)

with gr.Column():
with gr.Column() as image_box:
image = gr.Image(type="numpy")

query = gr.Textbox(show_label=False, lines=8)
Expand All @@ -55,13 +55,14 @@ def create_chat_box(
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])

return (
chat_box,
chatbot,
messages,
dict(
chat_box=chat_box,
role=role,
system=system,
tools=tools,
image_box=image_box,
image=image,
query=query,
submit_btn=submit_btn,
Expand Down
3 changes: 3 additions & 0 deletions src/llmtuner/webui/components/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def save_model(
adapter_path: List[str],
finetuning_type: str,
template: str,
visual_inputs: bool,
export_size: int,
export_quantization_bit: int,
export_quantization_dataset: str,
Expand Down Expand Up @@ -66,6 +67,7 @@ def save_model(
adapter_name_or_path=adapter_name_or_path,
finetuning_type=finetuning_type,
template=template,
visual_inputs=visual_inputs,
export_dir=export_dir,
export_hub_model_id=export_hub_model_id or None,
export_size=export_size,
Expand Down Expand Up @@ -105,6 +107,7 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
engine.manager.get_elem_by_id("top.adapter_path"),
engine.manager.get_elem_by_id("top.finetuning_type"),
engine.manager.get_elem_by_id("top.template"),
engine.manager.get_elem_by_id("top.visual_inputs"),
export_size,
export_quantization_bit,
export_quantization_dataset,
Expand Down
14 changes: 10 additions & 4 deletions src/llmtuner/webui/components/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,21 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
input_elems.update({infer_backend})
elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))

chat_box, chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
chatbot, messages, chat_elems = create_chat_box(engine, visible=False)
elem_dict.update(chat_elems)

load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then(
lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box]
lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]]
)

unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then(
lambda: ([], []), outputs=[chatbot, messages]
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box])
).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]])

engine.manager.get_elem_by_id("top.visual_inputs").change(
lambda enabled: gr.Column(visible=enabled),
[engine.manager.get_elem_by_id("top.visual_inputs")],
[chat_elems["image_box"]],
)

return elem_dict
16 changes: 10 additions & 6 deletions src/llmtuner/webui/components/top.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ...data import templates
from ...extras.constants import METHODS, SUPPORTED_MODELS
from ...extras.packages import is_gradio_available
from ..common import get_model_path, get_template, list_adapters, save_config
from ..common import get_model_path, get_template, get_visual, list_adapters, save_config
from ..utils import can_quantize


Expand All @@ -30,14 +30,17 @@ def create_top() -> Dict[str, "Component"]:

with gr.Accordion(open=False) as advanced_tab:
with gr.Row():
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none")
template = gr.Dropdown(choices=list(templates.keys()), value="default")
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none")
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2)
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=2)
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3)
visual_inputs = gr.Checkbox(scale=1)

model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then(
get_model_path, [model_name], [model_path], queue=False
).then(get_template, [model_name], [template], queue=False) # do not save config since the below line will save
).then(get_template, [model_name], [template], queue=False).then(
get_visual, [model_name], [visual_inputs], queue=False
) # do not save config since the below line will save

model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)

Expand All @@ -59,4 +62,5 @@ def create_top() -> Dict[str, "Component"]:
template=template,
rope_scaling=rope_scaling,
booster=booster,
visual_inputs=visual_inputs,
)
1 change: 1 addition & 0 deletions src/llmtuner/webui/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def resume(self):
init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())}
init_dict["train.config_path"] = {"value": "{}.json".format(get_time())}
init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())}
init_dict["infer.image_box"] = {"visible": False}

if user_config.get("last_model", None):
init_dict["top.model_name"] = {"value": user_config["last_model"]}
Expand Down
4 changes: 2 additions & 2 deletions src/llmtuner/webui/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def create_web_demo() -> gr.Blocks:
lang = gr.Dropdown(choices=["en", "zh"])
engine.manager.add_elems("top", dict(lang=lang))

chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
engine.manager.add_elems("infer", dict(chat_box=chat_box, **chat_elems))
_, _, chat_elems = create_chat_box(engine, visible=True)
engine.manager.add_elems("infer", chat_elems)

demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None)
lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False)
Expand Down
11 changes: 11 additions & 0 deletions src/llmtuner/webui/locales.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,17 @@
"label": "加速方式",
},
},
"visual_inputs": {
"en": {
"label": "Visual inputs",
},
"ru": {
"label": "визуальные входы",
},
"zh": {
"label": "图像输入",
},
},
"training_stage": {
"en": {
"label": "Stage",
Expand Down
1 change: 1 addition & 0 deletions src/llmtuner/webui/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,5 @@ def get_base_elems(self) -> Set["Component"]:
self._id_to_elem["top.template"],
self._id_to_elem["top.rope_scaling"],
self._id_to_elem["top.booster"],
self._id_to_elem["top.visual_inputs"],
}
2 changes: 2 additions & 0 deletions src/llmtuner/webui/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
visual_inputs=get("top.visual_inputs"),
dataset_dir=get("train.dataset_dir"),
dataset=",".join(get("train.dataset")),
cutoff_len=get("train.cutoff_len"),
Expand Down Expand Up @@ -224,6 +225,7 @@ def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]:
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto",
use_unsloth=(get("top.booster") == "unsloth"),
visual_inputs=get("top.visual_inputs"),
dataset_dir=get("eval.dataset_dir"),
dataset=",".join(get("eval.dataset")),
cutoff_len=get("eval.cutoff_len"),
Expand Down

0 comments on commit cd3a960

Please sign in to comment.