From 0ee2dfacd6e18c40ea0e24d293442aaefc9c4608 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Sat, 26 Aug 2023 00:23:21 +0900 Subject: [PATCH] Improve BetterTransformer backward compatibility (#1314) improve backward compatibility bettertransformer --- optimum/bettertransformer/models/__init__.py | 9 ++++++++- optimum/bettertransformer/models/attention.py | 15 ++++++++++++--- .../bettertransformer/models/decoder_models.py | 16 ++++++++++++++-- optimum/bettertransformer/transformation.py | 2 +- optimum/utils/import_utils.py | 8 ++++++++ 5 files changed, 43 insertions(+), 7 deletions(-) diff --git a/optimum/bettertransformer/models/__init__.py b/optimum/bettertransformer/models/__init__.py index 35d106b9f5..7ef029bbdd 100644 --- a/optimum/bettertransformer/models/__init__.py +++ b/optimum/bettertransformer/models/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. import warnings -from .attention import _llama_prepare_decoder_attention_mask +from ...utils.import_utils import check_if_transformers_greater from .decoder_models import ( BarkAttentionLayerBetterTransformer, BartAttentionLayerBetterTransformer, @@ -48,6 +48,13 @@ ) +# TODO: remove once we are much higher than 4.31 +if check_if_transformers_greater("4.31"): + from .attention import _llama_prepare_decoder_attention_mask +else: + from ...utils.dummy_bettertransformer_objects import _llama_prepare_decoder_attention_mask + + class BetterTransformerManager: MODEL_MAPPING = { "albert": {"AlbertLayer": AlbertLayerBetterTransformer}, diff --git a/optimum/bettertransformer/models/attention.py b/optimum/bettertransformer/models/attention.py index 462f1b46a8..86ed1cce6e 100644 --- a/optimum/bettertransformer/models/attention.py +++ b/optimum/bettertransformer/models/attention.py @@ -15,10 +15,19 @@ from typing import Optional, Tuple import torch -from transformers.models.llama.modeling_llama import _expand_mask as _llama_expand_mask -from transformers.models.llama.modeling_llama import _make_causal_mask as _llama_make_causal_mask -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv +from ...utils.import_utils import check_if_transformers_greater + + +# TODO: remove once we are much higher than 4.31 +if check_if_transformers_greater("4.31"): + from transformers.models.llama.modeling_llama import _expand_mask as _llama_expand_mask + from transformers.models.llama.modeling_llama import _make_causal_mask as _llama_make_causal_mask + from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv +else: + from ...utils.dummy_bettertransformer_objects import _expand_mask as _llama_expand_mask + from ...utils.dummy_bettertransformer_objects import _make_causal_mask as _llama_make_causal_mask + from ...utils.dummy_bettertransformer_objects import apply_rotary_pos_emb, repeat_kv # TODO (CRITICAL): Layer-wise attention scaling is broken for several archs (see a fix in gpt_bigcode_wrapped_scaled_dot_product). diff --git a/optimum/bettertransformer/models/decoder_models.py b/optimum/bettertransformer/models/decoder_models.py index ab09d6af8c..fc23e1b9b2 100644 --- a/optimum/bettertransformer/models/decoder_models.py +++ b/optimum/bettertransformer/models/decoder_models.py @@ -15,7 +15,6 @@ import torch import torch.nn as nn -from transformers.models.bark.modeling_bark import BarkSelfAttention from transformers.models.bart.modeling_bart import BartAttention from transformers.models.blenderbot.modeling_blenderbot import BlenderbotAttention from transformers.models.bloom.modeling_bloom import BloomAttention @@ -25,13 +24,26 @@ from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoSelfAttention from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXAttention from transformers.models.gptj.modeling_gptj import GPTJAttention -from transformers.models.llama.modeling_llama import LlamaAttention from transformers.models.m2m_100.modeling_m2m_100 import M2M100Attention from transformers.models.marian.modeling_marian import MarianAttention from transformers.models.opt.modeling_opt import OPTAttention from transformers.models.pegasus.modeling_pegasus import PegasusAttention from transformers.models.t5.modeling_t5 import T5Attention +from ...utils.import_utils import check_if_transformers_greater + + +# TODO: remove once we are much higher than 4.31 +if check_if_transformers_greater("4.31"): + from transformers.models.llama.modeling_llama import LlamaAttention +else: + from ...utils.dummy_bettertransformer_objects import LlamaAttention + +if check_if_transformers_greater("4.31"): + from transformers.models.bark.modeling_bark import BarkSelfAttention +else: + from ...utils.dummy_bettertransformer_objects import BarkSelfAttention + from .attention import ( bark_wrapped_scaled_dot_product, bart_forward, diff --git a/optimum/bettertransformer/transformation.py b/optimum/bettertransformer/transformation.py index 54573d253e..2bb4224f08 100644 --- a/optimum/bettertransformer/transformation.py +++ b/optimum/bettertransformer/transformation.py @@ -180,7 +180,7 @@ class BetterTransformer(object): """ @check_if_pytorch_greater( - "1.13.0", + "1.13.99", "Please upgrade PyTorch following https://pytorch.org/get-started/locally/ in order to use BetterTransformer.", ) def transform( diff --git a/optimum/utils/import_utils.py b/optimum/utils/import_utils.py index a08bb1af19..7221baaf0b 100644 --- a/optimum/utils/import_utils.py +++ b/optimum/utils/import_utils.py @@ -173,9 +173,17 @@ def require_numpy_strictly_lower(version: str, message: str): diffusers`. Please note that you may need to restart your runtime after installation. """ +TRANSFORMERS_IMPORT_ERROR = """requires the transformers>={0} library but it was not found in your environment. You can install it with pip: `pip install +-U transformers`. Please note that you may need to restart your runtime after installation. +""" + BACKENDS_MAPPING = OrderedDict( [ ("diffusers", (is_diffusers_available, DIFFUSERS_IMPORT_ERROR)), + ( + "transformers_431", + (lambda: check_if_transformers_greater("4.31"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.31")), + ), ] )