Skip to content

Commit

Permalink
swaps to use newer sample packing for mistral (#1773)
Browse files Browse the repository at this point in the history
* swaps to use newer sample packing for mistral

* fix multipack patch test

* patch the common fa utils

* update for refactor of flash attn unpad

* remove un-needed drop attn mask for mistral

* bump transformers to main to pick up latest mistral fix for 12b and refactor of fa2

* update test
  • Loading branch information
winglian authored Jul 23, 2024
1 parent 985819d commit 87455e7
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 70 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.11.1
transformers==4.42.4
transformers @ git+https://github.com/huggingface/transformers.git@0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf
tokenizers==0.19.1
bitsandbytes==0.43.1
accelerate==0.32.0
Expand Down
10 changes: 10 additions & 0 deletions src/axolotl/monkeypatch/mistral_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# pylint: disable=duplicate-code

import logging
from functools import partial
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -45,6 +46,15 @@ def replace_mistral_attn_with_flash_attn(
)


def patch_mistral_cross_entropy():
from flash_attn.losses.cross_entropy import CrossEntropyLoss

LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)


@torch.jit.script
def _make_sliding_window_causal_mask(
bsz: int,
Expand Down
32 changes: 23 additions & 9 deletions src/axolotl/monkeypatch/multipack.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

SUPPORTED_MULTIPACK_MODEL_TYPES = [
"llama",
"mistral",
"mixtral",
"qwen2",
"qwen2_moe",
Expand All @@ -25,16 +26,35 @@


def patch_for_multipack(model_type, model_name=None):
if model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
elif hasattr(transformers, "modeling_flash_attention_utils"):
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
return

# retain for legacy
if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
elif model_type == "llama":
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"):
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "mistral":
if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"):
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
Expand Down Expand Up @@ -63,12 +83,6 @@ def patch_for_multipack(model_type, model_name=None):
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "jamba":
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")


def patch_remote(model_name, config_name, modeling_name):
Expand Down
79 changes: 41 additions & 38 deletions src/axolotl/monkeypatch/unsloth_.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,48 +99,51 @@ def check_self_attn_is_patchable() -> bool:
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv


def integrate_cross_entropy_loss_patch():
forward = get_forward_code()
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
if model_type == "llama":
forward = get_forward_code()
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"

forward = forward.replace(
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
)
forward = forward.replace(
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
"",
)
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
forward = forward.replace(
"def forward(",
"def fast_cross_entropy_loss_forward(",
1,
)

forward = forward.replace(
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
)
forward = forward.replace(
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
"",
)
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
forward = forward.replace(
"def forward(",
"def fast_cross_entropy_loss_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama

# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)

items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)

exec( # pylint: disable=exec-used # nosec B102
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
globals(),
)
exec( # pylint: disable=exec-used # nosec B102
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
globals(),
)

exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
else:
raise ValueError("Unsupported model type")


def detab_code(code: str) -> Tuple[str, str]:
Expand Down
21 changes: 5 additions & 16 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def load_model(
integrate_cross_entropy_loss_patch,
)

integrate_cross_entropy_loss_patch()
integrate_cross_entropy_loss_patch(model_type="llama")
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora

Expand Down Expand Up @@ -424,31 +424,20 @@ def load_model(
if cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch

integrate_cross_entropy_loss_patch()
integrate_cross_entropy_loss_patch(model_type="llama")

if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora

patch_self_attn_lora()

# Modify mistral derived models
if (
cfg.model_config_type == "mistral"
and cfg.flash_attention
and cfg.sample_packing
):
if cfg.model_config_type == "mistral" and cfg.flash_attn_cross_entropy_loss:
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
replace_mistral_attn_with_flash_attn,
patch_mistral_cross_entropy,
)

LOG.info("patching mistral with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)

if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask

LOG.info("patching _expand_mask")
hijack_expand_mask()
patch_mistral_cross_entropy()

model_kwargs: Dict[str, Any] = {}

Expand Down
4 changes: 1 addition & 3 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)

if (
cfg.is_mistral_derived_model and cfg.flash_attention
) or cfg.model_config_type == "mamba":
if cfg.model_config_type == "mamba":
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset:
Expand Down
8 changes: 5 additions & 3 deletions tests/e2e/patched/test_model_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import unittest

import transformers

from axolotl.common.cli import TrainerCliArgs
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
Expand Down Expand Up @@ -87,9 +89,9 @@ def test_mistral_multipack(self, temp_dir):
normalize_config(cfg)
cli_args = TrainerCliArgs()
tokenizer = load_tokenizer(cfg)
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
load_model(cfg, tokenizer, inference=cli_args.inference)

assert (
"axolotl.monkeypatch.mistral_attn_hijack_flash"
in model.model.layers[0].self_attn.forward.__module__
"torch.jit"
in transformers.modeling_flash_attention_utils._get_unpad_data.__module__ # pylint: disable=protected-access
)

0 comments on commit 87455e7

Please sign in to comment.