Skip to content

Commit

Permalink
feat: support str in target_modules for LoraConfig (#39)
Browse files Browse the repository at this point in the history
* build: set dependency to peft>=0.8.0 which enables --target_modules=all-linear

Signed-off-by: Vassilis Vassiliadis <vassilis.vassiliadis@ibm.com>

* feat: support str for --target_modules (includes support for --target_modules="all-linear")

Supports setting LoraConfig(target_modules=None) when using the train() method programmatically

Signed-off-by: Vassilis Vassiliadis <vassilis.vassiliadis@ibm.com>

* fix: target_modules must be a List[] for HfArgumentParser to parse more than 1 command line value

For more information see

https://github.com/huggingface/transformers/blob/2749e479f30ab13235b0b9b4a6bbcf4c3b29a081/src/transformers/hf_argparser.py#L206-L208
Signed-off-by: Vassilis Vassiliadis <vassilis.vassiliadis@ibm.com>

* refactor: move LoraConfig pre-processing from main() into get_hf_peft_config()

Signed-off-by: Vassilis Vassiliadis <vassilis.vassiliadis@ibm.com>

* docs: update docstring of target_modules to reflect new code

Signed-off-by: Vassilis Vassiliadis <vassilis.vassiliadis@ibm.com>

---------

Signed-off-by: Vassilis Vassiliadis <vassilis.vassiliadis@ibm.com>
VassilisVassiliadis authored Feb 13, 2024
1 parent 905248c commit 5b895c4
Showing 4 changed files with 11 additions and 4 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@ tokenizers>=0.13.3
tqdm
trl
ninja
peft
peft>=0.8.0
datasets>=2.15.0
flash-attn
fire
6 changes: 5 additions & 1 deletion tuning/config/peft_config.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,11 @@
class LoraConfig:
r: int = 8
lora_alpha: int = 32
target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"])
target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"], metadata={
"help": "The names of the modules to apply LORA to. LORA selects modules which either completely match or "
"end with one of the strings. If the value is [\"all-linear\"], then LORA selects all linear and Conv1D "
"modules except for the output layer."
})
bias = "none"
lora_dropout: float = 0.05

2 changes: 1 addition & 1 deletion tuning/sft_trainer.py
Original file line number Diff line number Diff line change
@@ -189,7 +189,7 @@ def main(**kwargs):
peft_config.PromptTuningConfig))
parser.add_argument('--peft_method', type=str.lower, choices=['pt', 'lora', None, 'none'], default="pt")
model_args, data_args, training_args, lora_config, prompt_tuning_config, peft_method, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
if peft_method.peft_method =="lora":
if peft_method.peft_method == "lora":
tune_config=lora_config
elif peft_method.peft_method =="pt":
tune_config=prompt_tuning_config
5 changes: 4 additions & 1 deletion tuning/utils/config_utils.py
Original file line number Diff line number Diff line change
@@ -51,7 +51,10 @@ def get_hf_peft_config(task_type, tuning_config):
Return: HF PEFT config or None
"""
if isinstance(tuning_config, peft_config.LoraConfig):
hf_peft_config = LoraConfig(task_type=task_type, **asdict(tuning_config))
lora_config = asdict(tuning_config)
if lora_config["target_modules"] == ["all-linear"]:
lora_config["target_modules"] = "all-linear"
hf_peft_config = LoraConfig(task_type=task_type, **lora_config)
elif isinstance(tuning_config, peft_config.PromptTuningConfig):
hf_peft_config = PromptTuningConfig(task_type=task_type, **asdict(tuning_config))
else:

0 comments on commit 5b895c4

Please sign in to comment.