diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 3130f86e42..fba48486c6 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -88,6 +88,16 @@ class LoraArguments: ) }, ) + lora_parameters: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Name(s) of nn.Parameters to apply LoRA directly. " + "Use commas to separate multiple parameters. " + "Useful for MoE models with expert parameters." + ) + }, + ) loraplus_lr_ratio: Optional[float] = field( default=None, metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."}, @@ -524,6 +534,7 @@ def split_arg(arg): self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules) self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2 self.lora_target: list[str] = split_arg(self.lora_target) + self.lora_parameters: Optional[list[str]] = split_arg(self.lora_parameters) self.oft_target: list[str] = split_arg(self.oft_target) self.additional_target: Optional[list[str]] = split_arg(self.additional_target) self.galore_target: list[str] = split_arg(self.galore_target) diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index d9522d39dd..10c4f9707f 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -198,8 +198,14 @@ def _setup_lora_tuning( logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) if is_trainable and adapter_to_resume is None: # create new lora weights while training + target_modules = [] + target_parameters = [] if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": - target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower) + if finetuning_args.lora_parameters: # if specified the parameters to be adapted, use them + logger.info_rank0("Using specified LoRA parameters: {}", finetuning_args.lora_parameters) + target_parameters = finetuning_args.lora_parameters + else: + target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower) else: target_modules = finetuning_args.lora_target @@ -235,6 +241,7 @@ def _setup_lora_tuning( "use_rslora": finetuning_args.use_rslora, "use_dora": finetuning_args.use_dora, "modules_to_save": finetuning_args.additional_target, + "target_parameters": target_parameters, } elif finetuning_args.finetuning_type == "oft": peft_kwargs = { diff --git a/src/llamafactory/train/test_utils.py b/src/llamafactory/train/test_utils.py index 6e4c4ffc28..94cc19b27b 100644 --- a/src/llamafactory/train/test_utils.py +++ b/src/llamafactory/train/test_utils.py @@ -44,10 +44,16 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k def check_lora_model(model: "LoraModel") -> tuple[set[str], set[str]]: - linear_modules, extra_modules = set(), set() + linear_modules, linear_parameters, extra_modules = set(), set(), set() for name, param in model.named_parameters(): if any(module in name for module in ["lora_A", "lora_B"]): linear_modules.add(name.split(".lora_", maxsplit=1)[0].split(".")[-1]) + parts = name.split(".") + for i, part in enumerate(parts): + if "lora_" in part: + short_name = parts[i - 1] + "." + parts[-1] + linear_parameters.add(short_name) + break assert param.requires_grad is True assert param.dtype == torch.float32 elif "modules_to_save" in name: @@ -58,8 +64,7 @@ def check_lora_model(model: "LoraModel") -> tuple[set[str], set[str]]: assert param.requires_grad is False assert param.dtype == torch.float16 - return linear_modules, extra_modules - + return linear_modules, linear_parameters, extra_modules def load_train_model(add_valuehead: bool = False, **kwargs) -> "PreTrainedModel": model_args, _, _, finetuning_args, _ = get_train_args(kwargs) diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index 8b7aa6e946..fc62add0b2 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -179,6 +179,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]: use_pissa = gr.Checkbox() lora_target = gr.Textbox(scale=2) additional_target = gr.Textbox(scale=2) + lora_parameters = gr.Textbox(scale=2) input_elems.update( { @@ -192,6 +193,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]: use_pissa, lora_target, additional_target, + lora_parameters, } ) elem_dict.update( @@ -207,6 +209,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]: use_pissa=use_pissa, lora_target=lora_target, additional_target=additional_target, + lora_parameters=lora_parameters, ) ) diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 7051b30e80..73f377120b 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -1323,6 +1323,28 @@ "info": "LoRA 層以外の学習可能なモジュールの名前。複数のモジュールを区切るにはカンマを使用します。", }, }, + "lora_parameters": { + "en": { + "label": "LoRA parameters (optional)", + "info": "Name(s) of parameters to apply LoRA. Use commas to separate multiple parameters.", + }, + "ru": { + "label": "Параметры LoRA (необязательно)", + "info": "Имя(ена) параметров для применения LoRA. Используйте запятые для разделения нескольких параметров.", + }, + "zh": { + "label": "LoRA 参数(可选)", + "info": "要应用 LoRA 的参数名称。使用逗号分隔多个参数。", + }, + "ko": { + "label": "LoRA 매개변수 (선택 사항)", + "info": "LoRA를 적용할 매개변수의 이름입니다. 여러 매개변수를 구분하려면 쉼표를 사용하십시오.", + }, + "ja": { + "label": "LoRA パラメータ (オプション)", + "info": "LoRA を適用するパラメータの名前。複数のパラメータを区切るにはカンマを使用します。", + }, + }, "rlhf_tab": { "en": { "label": "RLHF configurations", diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 0a6fc7c9aa..0d48acdd23 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -212,6 +212,7 @@ def _parse_train_args(self, data: dict["Component", Any]) -> dict[str, Any]: args["pissa_convert"] = get("train.use_pissa") args["lora_target"] = get("train.lora_target") or "all" args["additional_target"] = get("train.additional_target") or None + args["lora_parameters"] = get("train.lora_parameters") or None if args["use_llama_pro"]: args["freeze_trainable_layers"] = get("train.freeze_trainable_layers") diff --git a/tests/model/test_lora.py b/tests/model/test_lora.py index 3d394c33dc..59c5f88ceb 100644 --- a/tests/model/test_lora.py +++ b/tests/model/test_lora.py @@ -63,19 +63,19 @@ def fix_valuehead_cpu_loading(): def test_lora_train_qv_modules(): model = load_train_model(lora_target="q_proj,v_proj", **TRAIN_ARGS) - linear_modules, _ = check_lora_model(model) + linear_modules, _, _ = check_lora_model(model) assert linear_modules == {"q_proj", "v_proj"} def test_lora_train_all_modules(): model = load_train_model(lora_target="all", **TRAIN_ARGS) - linear_modules, _ = check_lora_model(model) + linear_modules, _, _ = check_lora_model(model) assert linear_modules == {"q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"} def test_lora_train_extra_modules(): model = load_train_model(additional_target="embed_tokens,lm_head", **TRAIN_ARGS) - _, extra_modules = check_lora_model(model) + _, _, extra_modules = check_lora_model(model) assert extra_modules == {"embed_tokens", "lm_head"} @@ -91,6 +91,17 @@ def test_lora_train_new_adapters(): compare_model( model, ref_model, diff_keys=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"] ) + +def test_lora_parameters(): + model = load_train_model(lora_parameters="q_proj.weight, k_proj.weight", **TRAIN_ARGS) + _, injected_parameters, _ = check_lora_model(model) + assert injected_parameters == {"q_proj.weight", "k_proj.weight"} + +def test_lora_target_and_parameters_conflicts(): + model = load_train_model(lora_parameters="q_proj.weight",lora_target="q_proj,v_proj", **TRAIN_ARGS) + linear_modules, injected_parameters, _ = check_lora_model(model) + assert injected_parameters == {"q_proj.weight", "v_proj.weight"} + assert linear_modules == {"q_proj", "v_proj"} @pytest.mark.usefixtures("fix_valuehead_cpu_loading")