-
-
Notifications
You must be signed in to change notification settings - Fork 851
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add tests for merging lora and validating the dtype #1512
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,13 +1,16 @@ | ||||||||||||||
""" | ||||||||||||||
E2E tests for lora llama | ||||||||||||||
""" | ||||||||||||||
|
||||||||||||||
import json | ||||||||||||||
import logging | ||||||||||||||
import os | ||||||||||||||
import unittest | ||||||||||||||
from pathlib import Path | ||||||||||||||
|
||||||||||||||
from axolotl.cli import load_datasets | ||||||||||||||
from transformers.utils import is_torch_bf16_gpu_available | ||||||||||||||
|
||||||||||||||
from axolotl.cli import do_merge_lora, load_datasets | ||||||||||||||
from axolotl.cli.merge_lora import modify_cfg_for_merge | ||||||||||||||
from axolotl.common.cli import TrainerCliArgs | ||||||||||||||
from axolotl.train import train | ||||||||||||||
from axolotl.utils.config import normalize_config | ||||||||||||||
|
@@ -39,11 +42,6 @@ def test_lora(self, temp_dir): | |||||||||||||
"lora_dropout": 0.05, | ||||||||||||||
"lora_target_linear": True, | ||||||||||||||
"val_set_size": 0.1, | ||||||||||||||
"special_tokens": { | ||||||||||||||
"unk_token": "<unk>", | ||||||||||||||
"bos_token": "<s>", | ||||||||||||||
"eos_token": "</s>", | ||||||||||||||
}, | ||||||||||||||
"datasets": [ | ||||||||||||||
{ | ||||||||||||||
"path": "mhenrichsen/alpaca_2k_test", | ||||||||||||||
|
@@ -57,6 +55,7 @@ def test_lora(self, temp_dir): | |||||||||||||
"learning_rate": 0.00001, | ||||||||||||||
"optimizer": "adamw_torch", | ||||||||||||||
"lr_scheduler": "cosine", | ||||||||||||||
"max_steps": 10, | ||||||||||||||
} | ||||||||||||||
) | ||||||||||||||
normalize_config(cfg) | ||||||||||||||
|
@@ -65,3 +64,67 @@ def test_lora(self, temp_dir): | |||||||||||||
|
||||||||||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) | ||||||||||||||
assert (Path(temp_dir) / "adapter_model.bin").exists() | ||||||||||||||
|
||||||||||||||
@with_temp_dir | ||||||||||||||
def test_lora_merge(self, temp_dir): | ||||||||||||||
# pylint: disable=duplicate-code | ||||||||||||||
cfg = DictDefault( | ||||||||||||||
{ | ||||||||||||||
"base_model": "JackFram/llama-68m", | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, sometimes, this issue can occur for different model types. For ex, previous llama merge was fine, but mistral was not. Do we need to test this for other arch? |
||||||||||||||
"tokenizer_type": "LlamaTokenizer", | ||||||||||||||
"sequence_len": 1024, | ||||||||||||||
"load_in_8bit": True, | ||||||||||||||
"adapter": "lora", | ||||||||||||||
"lora_r": 32, | ||||||||||||||
"lora_alpha": 64, | ||||||||||||||
"lora_dropout": 0.05, | ||||||||||||||
"lora_target_linear": True, | ||||||||||||||
"val_set_size": 0.1, | ||||||||||||||
"datasets": [ | ||||||||||||||
{ | ||||||||||||||
"path": "mhenrichsen/alpaca_2k_test", | ||||||||||||||
"type": "alpaca", | ||||||||||||||
}, | ||||||||||||||
], | ||||||||||||||
"num_epochs": 2, | ||||||||||||||
"micro_batch_size": 8, | ||||||||||||||
"gradient_accumulation_steps": 1, | ||||||||||||||
"output_dir": temp_dir, | ||||||||||||||
"learning_rate": 0.00001, | ||||||||||||||
"optimizer": "adamw_torch", | ||||||||||||||
"lr_scheduler": "cosine", | ||||||||||||||
"max_steps": 10, | ||||||||||||||
"bf16": "auto", | ||||||||||||||
} | ||||||||||||||
) | ||||||||||||||
normalize_config(cfg) | ||||||||||||||
cli_args = TrainerCliArgs() | ||||||||||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) | ||||||||||||||
|
||||||||||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think you need to train a model, maybe a tiny adapter can be uploaded to HF which we use for merge? |
||||||||||||||
assert (Path(temp_dir) / "adapter_model.bin").exists() | ||||||||||||||
|
||||||||||||||
cfg.lora_model_dir = cfg.output_dir | ||||||||||||||
cfg.load_in_4bit = False | ||||||||||||||
cfg.load_in_8bit = False | ||||||||||||||
cfg.flash_attention = False | ||||||||||||||
cfg.deepspeed = None | ||||||||||||||
cfg.fsdp = None | ||||||||||||||
Comment on lines
+107
to
+112
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be excluded as the modify_cfg_for_merge should've set it?
Suggested change
|
||||||||||||||
|
||||||||||||||
cfg = modify_cfg_for_merge(cfg) | ||||||||||||||
cfg.merge_lora = True | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's move this setting inside the modify_cfg function as well. |
||||||||||||||
|
||||||||||||||
cli_args = TrainerCliArgs(merge_lora=True) | ||||||||||||||
|
||||||||||||||
do_merge_lora(cfg=cfg, cli_args=cli_args) | ||||||||||||||
assert (Path(temp_dir) / "merged/pytorch_model.bin").exists() | ||||||||||||||
|
||||||||||||||
with open( | ||||||||||||||
Path(temp_dir) / "merged/config.json", "r", encoding="utf-8" | ||||||||||||||
) as f_handle: | ||||||||||||||
config = f_handle.read() | ||||||||||||||
config = json.loads(config) | ||||||||||||||
if is_torch_bf16_gpu_available(): | ||||||||||||||
assert config["torch_dtype"] == "bfloat16" | ||||||||||||||
else: | ||||||||||||||
assert config["torch_dtype"] == "float16" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the above section already sets these properties, is it necessary to set it again below?