Skip to content

Commit

Permalink
Improve BetterTransformer backward compatibility (huggingface#1314)
Browse files Browse the repository at this point in the history
improve backward compatibility bettertransformer
  • Loading branch information
fxmarty committed Aug 25, 2023
1 parent d7d17eb commit 0ee2dfa
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 7 deletions.
9 changes: 8 additions & 1 deletion optimum/bettertransformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import warnings

from .attention import _llama_prepare_decoder_attention_mask
from ...utils.import_utils import check_if_transformers_greater
from .decoder_models import (
BarkAttentionLayerBetterTransformer,
BartAttentionLayerBetterTransformer,
Expand Down Expand Up @@ -48,6 +48,13 @@
)


# TODO: remove once we are much higher than 4.31
if check_if_transformers_greater("4.31"):
from .attention import _llama_prepare_decoder_attention_mask
else:
from ...utils.dummy_bettertransformer_objects import _llama_prepare_decoder_attention_mask


class BetterTransformerManager:
MODEL_MAPPING = {
"albert": {"AlbertLayer": AlbertLayerBetterTransformer},
Expand Down
15 changes: 12 additions & 3 deletions optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,19 @@
from typing import Optional, Tuple

import torch
from transformers.models.llama.modeling_llama import _expand_mask as _llama_expand_mask
from transformers.models.llama.modeling_llama import _make_causal_mask as _llama_make_causal_mask
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv

from ...utils.import_utils import check_if_transformers_greater


# TODO: remove once we are much higher than 4.31
if check_if_transformers_greater("4.31"):
from transformers.models.llama.modeling_llama import _expand_mask as _llama_expand_mask
from transformers.models.llama.modeling_llama import _make_causal_mask as _llama_make_causal_mask
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
else:
from ...utils.dummy_bettertransformer_objects import _expand_mask as _llama_expand_mask
from ...utils.dummy_bettertransformer_objects import _make_causal_mask as _llama_make_causal_mask
from ...utils.dummy_bettertransformer_objects import apply_rotary_pos_emb, repeat_kv

# TODO (CRITICAL): Layer-wise attention scaling is broken for several archs (see a fix in gpt_bigcode_wrapped_scaled_dot_product).

Expand Down
16 changes: 14 additions & 2 deletions optimum/bettertransformer/models/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import torch
import torch.nn as nn
from transformers.models.bark.modeling_bark import BarkSelfAttention
from transformers.models.bart.modeling_bart import BartAttention
from transformers.models.blenderbot.modeling_blenderbot import BlenderbotAttention
from transformers.models.bloom.modeling_bloom import BloomAttention
Expand All @@ -25,13 +24,26 @@
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoSelfAttention
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXAttention
from transformers.models.gptj.modeling_gptj import GPTJAttention
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.m2m_100.modeling_m2m_100 import M2M100Attention
from transformers.models.marian.modeling_marian import MarianAttention
from transformers.models.opt.modeling_opt import OPTAttention
from transformers.models.pegasus.modeling_pegasus import PegasusAttention
from transformers.models.t5.modeling_t5 import T5Attention

from ...utils.import_utils import check_if_transformers_greater


# TODO: remove once we are much higher than 4.31
if check_if_transformers_greater("4.31"):
from transformers.models.llama.modeling_llama import LlamaAttention
else:
from ...utils.dummy_bettertransformer_objects import LlamaAttention

if check_if_transformers_greater("4.31"):
from transformers.models.bark.modeling_bark import BarkSelfAttention
else:
from ...utils.dummy_bettertransformer_objects import BarkSelfAttention

from .attention import (
bark_wrapped_scaled_dot_product,
bart_forward,
Expand Down
2 changes: 1 addition & 1 deletion optimum/bettertransformer/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class BetterTransformer(object):
"""

@check_if_pytorch_greater(
"1.13.0",
"1.13.99",
"Please upgrade PyTorch following https://pytorch.org/get-started/locally/ in order to use BetterTransformer.",
)
def transform(
Expand Down
8 changes: 8 additions & 0 deletions optimum/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,17 @@ def require_numpy_strictly_lower(version: str, message: str):
diffusers`. Please note that you may need to restart your runtime after installation.
"""

TRANSFORMERS_IMPORT_ERROR = """requires the transformers>={0} library but it was not found in your environment. You can install it with pip: `pip install
-U transformers`. Please note that you may need to restart your runtime after installation.
"""

BACKENDS_MAPPING = OrderedDict(
[
("diffusers", (is_diffusers_available, DIFFUSERS_IMPORT_ERROR)),
(
"transformers_431",
(lambda: check_if_transformers_greater("4.31"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.31")),
),
]
)

Expand Down

0 comments on commit 0ee2dfa

Please sign in to comment.