Skip to content

Commit

Permalink
support for llama multipack using updated code/patches (#1754)
Browse files Browse the repository at this point in the history
* support for llama multipack using updated code/patches

* also support unsloth patches

* incorrect arg

* add config validation for unsloth

* add missing return to validation

* add another missing return to validation
  • Loading branch information
winglian committed Jul 16, 2024
1 parent cfc533a commit 5f58555
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 21 deletions.
50 changes: 29 additions & 21 deletions src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,33 @@ def replace_llama_qkv_with_fused(model):
set_module_name(model, name, qkv)


def patch_llama_cross_entropy():
from flash_attn.losses.cross_entropy import CrossEntropyLoss

LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)


def patch_llama_rms_norm():
try:
from flash_attn.ops.rms_norm import RMSNorm

class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""

def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)

LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.warning(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)


def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False,
Expand All @@ -104,30 +131,11 @@ def replace_llama_attn_with_flash_attn(

# skip only if explicitly disabled
if cross_entropy:
from flash_attn.losses.cross_entropy import CrossEntropyLoss

LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
patch_llama_cross_entropy()

# skip only if explicitly disabled
if rms_norm:
try:
from flash_attn.ops.rms_norm import RMSNorm

class LlamaRMSNorm(RMSNorm):
"""Patched LLamaRMSNorm"""

def __init__(self, hidden_size, eps=1e-6):
super().__init__(hidden_size, eps=eps)

LOG.info("patching with flash_attn.ops.rms_norm")
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
except ImportError:
LOG.warning(
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
)
patch_llama_rms_norm()


class FusedAttention(LlamaAttention):
Expand Down
5 changes: 5 additions & 0 deletions src/axolotl/monkeypatch/multipack.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from axolotl.monkeypatch.utils import get_unpad_data

SUPPORTED_MULTIPACK_MODEL_TYPES = [
"llama",
"mixtral",
"qwen2",
"qwen2_moe",
Expand All @@ -30,6 +31,10 @@ def patch_for_multipack(model_type, model_name=None):
)
if is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
elif model_type == "llama":
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
Expand Down
40 changes: 40 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,31 @@ def check_dataset_or_pretraining_dataset(cls, data):
raise ValueError("either datasets or pretraining_dataset is required")
return data

@model_validator(mode="before")
@classmethod
def check_xentropy_patch_conflicts(cls, data):
if data.get("flash_attn_cross_entropy") and data.get(
"unsloth_cross_entropy_loss"
):
raise ValueError(
"flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled"
)
return data

@model_validator(mode="before")
@classmethod
def check_qlora_unsloth(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
if data.get("adapter") == "lora" or data.get("load_in_8bit"):
raise ValueError(
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA"
)
return data


class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""
Expand Down Expand Up @@ -1163,3 +1188,18 @@ def check_fsdp_deepspeed(cls, data):
if data.get("deepspeed") and data.get("fsdp"):
raise ValueError("deepspeed and fsdp cannot be used together.")
return data

@model_validator(mode="before")
@classmethod
def check_multigpu_unsloth(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
capabilities = data.get("capabilities")
if capabilities and capabilities.get("num_gpus") > 1:
raise ValueError(
"unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training."
)
return data
21 changes: 21 additions & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,27 @@ def load_model(
and cfg.sample_packing
):
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)

if cfg.is_llama_derived_model:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_llama_cross_entropy,
patch_llama_rms_norm,
)

if cfg.flash_attn_cross_entropy:
patch_llama_cross_entropy()
if cfg.flash_attn_rms_norm:
patch_llama_rms_norm()
if cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import (
integrate_cross_entropy_loss_patch,
)

integrate_cross_entropy_loss_patch()
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora

patch_self_attn_lora()
elif cfg.is_llama_derived_model:
# Modify all llama derived models in one block

Expand Down

0 comments on commit 5f58555

Please sign in to comment.