Skip to content

Commit

Permalink
Fix: modelloader handling of model_kwargs load_in*bit (#1999)
Browse files Browse the repository at this point in the history
* fix: load_in_*bit not properly read

* fix: load_*bit check

* fix: typo

* refactor: load * bit handling

* feat: add test dpo lora multi-gpu

* fix: turn off sample packing for dpo

* fix: missing warmup_steps

* fix: test to load in 8bit for lora

* skip 8bit lora on h100, add 4bit lora on h100 to multi gpu tests

* chore: reduce max_steps

---------

Co-authored-by: Wing Lian <[email protected]>
  • Loading branch information
NanoCode012 and winglian authored Oct 30, 2024
1 parent 74db2a1 commit 5c7e891
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 42 deletions.
33 changes: 7 additions & 26 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,9 +640,7 @@ def set_quantization_config(self) -> None:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config
)
elif self.cfg.adapter == "qlora" and (
"load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"]
):
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
bnb_config = {
"load_in_4bit": True,
"llm_int8_threshold": 6.0,
Expand All @@ -665,9 +663,7 @@ def set_quantization_config(self) -> None:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)
elif self.cfg.adapter == "lora" and (
"load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"]
):
elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]:
bnb_config = {
"load_in_8bit": True,
}
Expand All @@ -680,10 +676,8 @@ def set_quantization_config(self) -> None:

# no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
if "load_in_8bit" in self.model_kwargs:
del self.model_kwargs["load_in_8bit"]
if "load_in_4bit" in self.model_kwargs:
del self.model_kwargs["load_in_4bit"]
self.model_kwargs.pop("load_in_8bit", None)
self.model_kwargs.pop("load_in_4bit", None)

def set_attention_config(self) -> None:
"""
Expand Down Expand Up @@ -968,17 +962,10 @@ def prepare_model(self, qlora_fsdp) -> None:
if is_deepspeed_zero3_enabled():
skip_prepare_model_for_kbit_training = True

is_load_in_8bit = (
"load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"]
)
is_load_in_4bit = (
"load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"]
)

if (
not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"]
and (is_load_in_8bit or is_load_in_4bit)
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit)
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
self.model = prepare_model_for_kbit_training(
Expand Down Expand Up @@ -1116,16 +1103,10 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
# ---------------------------------------------------------
# put model to accelerator
# ---------------------------------------------------------
is_load_in_8bit = (
"load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"]
)
is_load_in_4bit = (
"load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"]
)
if (
self.cfg.ddp
and not is_load_in_8bit
and not (self.cfg.rl and is_load_in_4bit)
and not self.cfg.load_in_8bit
and not (self.cfg.rl and self.cfg.load_in_4bit)
and not skip_move_to_device
):
# TODO revaldate this conditional
Expand Down
156 changes: 148 additions & 8 deletions tests/e2e/multigpu/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from axolotl.utils.dict import DictDefault

from ..utils import with_temp_dir
from ..utils import is_hopper, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
Expand Down Expand Up @@ -59,7 +59,7 @@ def test_lora_ddp(self, temp_dir):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
Expand Down Expand Up @@ -116,7 +116,7 @@ def test_lora_ddp_packed(self, temp_dir):
},
],
"num_epochs": 1,
"max_steps": 50,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
Expand Down Expand Up @@ -144,6 +144,146 @@ def test_lora_ddp_packed(self, temp_dir):
]
)

@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
@with_temp_dir
def test_dpo_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 2048,
"sample_packing": False,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"rl": "dpo",
"chat_template": "llama3",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
},
],
"num_epochs": 1,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"warmup_steps": 0,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
}
)

# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))

execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)

@with_temp_dir
def test_dpo_qlora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM-135M",
"sequence_len": 2048,
"sample_packing": False,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"rl": "dpo",
"chat_template": "chatml",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
},
],
"num_epochs": 1,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"warmup_steps": 0,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
}
)

# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))

execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)

@with_temp_dir
def test_fsdp(self, temp_dir):
# pylint: disable=duplicate-code
Expand All @@ -165,7 +305,7 @@ def test_fsdp(self, temp_dir):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
Expand Down Expand Up @@ -231,7 +371,7 @@ def test_fsdp_packed(self, temp_dir):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
Expand Down Expand Up @@ -307,7 +447,7 @@ def test_fsdp_qlora_prequant_packed(self, temp_dir):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
Expand Down Expand Up @@ -373,7 +513,7 @@ def test_ds_zero3_packed(self, temp_dir):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
Expand Down Expand Up @@ -432,7 +572,7 @@ def test_ds_zero3_qlora_packed(self, temp_dir):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/multigpu/test_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_qlora_fsdp_dpo(self, temp_dir):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"warmup_steps": 20,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/patched/test_4d_multipack_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault

from ..utils import require_torch_2_1_1, with_temp_dir
from ..utils import require_torch_2_3_1, with_temp_dir

LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
Expand All @@ -24,7 +24,7 @@ class Test4dMultipackLlama(unittest.TestCase):
Test case for Llama models using 4d attention with multipack
"""

@require_torch_2_1_1
@require_torch_2_3_1
@with_temp_dir
def test_sdp_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code
Expand Down
17 changes: 12 additions & 5 deletions tests/e2e/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from importlib.metadata import version
from pathlib import Path

import torch


def with_temp_dir(test_func):
@wraps(test_func)
Expand All @@ -35,13 +37,18 @@ def most_recent_subdir(path):
return subdir


def require_torch_2_1_1(test_case):
def require_torch_2_3_1(test_case):
"""
Decorator marking a test that requires torch >= 2.1.1
Decorator marking a test that requires torch >= 2.3.1
"""

def is_min_2_1_1():
def is_min_2_3_1():
torch_version = version("torch")
return torch_version >= "2.1.1"
return torch_version >= "2.3.1"

return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)


return unittest.skipUnless(is_min_2_1_1(), "test torch 2.1.1")(test_case)
def is_hopper():
compute_capability = torch.cuda.get_device_capability()
return compute_capability == (9, 0)

0 comments on commit 5c7e891

Please sign in to comment.