Skip to content

Commit

Permalink
use flash_attn xentropy when available (#525)
Browse files Browse the repository at this point in the history
* use flash_attn xentropy when available

* log when xentropy is not found
  • Loading branch information
tmm1 authored Sep 4, 2023
1 parent 44454ae commit 5fe30b1
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/axolotl/monkeypatch/llama_attn_hijack_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py

import logging
import warnings
from functools import partial
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -33,6 +35,9 @@
)


LOG = logging.getLogger("axolotl")


def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
Expand All @@ -44,6 +49,18 @@ def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
llama_model_forward
)

try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss

LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
except ImportError:
LOG.info(
"optimized flash-attention CrossEntropyLoss not found (run `pip install git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy`)"
)


# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
Expand Down

0 comments on commit 5fe30b1

Please sign in to comment.