From cd3a960f81579b4173ae79aa14d075e94651e8ce Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 26 Apr 2024 06:41:35 +0800 Subject: [PATCH] add llava to llamaboard --- examples/inference/web_demo.sh | 1 + src/llmtuner/extras/constants.py | 15 +++++++++++++++ src/llmtuner/webui/chatter.py | 1 + src/llmtuner/webui/common.py | 5 +++++ src/llmtuner/webui/components/chatbot.py | 7 ++++--- src/llmtuner/webui/components/export.py | 3 +++ src/llmtuner/webui/components/infer.py | 14 ++++++++++---- src/llmtuner/webui/components/top.py | 16 ++++++++++------ src/llmtuner/webui/engine.py | 1 + src/llmtuner/webui/interface.py | 4 ++-- src/llmtuner/webui/locales.py | 11 +++++++++++ src/llmtuner/webui/manager.py | 1 + src/llmtuner/webui/runner.py | 2 ++ 13 files changed, 66 insertions(+), 15 deletions(-) diff --git a/examples/inference/web_demo.sh b/examples/inference/web_demo.sh index 201be2b497..8d6ed09db7 100644 --- a/examples/inference/web_demo.sh +++ b/examples/inference/web_demo.sh @@ -1,4 +1,5 @@ #!/bin/bash +# add `--visual_inputs True` to load MLLM CUDA_VISIBLE_DEVICES=0 python ../../src/web_demo.py \ --model_name_or_path meta-llama/Llama-2-7b-hf \ diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index 9f7d5c4684..269905300b 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -28,6 +28,8 @@ METHODS = ["full", "freeze", "lora"] +MLLM_LIST = ["LLaVA1.5"] + MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"] PEFT_METHODS = ["lora"] @@ -566,6 +568,19 @@ def register_model_group( ) +register_model_group( + models={ + "LLaVA1.5-7B-Chat": { + DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf", + }, + "LLaVA1.5-13B-Chat": { + DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf", + }, + }, + template="vicuna", +) + + register_model_group( models={ "Mistral-7B-v0.1": { diff --git a/src/llmtuner/webui/chatter.py b/src/llmtuner/webui/chatter.py index 5aa8f56311..a92f6ef7ba 100644 --- a/src/llmtuner/webui/chatter.py +++ b/src/llmtuner/webui/chatter.py @@ -79,6 +79,7 @@ def load_model(self, data) -> Generator[str, None, None]: template=get("top.template"), flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", use_unsloth=(get("top.booster") == "unsloth"), + visual_inputs=get("top.visual_inputs"), rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, infer_backend=get("infer.infer_backend"), ) diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index 659c35c385..9af4c43917 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -9,6 +9,7 @@ DATA_CONFIG, DEFAULT_MODULE, DEFAULT_TEMPLATE, + MLLM_LIST, PEFT_METHODS, STAGES_USE_PAIR_DATA, SUPPORTED_MODELS, @@ -105,6 +106,10 @@ def get_template(model_name: str) -> str: return "default" +def get_visual(model_name: str) -> bool: + return get_prefix(model_name) in MLLM_LIST + + def list_adapters(model_name: str, finetuning_type: str) -> "gr.Dropdown": if finetuning_type not in PEFT_METHODS: return gr.Dropdown(value=[], choices=[], interactive=False) diff --git a/src/llmtuner/webui/components/chatbot.py b/src/llmtuner/webui/components/chatbot.py index e1be1f7b11..15c1fc832b 100644 --- a/src/llmtuner/webui/components/chatbot.py +++ b/src/llmtuner/webui/components/chatbot.py @@ -17,7 +17,7 @@ def create_chat_box( engine: "Engine", visible: bool = False -) -> Tuple["gr.Column", "Component", "Component", Dict[str, "Component"]]: +) -> Tuple["Component", "Component", Dict[str, "Component"]]: with gr.Column(visible=visible) as chat_box: chatbot = gr.Chatbot(show_copy_button=True) messages = gr.State([]) @@ -29,7 +29,7 @@ def create_chat_box( system = gr.Textbox(show_label=False) tools = gr.Textbox(show_label=False, lines=4) - with gr.Column(): + with gr.Column() as image_box: image = gr.Image(type="numpy") query = gr.Textbox(show_label=False, lines=8) @@ -55,13 +55,14 @@ def create_chat_box( clear_btn.click(lambda: ([], []), outputs=[chatbot, messages]) return ( - chat_box, chatbot, messages, dict( + chat_box=chat_box, role=role, system=system, tools=tools, + image_box=image_box, image=image, query=query, submit_btn=submit_btn, diff --git a/src/llmtuner/webui/components/export.py b/src/llmtuner/webui/components/export.py index ebccac25c9..4c2247366b 100644 --- a/src/llmtuner/webui/components/export.py +++ b/src/llmtuner/webui/components/export.py @@ -27,6 +27,7 @@ def save_model( adapter_path: List[str], finetuning_type: str, template: str, + visual_inputs: bool, export_size: int, export_quantization_bit: int, export_quantization_dataset: str, @@ -66,6 +67,7 @@ def save_model( adapter_name_or_path=adapter_name_or_path, finetuning_type=finetuning_type, template=template, + visual_inputs=visual_inputs, export_dir=export_dir, export_hub_model_id=export_hub_model_id or None, export_size=export_size, @@ -105,6 +107,7 @@ def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: engine.manager.get_elem_by_id("top.adapter_path"), engine.manager.get_elem_by_id("top.finetuning_type"), engine.manager.get_elem_by_id("top.template"), + engine.manager.get_elem_by_id("top.visual_inputs"), export_size, export_quantization_bit, export_quantization_dataset, diff --git a/src/llmtuner/webui/components/infer.py b/src/llmtuner/webui/components/infer.py index d565347e5d..970f4629c2 100644 --- a/src/llmtuner/webui/components/infer.py +++ b/src/llmtuner/webui/components/infer.py @@ -28,15 +28,21 @@ def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]: input_elems.update({infer_backend}) elem_dict.update(dict(infer_backend=infer_backend, load_btn=load_btn, unload_btn=unload_btn, info_box=info_box)) - chat_box, chatbot, messages, chat_elems = create_chat_box(engine, visible=False) - elem_dict.update(dict(chat_box=chat_box, **chat_elems)) + chatbot, messages, chat_elems = create_chat_box(engine, visible=False) + elem_dict.update(chat_elems) load_btn.click(engine.chatter.load_model, input_elems, [info_box]).then( - lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box] + lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]] ) unload_btn.click(engine.chatter.unload_model, input_elems, [info_box]).then( lambda: ([], []), outputs=[chatbot, messages] - ).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_box]) + ).then(lambda: gr.Column(visible=engine.chatter.loaded), outputs=[chat_elems["chat_box"]]) + + engine.manager.get_elem_by_id("top.visual_inputs").change( + lambda enabled: gr.Column(visible=enabled), + [engine.manager.get_elem_by_id("top.visual_inputs")], + [chat_elems["image_box"]], + ) return elem_dict diff --git a/src/llmtuner/webui/components/top.py b/src/llmtuner/webui/components/top.py index c67d7cc57a..a75a4d62be 100644 --- a/src/llmtuner/webui/components/top.py +++ b/src/llmtuner/webui/components/top.py @@ -3,7 +3,7 @@ from ...data import templates from ...extras.constants import METHODS, SUPPORTED_MODELS from ...extras.packages import is_gradio_available -from ..common import get_model_path, get_template, list_adapters, save_config +from ..common import get_model_path, get_template, get_visual, list_adapters, save_config from ..utils import can_quantize @@ -30,14 +30,17 @@ def create_top() -> Dict[str, "Component"]: with gr.Accordion(open=False) as advanced_tab: with gr.Row(): - quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none") - template = gr.Dropdown(choices=list(templates.keys()), value="default") - rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none") - booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none") + quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2) + template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=2) + rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3) + booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3) + visual_inputs = gr.Checkbox(scale=1) model_name.change(list_adapters, [model_name, finetuning_type], [adapter_path], queue=False).then( get_model_path, [model_name], [model_path], queue=False - ).then(get_template, [model_name], [template], queue=False) # do not save config since the below line will save + ).then(get_template, [model_name], [template], queue=False).then( + get_visual, [model_name], [visual_inputs], queue=False + ) # do not save config since the below line will save model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False) @@ -59,4 +62,5 @@ def create_top() -> Dict[str, "Component"]: template=template, rope_scaling=rope_scaling, booster=booster, + visual_inputs=visual_inputs, ) diff --git a/src/llmtuner/webui/engine.py b/src/llmtuner/webui/engine.py index b9ee61d2aa..cebac3b90a 100644 --- a/src/llmtuner/webui/engine.py +++ b/src/llmtuner/webui/engine.py @@ -43,6 +43,7 @@ def resume(self): init_dict["train.output_dir"] = {"value": "train_{}".format(get_time())} init_dict["train.config_path"] = {"value": "{}.json".format(get_time())} init_dict["eval.output_dir"] = {"value": "eval_{}".format(get_time())} + init_dict["infer.image_box"] = {"visible": False} if user_config.get("last_model", None): init_dict["top.model_name"] = {"value": user_config["last_model"]} diff --git a/src/llmtuner/webui/interface.py b/src/llmtuner/webui/interface.py index 0359d082b9..abca16c556 100644 --- a/src/llmtuner/webui/interface.py +++ b/src/llmtuner/webui/interface.py @@ -58,8 +58,8 @@ def create_web_demo() -> gr.Blocks: lang = gr.Dropdown(choices=["en", "zh"]) engine.manager.add_elems("top", dict(lang=lang)) - chat_box, _, _, chat_elems = create_chat_box(engine, visible=True) - engine.manager.add_elems("infer", dict(chat_box=chat_box, **chat_elems)) + _, _, chat_elems = create_chat_box(engine, visible=True) + engine.manager.add_elems("infer", chat_elems) demo.load(engine.resume, outputs=engine.manager.get_elem_list(), concurrency_limit=None) lang.change(engine.change_lang, [lang], engine.manager.get_elem_list(), queue=False) diff --git a/src/llmtuner/webui/locales.py b/src/llmtuner/webui/locales.py index 8e93efd6ce..d341c7b6e4 100644 --- a/src/llmtuner/webui/locales.py +++ b/src/llmtuner/webui/locales.py @@ -129,6 +129,17 @@ "label": "加速方式", }, }, + "visual_inputs": { + "en": { + "label": "Visual inputs", + }, + "ru": { + "label": "визуальные входы", + }, + "zh": { + "label": "图像输入", + }, + }, "training_stage": { "en": { "label": "Stage", diff --git a/src/llmtuner/webui/manager.py b/src/llmtuner/webui/manager.py index a67c0995b2..f65fa80466 100644 --- a/src/llmtuner/webui/manager.py +++ b/src/llmtuner/webui/manager.py @@ -60,4 +60,5 @@ def get_base_elems(self) -> Set["Component"]: self._id_to_elem["top.template"], self._id_to_elem["top.rope_scaling"], self._id_to_elem["top.booster"], + self._id_to_elem["top.visual_inputs"], } diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 77d5ea98cd..8054484f1f 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -124,6 +124,7 @@ def _parse_train_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", use_unsloth=(get("top.booster") == "unsloth"), + visual_inputs=get("top.visual_inputs"), dataset_dir=get("train.dataset_dir"), dataset=",".join(get("train.dataset")), cutoff_len=get("train.cutoff_len"), @@ -224,6 +225,7 @@ def _parse_eval_args(self, data: Dict["Component", Any]) -> Dict[str, Any]: rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, flash_attn="fa2" if get("top.booster") == "flashattn2" else "auto", use_unsloth=(get("top.booster") == "unsloth"), + visual_inputs=get("top.visual_inputs"), dataset_dir=get("eval.dataset_dir"), dataset=",".join(get("eval.dataset")), cutoff_len=get("eval.cutoff_len"),