From 5190280427e5599873f6c5a392c3fe867d3bdea7 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 16 Jan 2025 12:47:29 +0100 Subject: [PATCH 01/26] test --- setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index ec15277f18..37194d8953 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,7 @@ "datasets>=1.2.1", "evaluate", "protobuf>=3.20.1", - "transformers>=4.36,<4.48.0", + "transformers>=4.36,<4.49.0", ], "onnxruntime-gpu": [ "onnx", @@ -59,13 +59,13 @@ "evaluate", "protobuf>=3.20.1", "accelerate", # ORTTrainer requires it. - "transformers>=4.36,<4.48.0", + "transformers>=4.36,<4.49.0", ], "exporters": [ "onnx", "onnxruntime", "timm", - "transformers>=4.36,<4.48.0", + "transformers>=4.36,<4.49.0", ], "exporters-gpu": [ "onnx", From 6a03d7657171f2233506b65fb9f5be776cbe46c2 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Jan 2025 10:24:35 +0100 Subject: [PATCH 02/26] testing tensor cache x) --- optimum/exporters/onnx/_tensor_cache.py | 92 +++++++++++++++++++++++++ optimum/exporters/onnx/model_patcher.py | 63 +++++++++++++---- 2 files changed, 141 insertions(+), 14 deletions(-) create mode 100644 optimum/exporters/onnx/_tensor_cache.py diff --git a/optimum/exporters/onnx/_tensor_cache.py b/optimum/exporters/onnx/_tensor_cache.py new file mode 100644 index 0000000000..43192e8408 --- /dev/null +++ b/optimum/exporters/onnx/_tensor_cache.py @@ -0,0 +1,92 @@ +from typing import Any, Dict, Optional, Tuple + +import torch +from transformers.cache_utils import logger + + +# The same as transformers.cache_utils.Cache but iherits from torch.Tensor instead of torch.nn.Module +class Cache(torch.Tensor): + """ + Base, abstract class for all caches. The actual data structure is specific to each subclass. + """ + + def __init__(self): + super().__init__() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. These are specific to each subclass and allow new types of + cache to be created. + + Return: + A tuple containing the updated key and value states. + """ + raise NotImplementedError("Make sure to implement `update` in a subclass.") + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + + # Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length" + # Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles + # infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so + # we change naming to be more explicit + def get_max_length(self) -> Optional[int]: + logger.warning_once( + "`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. " + "Calling `get_max_cache()` will raise error from v4.48" + ) + return self.get_max_cache_shape() + + def get_max_cache_shape(self) -> Optional[int]: + """Returns the maximum sequence length (i.e. max capacity) of the cache object""" + raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") + + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache.""" + # Cache without size limit -> all cache is usable + # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache + # length, we will need to evict part of the cache (and thus not all cache is usable) + max_length = self.get_max_cache_shape() + previous_seq_length = self.get_seq_length(layer_idx) + if max_length is not None and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx] != []: + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + if self.value_cache[layer_idx] != []: + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + @property + def seen_tokens(self): + logger.warning_once( + "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " + "model input instead." + ) + if hasattr(self, "_seen_tokens"): + return self._seen_tokens + else: + return None diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 80293e7b95..5a28452baa 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import transformers +import transformers.cache_utils from packaging import version from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet from transformers.utils import is_torch_available @@ -31,6 +32,7 @@ from ...configuration_utils import _transformers_version from ...utils import logging +from ._tensor_cache import Cache as PatchedCache if _transformers_version > version.parse("4.34.99"): @@ -49,6 +51,7 @@ from .base import OnnxConfig + logger = logging.get_logger(__name__) @@ -158,6 +161,7 @@ def onnx_compatible_unfold(input_tensor, dimension, size, step): UNSUPPORTED_OPS_PATCHING_SPEC = [PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold)] +CACHE_PATCHING_SPEC = [PatchingSpec(transformers.cache_utils, "Cache", PatchedCache, transformers.cache_utils.Cache)] class ModelPatcher: @@ -171,6 +175,7 @@ def __init__( patching_specs = config.PATCHING_SPECS or [] patching_specs.extend(UNSUPPORTED_OPS_PATCHING_SPEC) + patching_specs.extend(CACHE_PATCHING_SPEC) self._patching_specs = [] for spec in patching_specs: @@ -194,11 +199,32 @@ def __init__( @functools.wraps(self.orig_forward) def patched_forward(*args, **kwargs): + from transformers.cache_utils import DynamicCache, EncoderDecoderCache + signature = inspect.signature(self.orig_forward) args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs) + if kwargs.get("past_key_values") is not None: + if len(kwargs["past_key_values"][0]) == 2: + kwargs["past_key_values"] = DynamicCache.from_legacy_cache(kwargs["past_key_values"]) + elif len(kwargs["past_key_values"][0]) == 4: + kwargs["past_key_values"] = EncoderDecoderCache.from_legacy_cache(kwargs["past_key_values"]) + + elif any(isinstance(arg, (list, tuple)) for arg in args): + for i, arg in enumerate(args): + if isinstance(arg, (list, tuple)): + if len(arg[0]) == 2: + args[i] = DynamicCache.from_legacy_cache(arg) + elif len(arg[0]) == 4: + args[i] = EncoderDecoderCache.from_legacy_cache(arg) + outputs = self.orig_forward(*args, **kwargs) + if "past_key_values" in outputs and isinstance( + outputs["past_key_values"], (DynamicCache, EncoderDecoderCache) + ): + outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() + # This code block handles different cases of the filterd_outputs input to align it with the expected # format of outputs. It is common for the output type of a model to vary, such as tensor, list, # tuple, etc. For Transformers models, the output is encapsulated in a ModelOutput object that @@ -230,6 +256,7 @@ def patched_forward(*args, **kwargs): filterd_outputs[name] = outputs name = list(config.outputs.keys())[0] filterd_outputs[name] = outputs + return filterd_outputs self.patched_forward = patched_forward @@ -833,14 +860,22 @@ def patched_forward( class SentenceTransformersTransformerPatcher(ModelPatcher): def __enter__(self): super().__enter__() - if _transformers_version >= version.parse("4.42") and self.real_config._config.model_type == "mistral": + if ( + _transformers_version >= version.parse("4.42") + and _transformers_version < version.parse("4.48") + and self.real_config._config.model_type == "mistral" + ): self._model[0].auto_model._update_causal_mask = types.MethodType( _update_causal_mask_patched, self._model[0].auto_model ) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - if _transformers_version >= version.parse("4.42") and self.real_config._config.model_type == "mistral": + if ( + _transformers_version >= version.parse("4.42") + and _transformers_version < version.parse("4.48") + and self.real_config._config.model_type == "mistral" + ): self._model[0].auto_model._update_causal_mask = types.MethodType( self._update_causal_mask_original, self._model[0].auto_model ) @@ -1132,16 +1167,16 @@ def _update_causal_mask_patched( padding_mask, min_dtype ) - # if ( - # self.config._attn_implementation == "sdpa" - # and attention_mask is not None - # and attention_mask.device.type == "cuda" - # and not output_attentions - # ): - # # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # # Details: https://github.com/pytorch/pytorch/issues/110213 - # causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @@ -1161,7 +1196,7 @@ def __enter__(self): "_prepare_4d_causal_attention_mask_for_sdpa", _prepare_4d_causal_attention_mask_for_sdpa_patched ) - if _transformers_version >= version.parse("4.42"): + if _transformers_version >= version.parse("4.42") and _transformers_version < version.parse("4.48"): if hasattr(self._model, "model"): self._model.model._update_causal_mask = types.MethodType( _update_causal_mask_patched, self._model.model @@ -1183,7 +1218,7 @@ def __exit__(self, exc_type, exc_value, traceback): "_prepare_4d_causal_attention_mask_for_sdpa", self.original_prepare_4d_causal_attention_mask_for_sdpa ) - if _transformers_version >= version.parse("4.42"): + if _transformers_version >= version.parse("4.42") and _transformers_version < version.parse("4.48"): if hasattr(self._model, "model"): self._model.model._update_causal_mask = types.MethodType( self._update_causal_mask_original, self._model.model From 7207215f2b92773c859a788c66a9ca2c98500489 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Jan 2025 10:30:41 +0100 Subject: [PATCH 03/26] fix logger --- optimum/exporters/onnx/_tensor_cache.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/optimum/exporters/onnx/_tensor_cache.py b/optimum/exporters/onnx/_tensor_cache.py index 43192e8408..154b186ef4 100644 --- a/optimum/exporters/onnx/_tensor_cache.py +++ b/optimum/exporters/onnx/_tensor_cache.py @@ -1,7 +1,10 @@ from typing import Any, Dict, Optional, Tuple import torch -from transformers.cache_utils import logger +from transformers import logging + + +logger = logging.get_logger(__name__) # The same as transformers.cache_utils.Cache but iherits from torch.Tensor instead of torch.nn.Module From 6261094b2a008969c12697b37b71853bce1e1cec Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Jan 2025 10:44:30 +0100 Subject: [PATCH 04/26] condition cache class usage --- optimum/exporters/onnx/model_patcher.py | 38 +++++++++++++------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 5a28452baa..98b772851f 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -199,31 +199,33 @@ def __init__( @functools.wraps(self.orig_forward) def patched_forward(*args, **kwargs): - from transformers.cache_utils import DynamicCache, EncoderDecoderCache - signature = inspect.signature(self.orig_forward) args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs) - if kwargs.get("past_key_values") is not None: - if len(kwargs["past_key_values"][0]) == 2: - kwargs["past_key_values"] = DynamicCache.from_legacy_cache(kwargs["past_key_values"]) - elif len(kwargs["past_key_values"][0]) == 4: - kwargs["past_key_values"] = EncoderDecoderCache.from_legacy_cache(kwargs["past_key_values"]) + if _transformers_version >= version.parse("4.48"): + from transformers.cache_utils import DynamicCache, EncoderDecoderCache + + if isinstance(kwargs.get("past_key_values"), (list, tuple)): + if len(kwargs["past_key_values"][0]) == 2: + kwargs["past_key_values"] = DynamicCache.from_legacy_cache(kwargs["past_key_values"]) + elif len(kwargs["past_key_values"][0]) == 4: + kwargs["past_key_values"] = EncoderDecoderCache.from_legacy_cache(kwargs["past_key_values"]) - elif any(isinstance(arg, (list, tuple)) for arg in args): - for i, arg in enumerate(args): - if isinstance(arg, (list, tuple)): - if len(arg[0]) == 2: - args[i] = DynamicCache.from_legacy_cache(arg) - elif len(arg[0]) == 4: - args[i] = EncoderDecoderCache.from_legacy_cache(arg) + elif any(isinstance(arg, (list, tuple)) for arg in args): + for i, arg in enumerate(args): + if isinstance(arg, (list, tuple)): + if len(arg[0]) == 2: + args[i] = DynamicCache.from_legacy_cache(arg) + elif len(arg[0]) == 4: + args[i] = EncoderDecoderCache.from_legacy_cache(arg) outputs = self.orig_forward(*args, **kwargs) - if "past_key_values" in outputs and isinstance( - outputs["past_key_values"], (DynamicCache, EncoderDecoderCache) - ): - outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() + if _transformers_version >= version.parse("4.48"): + if "past_key_values" in outputs and isinstance( + outputs["past_key_values"], (DynamicCache, EncoderDecoderCache) + ): + outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() # This code block handles different cases of the filterd_outputs input to align it with the expected # format of outputs. It is common for the output type of a model to vary, such as tensor, list, From 822066d3be5329db85a26cf5c8f8bb00932205f4 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Jan 2025 11:40:35 +0100 Subject: [PATCH 05/26] update opset for beit and data2vec vision and skip flattened/fused pkv (e.g. gpt bigcode) --- optimum/exporters/onnx/model_configs.py | 13 ++++++------ optimum/exporters/onnx/model_patcher.py | 27 ++++++++++++++++++------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 503f28d057..d10159dc9c 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -843,7 +843,7 @@ class DeiTOnnxConfig(ViTOnnxConfig): class BeitOnnxConfig(ViTOnnxConfig): - DEFAULT_ONNX_OPSET = 11 + DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. class ConvNextOnnxConfig(ViTOnnxConfig): @@ -1573,13 +1573,12 @@ class Data2VecTextOnnxConfig(DistilBertOnnxConfig): class Data2VecVisionOnnxConfig(ViTOnnxConfig): - DEFAULT_ONNX_OPSET = 11 + DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. class Data2VecAudioOnnxConfig(AudioOnnxConfig): - NORMALIZED_CONFIG_CLASS = NormalizedConfig - ATOL_FOR_VALIDATION = 1e-4 DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1. + NORMALIZED_CONFIG_CLASS = NormalizedConfig class PerceiverDummyInputGenerator(DummyVisionInputGenerator): @@ -2307,9 +2306,9 @@ def outputs(self) -> Dict[str, Dict[int, str]]: # for Speech2text, we need to name the second axis as # encoder_sequence_length / 2 * self._config.num_conv_layers as the axis name is # used for dummy input generation - common_outputs["last_hidden_state"][ - 1 - ] = f"{common_outputs['last_hidden_state'][1]} / {(2 * self._config.num_conv_layers)}" + common_outputs["last_hidden_state"][1] = ( + f"{common_outputs['last_hidden_state'][1]} / {(2 * self._config.num_conv_layers)}" + ) return common_outputs diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 98b772851f..8faf412f19 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -205,19 +205,32 @@ def patched_forward(*args, **kwargs): if _transformers_version >= version.parse("4.48"): from transformers.cache_utils import DynamicCache, EncoderDecoderCache - if isinstance(kwargs.get("past_key_values"), (list, tuple)): + if isinstance(kwargs.get("past_key_values"), (list, tuple)) and isinstance( + kwargs["past_key_values"][0], (list, tuple) + ): + print("Transforming past_key_values") if len(kwargs["past_key_values"][0]) == 2: kwargs["past_key_values"] = DynamicCache.from_legacy_cache(kwargs["past_key_values"]) elif len(kwargs["past_key_values"][0]) == 4: kwargs["past_key_values"] = EncoderDecoderCache.from_legacy_cache(kwargs["past_key_values"]) + else: + raise ValueError( + f"past_key_values should have either 2 or 4 elements, but it has {len(kwargs['past_key_values'][0])} elements" + ) elif any(isinstance(arg, (list, tuple)) for arg in args): - for i, arg in enumerate(args): - if isinstance(arg, (list, tuple)): - if len(arg[0]) == 2: - args[i] = DynamicCache.from_legacy_cache(arg) - elif len(arg[0]) == 4: - args[i] = EncoderDecoderCache.from_legacy_cache(arg) + for i in range(len(args)): + if isinstance(args[i], (list, tuple)) and isinstance(args[i][0], (list, tuple)): + print("Transforming past_key_values") + if len(args[i]) == 2: + args[i] = DynamicCache.from_legacy_cache(args[i]) + elif len(args[i]) == 4: + args[i] = EncoderDecoderCache.from_legacy_cache(args[i]) + else: + raise ValueError( + f"past_key_values should have either 2 or 4 elements, but it has {len(args[i])} elements" + ) + break outputs = self.orig_forward(*args, **kwargs) From 3ab38fdbc955b0fc04f0d1e466a512c8a0ccad3b Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Jan 2025 11:43:44 +0100 Subject: [PATCH 06/26] style --- optimum/exporters/onnx/model_configs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index d10159dc9c..f765eb7042 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -2306,9 +2306,9 @@ def outputs(self) -> Dict[str, Dict[int, str]]: # for Speech2text, we need to name the second axis as # encoder_sequence_length / 2 * self._config.num_conv_layers as the axis name is # used for dummy input generation - common_outputs["last_hidden_state"][1] = ( - f"{common_outputs['last_hidden_state'][1]} / {(2 * self._config.num_conv_layers)}" - ) + common_outputs["last_hidden_state"][ + 1 + ] = f"{common_outputs['last_hidden_state'][1]} / {(2 * self._config.num_conv_layers)}" return common_outputs From d713e5a43b8253ddd7dcc06e626c70fc34eb5e16 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Jan 2025 12:09:33 +0100 Subject: [PATCH 07/26] fix args patcher --- optimum/exporters/onnx/model_patcher.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 8faf412f19..8415586dd6 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -208,7 +208,6 @@ def patched_forward(*args, **kwargs): if isinstance(kwargs.get("past_key_values"), (list, tuple)) and isinstance( kwargs["past_key_values"][0], (list, tuple) ): - print("Transforming past_key_values") if len(kwargs["past_key_values"][0]) == 2: kwargs["past_key_values"] = DynamicCache.from_legacy_cache(kwargs["past_key_values"]) elif len(kwargs["past_key_values"][0]) == 4: @@ -221,10 +220,9 @@ def patched_forward(*args, **kwargs): elif any(isinstance(arg, (list, tuple)) for arg in args): for i in range(len(args)): if isinstance(args[i], (list, tuple)) and isinstance(args[i][0], (list, tuple)): - print("Transforming past_key_values") - if len(args[i]) == 2: + if len(args[i][0]) == 2: args[i] = DynamicCache.from_legacy_cache(args[i]) - elif len(args[i]) == 4: + elif len(args[i][0]) == 4: args[i] = EncoderDecoderCache.from_legacy_cache(args[i]) else: raise ValueError( From bf4d1f3a8928a3cce927abe13a01b4adc5b20b3b Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Jan 2025 12:13:57 +0100 Subject: [PATCH 08/26] fix modernbert testing --- tests/exporters/exporters_utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index ee31397fd8..f43b70201c 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -13,6 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from optimum.utils import is_transformers_version + + VALIDATE_EXPORT_ON_SHAPES_SLOW = { "batch_size": [1, 3, 5], "sequence_length": [8, 33, 96, 154], @@ -125,7 +128,11 @@ "mobilenet-v2": "hf-internal-testing/tiny-random-MobileNetV2Model", "mobilenet-v1": "google/mobilenet_v1_0.75_192", "mobilevit": "hf-internal-testing/tiny-random-mobilevit", - "modernbert": "hf-internal-testing/tiny-random-ModernBertForMaskedLM", + **( + {"modernbert": "hf-internal-testing/tiny-random-ModernBertForMaskedLM"} + if is_transformers_version(">=", "4.48") + else {} + ), "mpnet": "hf-internal-testing/tiny-random-MPNetModel", "mpt": "hf-internal-testing/tiny-random-MptForCausalLM", "mt5": "lewtun/tiny-random-mt5", @@ -269,7 +276,7 @@ # "mobilenet_v1": "google/mobilenet_v1_0.75_192", # "mobilenet_v2": "google/mobilenet_v2_0.35_96", "mobilevit": "apple/mobilevit-small", - "modernbert": "answerdotai/ModernBERT-base", + **({"modernbert": "answerdotai/ModernBERT-base"} if is_transformers_version(">=", "4.48") else {}), "mpt": "mosaicml/mpt-7b", "mt5": "lewtun/tiny-random-mt5", # Not using google/mt5-small because it takes too much time for testing. "musicgen": "facebook/musicgen-small", From 230c3a00a7622abf77f0b832014d3ea4a12841a7 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Jan 2025 13:26:08 +0100 Subject: [PATCH 09/26] adaot to new whisper returned generation length --- tests/onnxruntime/test_modeling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index d92888a8dd..c341bd88a9 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -4612,7 +4612,9 @@ def test_compare_with_and_without_past_key_values(self, model_arch: str): self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) - if model_arch == "whisper" and is_transformers_version(">=", "4.43"): + if model_arch == "whisper" and is_transformers_version(">=", "4.48"): + gen_length = self.GENERATION_LENGTH + elif model_arch == "whisper" and is_transformers_version(">=", "4.43"): gen_length = self.GENERATION_LENGTH + 2 else: gen_length = self.GENERATION_LENGTH + 1 From 3d5d9c96bbc27344f017414b51f8b61fe96f8649 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Jan 2025 14:01:32 +0100 Subject: [PATCH 10/26] fix is_causal in transformers --- setup.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 37194d8953..3942ac365f 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,8 @@ "datasets>=1.2.1", "evaluate", "protobuf>=3.20.1", - "transformers>=4.36,<4.49.0", + # "transformers>=4.36,<4.49.0", + "transformers@git+https://github.com/huggingface/transformers.git@fix-gpt2-is-causal", ], "onnxruntime-gpu": [ "onnx", @@ -59,19 +60,22 @@ "evaluate", "protobuf>=3.20.1", "accelerate", # ORTTrainer requires it. - "transformers>=4.36,<4.49.0", + # "transformers>=4.36,<4.49.0", + "transformers@git+https://github.com/huggingface/transformers.git@fix-gpt2-is-causal", ], "exporters": [ "onnx", "onnxruntime", "timm", - "transformers>=4.36,<4.49.0", + # "transformers>=4.36,<4.49.0", + "transformers@git+https://github.com/huggingface/transformers.git@fix-gpt2-is-causal", ], "exporters-gpu": [ "onnx", "onnxruntime-gpu", "timm", - "transformers>=4.36,<4.48.0", + # "transformers>=4.36,<4.49.0", + "transformers@git+https://github.com/huggingface/transformers.git@fix-gpt2-is-causal", ], "exporters-tf": [ "tensorflow>=2.4,<=2.12.1", From 96e2714184e860c1c55f450f6d9169a17fc7da49 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Jan 2025 14:09:26 +0100 Subject: [PATCH 11/26] fix modernbert failures --- optimum/exporters/tasks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 3793a56068..0db04a82bf 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -897,9 +897,9 @@ class TasksManager: "feature-extraction", "fill-mask", "text-classification", - "multiple-choice", + # "multiple-choice", "token-classification", - "question-answering", + # "question-answering", onnx="ModernBertOnnxConfig", ), "mpnet": supported_tasks_mapping( From 78a2dba71cf787684f6aa15c4b7df05b9939180b Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Jan 2025 14:14:23 +0100 Subject: [PATCH 12/26] style --- optimum/exporters/onnx/model_patcher.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 8415586dd6..a608e947d5 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -20,15 +20,10 @@ import types from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +import torch import transformers -import transformers.cache_utils from packaging import version from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet -from transformers.utils import is_torch_available - - -if is_torch_available(): - import torch from ...configuration_utils import _transformers_version from ...utils import logging From 967c6e2c1e5d0ea7b7b66f880b7ce8d033b91239 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Jan 2025 14:31:14 +0100 Subject: [PATCH 13/26] traceable cache --- optimum/exporters/onnx/_tensor_cache.py | 95 ------------------------- optimum/exporters/onnx/model_patcher.py | 91 ++++++++++++++++++++++- 2 files changed, 89 insertions(+), 97 deletions(-) delete mode 100644 optimum/exporters/onnx/_tensor_cache.py diff --git a/optimum/exporters/onnx/_tensor_cache.py b/optimum/exporters/onnx/_tensor_cache.py deleted file mode 100644 index 154b186ef4..0000000000 --- a/optimum/exporters/onnx/_tensor_cache.py +++ /dev/null @@ -1,95 +0,0 @@ -from typing import Any, Dict, Optional, Tuple - -import torch -from transformers import logging - - -logger = logging.get_logger(__name__) - - -# The same as transformers.cache_utils.Cache but iherits from torch.Tensor instead of torch.nn.Module -class Cache(torch.Tensor): - """ - Base, abstract class for all caches. The actual data structure is specific to each subclass. - """ - - def __init__(self): - super().__init__() - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. These are specific to each subclass and allow new types of - cache to be created. - - Return: - A tuple containing the updated key and value states. - """ - raise NotImplementedError("Make sure to implement `update` in a subclass.") - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") - - # Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length" - # Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles - # infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so - # we change naming to be more explicit - def get_max_length(self) -> Optional[int]: - logger.warning_once( - "`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. " - "Calling `get_max_cache()` will raise error from v4.48" - ) - return self.get_max_cache_shape() - - def get_max_cache_shape(self) -> Optional[int]: - """Returns the maximum sequence length (i.e. max capacity) of the cache object""" - raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") - - def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: - """Given the sequence length of the new inputs, returns the usable length of the cache.""" - # Cache without size limit -> all cache is usable - # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache - # length, we will need to evict part of the cache (and thus not all cache is usable) - max_length = self.get_max_cache_shape() - previous_seq_length = self.get_seq_length(layer_idx) - if max_length is not None and previous_seq_length + new_seq_length > max_length: - return max_length - new_seq_length - return previous_seq_length - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx] != []: - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - if self.value_cache[layer_idx] != []: - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - @property - def seen_tokens(self): - logger.warning_once( - "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " - "model input instead." - ) - if hasattr(self, "_seen_tokens"): - return self._seen_tokens - else: - return None diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index a608e947d5..29d0561f99 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -27,7 +27,6 @@ from ...configuration_utils import _transformers_version from ...utils import logging -from ._tensor_cache import Cache as PatchedCache if _transformers_version > version.parse("4.34.99"): @@ -155,8 +154,96 @@ def onnx_compatible_unfold(input_tensor, dimension, size, step): return result +# The same as transformers.cache_utils.Cache but iherits from torch.Tensor instead of torch.nn.Module +class TraceableCache(torch.Tensor): + """ + Base, abstract class for all caches. The actual data structure is specific to each subclass. + """ + + def __init__(self): + super().__init__() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. These are specific to each subclass and allow new types of + cache to be created. + + Return: + A tuple containing the updated key and value states. + """ + raise NotImplementedError("Make sure to implement `update` in a subclass.") + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + + # Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length" + # Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles + # infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so + # we change naming to be more explicit + def get_max_length(self) -> Optional[int]: + logger.warning_once( + "`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. " + "Calling `get_max_cache()` will raise error from v4.48" + ) + return self.get_max_cache_shape() + + def get_max_cache_shape(self) -> Optional[int]: + """Returns the maximum sequence length (i.e. max capacity) of the cache object""" + raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") + + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache.""" + # Cache without size limit -> all cache is usable + # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache + # length, we will need to evict part of the cache (and thus not all cache is usable) + max_length = self.get_max_cache_shape() + previous_seq_length = self.get_seq_length(layer_idx) + if max_length is not None and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx] != []: + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + if self.value_cache[layer_idx] != []: + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + @property + def seen_tokens(self): + logger.warning_once( + "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " + "model input instead." + ) + if hasattr(self, "_seen_tokens"): + return self._seen_tokens + else: + return None + + UNSUPPORTED_OPS_PATCHING_SPEC = [PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold)] -CACHE_PATCHING_SPEC = [PatchingSpec(transformers.cache_utils, "Cache", PatchedCache, transformers.cache_utils.Cache)] +CACHE_PATCHING_SPEC = [PatchingSpec(transformers.cache_utils, "Cache", TraceableCache, transformers.cache_utils.Cache)] class ModelPatcher: From 1d743882060750ab7b6d0e6a6195593c417ecb49 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Jan 2025 12:54:40 +0100 Subject: [PATCH 14/26] use pkv index --- optimum/exporters/onnx/model_patcher.py | 54 +++++++++++++------------ 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 29d0561f99..db78d01f0e 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -39,6 +39,9 @@ if _transformers_version >= version.parse("4.42"): from transformers.cache_utils import SlidingWindowCache, StaticCache +if _transformers_version >= version.parse("4.48"): + from transformers.cache_utils import DynamicCache, EncoderDecoderCache + if TYPE_CHECKING: from transformers import PreTrainedModel, TFPreTrainedModel @@ -154,8 +157,8 @@ def onnx_compatible_unfold(input_tensor, dimension, size, step): return result -# The same as transformers.cache_utils.Cache but iherits from torch.Tensor instead of torch.nn.Module -class TraceableCache(torch.Tensor): +# removing the nn.Module, same as in https://github.com/huggingface/transformers/pull/35873 +class TraceableCache: """ Base, abstract class for all caches. The actual data structure is specific to each subclass. """ @@ -284,11 +287,25 @@ def patched_forward(*args, **kwargs): signature = inspect.signature(self.orig_forward) args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs) - if _transformers_version >= version.parse("4.48"): - from transformers.cache_utils import DynamicCache, EncoderDecoderCache - - if isinstance(kwargs.get("past_key_values"), (list, tuple)) and isinstance( - kwargs["past_key_values"][0], (list, tuple) + if "past_key_values" in signature.parameters: + pkv_index = list(signature.parameters.keys()).index("past_key_values") + if ( + pkv_index < len(args) # pkv is in args + and isinstance(args[pkv_index], (list, tuple)) + and isinstance(args[pkv_index][0], (list, tuple)) + ): + if len(args[pkv_index][0]) == 2: + args[pkv_index] = DynamicCache.from_legacy_cache(args[pkv_index]) + elif len(args[pkv_index][0]) == 4: + args[pkv_index] = EncoderDecoderCache.from_legacy_cache(args[pkv_index]) + else: + raise ValueError( + f"past_key_values should have either 2 or 4 elements, but it has {len(args[pkv_index][0])} elements" + ) + elif ( + "past_key_values" in kwargs # pkv is in kwargs + and isinstance(kwargs["past_key_values"], (list, tuple)) + and isinstance(kwargs["past_key_values"][0], (list, tuple)) ): if len(kwargs["past_key_values"][0]) == 2: kwargs["past_key_values"] = DynamicCache.from_legacy_cache(kwargs["past_key_values"]) @@ -299,27 +316,8 @@ def patched_forward(*args, **kwargs): f"past_key_values should have either 2 or 4 elements, but it has {len(kwargs['past_key_values'][0])} elements" ) - elif any(isinstance(arg, (list, tuple)) for arg in args): - for i in range(len(args)): - if isinstance(args[i], (list, tuple)) and isinstance(args[i][0], (list, tuple)): - if len(args[i][0]) == 2: - args[i] = DynamicCache.from_legacy_cache(args[i]) - elif len(args[i][0]) == 4: - args[i] = EncoderDecoderCache.from_legacy_cache(args[i]) - else: - raise ValueError( - f"past_key_values should have either 2 or 4 elements, but it has {len(args[i])} elements" - ) - break - outputs = self.orig_forward(*args, **kwargs) - if _transformers_version >= version.parse("4.48"): - if "past_key_values" in outputs and isinstance( - outputs["past_key_values"], (DynamicCache, EncoderDecoderCache) - ): - outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() - # This code block handles different cases of the filterd_outputs input to align it with the expected # format of outputs. It is common for the output type of a model to vary, such as tensor, list, # tuple, etc. For Transformers models, the output is encapsulated in a ModelOutput object that @@ -352,6 +350,10 @@ def patched_forward(*args, **kwargs): name = list(config.outputs.keys())[0] filterd_outputs[name] = outputs + if _transformers_version >= version.parse("4.48"): + if isinstance(filterd_outputs.get("past_key_values"), (DynamicCache, EncoderDecoderCache)): + filterd_outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() + return filterd_outputs self.patched_forward = patched_forward From d452c46402a3c463b7012624ed364681d261f0a1 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Jan 2025 16:02:17 +0100 Subject: [PATCH 15/26] add version gard and clean up other model patcher version gards --- optimum/exporters/onnx/model_patcher.py | 161 ++++++++++-------------- 1 file changed, 67 insertions(+), 94 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index db78d01f0e..9724057202 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -22,26 +22,21 @@ import torch import transformers -from packaging import version from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet -from ...configuration_utils import _transformers_version -from ...utils import logging +from ...utils import is_transformers_version, logging -if _transformers_version > version.parse("4.34.99"): +if is_transformers_version(">=", "4.35"): from transformers.modeling_attn_mask_utils import AttentionMaskConverter -if _transformers_version >= version.parse("4.36"): +if is_transformers_version(">=", "4.36"): from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa -else: - _prepare_4d_causal_attention_mask_for_sdpa = None - AttentionMaskConverter = None - -if _transformers_version >= version.parse("4.42"): +if is_transformers_version(">=", "4.42"): from transformers.cache_utils import SlidingWindowCache, StaticCache -if _transformers_version >= version.parse("4.48"): +if is_transformers_version(">=", "4.48"): from transformers.cache_utils import DynamicCache, EncoderDecoderCache - +if is_transformers_version(">=", "4.43"): + from transformers.models.clip.modeling_clip import CLIPAttention, CLIPSdpaAttention if TYPE_CHECKING: from transformers import PreTrainedModel, TFPreTrainedModel @@ -287,34 +282,38 @@ def patched_forward(*args, **kwargs): signature = inspect.signature(self.orig_forward) args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs) - if "past_key_values" in signature.parameters: - pkv_index = list(signature.parameters.keys()).index("past_key_values") - if ( - pkv_index < len(args) # pkv is in args - and isinstance(args[pkv_index], (list, tuple)) - and isinstance(args[pkv_index][0], (list, tuple)) - ): - if len(args[pkv_index][0]) == 2: - args[pkv_index] = DynamicCache.from_legacy_cache(args[pkv_index]) - elif len(args[pkv_index][0]) == 4: - args[pkv_index] = EncoderDecoderCache.from_legacy_cache(args[pkv_index]) - else: - raise ValueError( - f"past_key_values should have either 2 or 4 elements, but it has {len(args[pkv_index][0])} elements" - ) - elif ( - "past_key_values" in kwargs # pkv is in kwargs - and isinstance(kwargs["past_key_values"], (list, tuple)) - and isinstance(kwargs["past_key_values"][0], (list, tuple)) - ): - if len(kwargs["past_key_values"][0]) == 2: - kwargs["past_key_values"] = DynamicCache.from_legacy_cache(kwargs["past_key_values"]) - elif len(kwargs["past_key_values"][0]) == 4: - kwargs["past_key_values"] = EncoderDecoderCache.from_legacy_cache(kwargs["past_key_values"]) - else: - raise ValueError( - f"past_key_values should have either 2 or 4 elements, but it has {len(kwargs['past_key_values'][0])} elements" - ) + if is_transformers_version(">=", "4.48"): + if "past_key_values" in signature.parameters: + pkv_index = list(signature.parameters.keys()).index("past_key_values") + + if ( + pkv_index < len(args) # pkv is in args + and isinstance(args[pkv_index], (list, tuple)) + and isinstance(args[pkv_index][0], (list, tuple)) + ): + if len(args[pkv_index][0]) == 2: + args[pkv_index] = DynamicCache.from_legacy_cache(args[pkv_index]) + elif len(args[pkv_index][0]) == 4: + args[pkv_index] = EncoderDecoderCache.from_legacy_cache(args[pkv_index]) + else: + raise ValueError( + f"past_key_values should have either 2 or 4 elements, but it has {len(args[pkv_index][0])} elements" + ) + elif ( + "past_key_values" in kwargs # pkv is in kwargs + and isinstance(kwargs["past_key_values"], (list, tuple)) + and isinstance(kwargs["past_key_values"][0], (list, tuple)) + ): + if len(kwargs["past_key_values"][0]) == 2: + kwargs["past_key_values"] = DynamicCache.from_legacy_cache(kwargs["past_key_values"]) + elif len(kwargs["past_key_values"][0]) == 4: + kwargs["past_key_values"] = EncoderDecoderCache.from_legacy_cache( + kwargs["past_key_values"] + ) + else: + raise ValueError( + f"past_key_values should have either 2 or 4 elements, but it has {len(kwargs['past_key_values'][0])} elements" + ) outputs = self.orig_forward(*args, **kwargs) @@ -350,7 +349,7 @@ def patched_forward(*args, **kwargs): name = list(config.outputs.keys())[0] filterd_outputs[name] = outputs - if _transformers_version >= version.parse("4.48"): + if is_transformers_version(">=", "4.48"): if isinstance(filterd_outputs.get("past_key_values"), (DynamicCache, EncoderDecoderCache)): filterd_outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache() @@ -492,7 +491,7 @@ def _make_causal_mask_patched( _make_causal_mask_patched_staticmethod = staticmethod(_make_causal_mask_patched) -if _transformers_version >= version.parse("4.39.0"): +if is_transformers_version(">=", "4.39"): _unmask_unattended_patched_staticmethod = staticmethod(_unmask_unattended_patched) else: _unmask_unattended_patched_staticmethod = staticmethod(_unmask_unattended_patched_legacy) @@ -536,28 +535,20 @@ def _prepare_4d_causal_attention_mask_for_sdpa_patched( class DecoderModelPatcher(ModelPatcher): def __enter__(self): super().__enter__() - if AttentionMaskConverter is not None: - # TODO: Remove this _make_causal_mask patch if once transformers if much above 4.35 - AttentionMaskConverter._make_causal_mask = _make_causal_mask_patched_staticmethod - - if _transformers_version >= version.parse("4.36"): - AttentionMaskConverter._unmask_unattended = _unmask_unattended_patched_staticmethod + if is_transformers_version(">=", "4.36"): + AttentionMaskConverter._unmask_unattended = _unmask_unattended_patched_staticmethod - if _transformers_version >= version.parse("4.36"): + if is_transformers_version(">=", "4.36"): patch_everywhere( "_prepare_4d_causal_attention_mask_for_sdpa", _prepare_4d_causal_attention_mask_for_sdpa_patched ) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - if AttentionMaskConverter is not None: - # TODO: Remove this _make_causal_mask patch if once transformers if much above 4.35 - AttentionMaskConverter._make_causal_mask = staticmethod(self.original_make_causal) - - if _transformers_version >= version.parse("4.36"): - AttentionMaskConverter._unmask_unattended = staticmethod(self.original_unmask_unattended) + if is_transformers_version(">=", "4.36"): + AttentionMaskConverter._unmask_unattended = staticmethod(self.original_unmask_unattended) - if _transformers_version >= version.parse("4.36"): + if is_transformers_version(">=", "4.36"): patch_everywhere( "_prepare_4d_causal_attention_mask_for_sdpa", self.original_prepare_4d_causal_attention_mask_for_sdpa ) @@ -570,14 +561,10 @@ def __init__( ): super().__init__(config, model, model_kwargs) - if _transformers_version >= version.parse("4.36"): + if is_transformers_version(">=", "4.36"): self.original_prepare_4d_causal_attention_mask_for_sdpa = _prepare_4d_causal_attention_mask_for_sdpa self.original_unmask_unattended = AttentionMaskConverter._unmask_unattended - # TODO: Remove this if once transformers if much above 4.35 - if AttentionMaskConverter is not None: - self.original_make_causal = AttentionMaskConverter._make_causal_mask - def falcon_build_alibi_tensor_patched( attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype @@ -958,8 +945,8 @@ class SentenceTransformersTransformerPatcher(ModelPatcher): def __enter__(self): super().__enter__() if ( - _transformers_version >= version.parse("4.42") - and _transformers_version < version.parse("4.48") + is_transformers_version(">=", "4.42") + and is_transformers_version("<", "4.48") and self.real_config._config.model_type == "mistral" ): self._model[0].auto_model._update_causal_mask = types.MethodType( @@ -969,8 +956,8 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if ( - _transformers_version >= version.parse("4.42") - and _transformers_version < version.parse("4.48") + is_transformers_version(">=", "4.42") + and is_transformers_version("<", "4.48") and self.real_config._config.model_type == "mistral" ): self._model[0].auto_model._update_causal_mask = types.MethodType( @@ -985,7 +972,11 @@ def __init__( ): super().__init__(config, model, model_kwargs) - if _transformers_version >= version.parse("4.42") and self.real_config._config.model_type == "mistral": + if ( + is_transformers_version(">=", "4.42") + and is_transformers_version("<", "4.48") + and self.real_config._config.model_type == "mistral" + ): self._update_causal_mask_original = self._model[0].auto_model._update_causal_mask def patched_forward(input_ids, attention_mask): @@ -1281,19 +1272,14 @@ def _update_causal_mask_patched( class MistralModelPatcher(ModelPatcher): def __enter__(self): super().__enter__() - if AttentionMaskConverter is not None: - # TODO: Remove this _make_causal_mask patch if once transformers if much above 4.35 - AttentionMaskConverter._make_causal_mask = _make_causal_mask_patched_staticmethod - - if _transformers_version >= version.parse("4.36"): - AttentionMaskConverter._unmask_unattended = _unmask_unattended_patched_staticmethod - if _transformers_version >= version.parse("4.36"): + if is_transformers_version(">=", "4.36"): + AttentionMaskConverter._unmask_unattended = _unmask_unattended_patched_staticmethod patch_everywhere( "_prepare_4d_causal_attention_mask_for_sdpa", _prepare_4d_causal_attention_mask_for_sdpa_patched ) - if _transformers_version >= version.parse("4.42") and _transformers_version < version.parse("4.48"): + if is_transformers_version(">=", "4.42") and is_transformers_version("<", "4.48"): if hasattr(self._model, "model"): self._model.model._update_causal_mask = types.MethodType( _update_causal_mask_patched, self._model.model @@ -1303,19 +1289,14 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - if AttentionMaskConverter is not None: - # TODO: Remove this _make_causal_mask patch if once transformers if much above 4.35 - AttentionMaskConverter._make_causal_mask = staticmethod(self.original_make_causal) - - if _transformers_version >= version.parse("4.36"): - AttentionMaskConverter._unmask_unattended = staticmethod(self.original_unmask_unattended) - if _transformers_version >= version.parse("4.36"): + if is_transformers_version(">=", "4.36"): + AttentionMaskConverter._unmask_unattended = staticmethod(self.original_unmask_unattended) patch_everywhere( "_prepare_4d_causal_attention_mask_for_sdpa", self.original_prepare_4d_causal_attention_mask_for_sdpa ) - if _transformers_version >= version.parse("4.42") and _transformers_version < version.parse("4.48"): + if is_transformers_version(">=", "4.42") and is_transformers_version("<", "4.48"): if hasattr(self._model, "model"): self._model.model._update_causal_mask = types.MethodType( self._update_causal_mask_original, self._model.model @@ -1331,15 +1312,11 @@ def __init__( ): super().__init__(config, model, model_kwargs) - if _transformers_version >= version.parse("4.36"): + if is_transformers_version(">=", "4.36"): self.original_prepare_4d_causal_attention_mask_for_sdpa = _prepare_4d_causal_attention_mask_for_sdpa self.original_unmask_unattended = AttentionMaskConverter._unmask_unattended - # TODO: Remove this if once transformers if much above 4.35 - if AttentionMaskConverter is not None: - self.original_make_causal = AttentionMaskConverter._make_causal_mask - - if _transformers_version >= version.parse("4.42"): + if is_transformers_version(">=", "4.42") and is_transformers_version("<", "4.48"): if hasattr(self._model, "model"): self._update_causal_mask_original = self._model.model._update_causal_mask else: @@ -1350,14 +1327,10 @@ class CLIPModelPatcher(ModelPatcher): def __enter__(self): super().__enter__() - if _transformers_version >= version.parse("4.43"): - from transformers.models.clip.modeling_clip import CLIPAttention, CLIPSdpaAttention - + if is_transformers_version(">=", "4.43"): self.original_sdpa_forward, CLIPSdpaAttention.forward = CLIPSdpaAttention.forward, CLIPAttention.forward def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - if _transformers_version >= version.parse("4.43"): - from transformers.models.clip.modeling_clip import CLIPSdpaAttention - + if is_transformers_version(">=", "4.43"): CLIPSdpaAttention.forward = self.original_sdpa_forward From 5dcab7f1ba003c88f703a03693ff1fa9c4430cf4 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Jan 2025 16:47:24 +0100 Subject: [PATCH 16/26] patch sdpa attention in optimum for now --- optimum/exporters/onnx/model_patcher.py | 65 ++++++++++++++++++++++--- setup.py | 12 ++--- 2 files changed, 63 insertions(+), 14 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 9724057202..90615e71bc 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -31,12 +31,15 @@ from transformers.modeling_attn_mask_utils import AttentionMaskConverter if is_transformers_version(">=", "4.36"): from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa +if is_transformers_version(">=", "4.43"): + from transformers.models.clip.modeling_clip import CLIPAttention, CLIPSdpaAttention if is_transformers_version(">=", "4.42"): from transformers.cache_utils import SlidingWindowCache, StaticCache if is_transformers_version(">=", "4.48"): + import transformers.integrations.sdpa_attention from transformers.cache_utils import DynamicCache, EncoderDecoderCache -if is_transformers_version(">=", "4.43"): - from transformers.models.clip.modeling_clip import CLIPAttention, CLIPSdpaAttention + from transformers.integrations.sdpa_attention import repeat_kv + if TYPE_CHECKING: from transformers import PreTrainedModel, TFPreTrainedModel @@ -532,27 +535,74 @@ def _prepare_4d_causal_attention_mask_for_sdpa_patched( return attention_mask +def patched_sdpa_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + is_causal: Optional[bool] = None, + **kwargs, +) -> Tuple[torch.Tensor, None]: + if hasattr(module, "num_key_value_groups"): + key = repeat_kv(key, module.num_key_value_groups) + value = repeat_kv(value, module.num_key_value_groups) + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions + # Reference: https://github.com/pytorch/pytorch/issues/112577. + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + if is_causal is None: + is_causal = causal_mask is None and query.shape[2] > 1 + + # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. + # We convert it to a bool for the SDPA kernel that only accepts bools. + if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor): + is_causal = is_causal.item() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=dropout, + scale=scaling, + is_causal=is_causal, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, None + + class DecoderModelPatcher(ModelPatcher): def __enter__(self): super().__enter__() if is_transformers_version(">=", "4.36"): AttentionMaskConverter._unmask_unattended = _unmask_unattended_patched_staticmethod - - if is_transformers_version(">=", "4.36"): patch_everywhere( "_prepare_4d_causal_attention_mask_for_sdpa", _prepare_4d_causal_attention_mask_for_sdpa_patched ) + if is_transformers_version(">=", "4.48"): + transformers.integrations.sdpa_attention.sdpa_attention_forward = patched_sdpa_attention_forward + def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) if is_transformers_version(">=", "4.36"): AttentionMaskConverter._unmask_unattended = staticmethod(self.original_unmask_unattended) - - if is_transformers_version(">=", "4.36"): patch_everywhere( "_prepare_4d_causal_attention_mask_for_sdpa", self.original_prepare_4d_causal_attention_mask_for_sdpa ) + if is_transformers_version(">=", "4.48"): + transformers.integrations.sdpa_attention.sdpa_attention_forward = self.original_sdpa_attention_forward + def __init__( self, config: "OnnxConfig", @@ -565,6 +615,9 @@ def __init__( self.original_prepare_4d_causal_attention_mask_for_sdpa = _prepare_4d_causal_attention_mask_for_sdpa self.original_unmask_unattended = AttentionMaskConverter._unmask_unattended + if is_transformers_version(">=", "4.48"): + self.original_sdpa_attention_forward = transformers.integrations.sdpa_attention.sdpa_attention_forward + def falcon_build_alibi_tensor_patched( attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype diff --git a/setup.py b/setup.py index 3942ac365f..4c4d9c43a0 100644 --- a/setup.py +++ b/setup.py @@ -50,8 +50,7 @@ "datasets>=1.2.1", "evaluate", "protobuf>=3.20.1", - # "transformers>=4.36,<4.49.0", - "transformers@git+https://github.com/huggingface/transformers.git@fix-gpt2-is-causal", + "transformers>=4.36,<4.49.0", ], "onnxruntime-gpu": [ "onnx", @@ -60,22 +59,19 @@ "evaluate", "protobuf>=3.20.1", "accelerate", # ORTTrainer requires it. - # "transformers>=4.36,<4.49.0", - "transformers@git+https://github.com/huggingface/transformers.git@fix-gpt2-is-causal", + "transformers>=4.36,<4.49.0", ], "exporters": [ "onnx", "onnxruntime", "timm", - # "transformers>=4.36,<4.49.0", - "transformers@git+https://github.com/huggingface/transformers.git@fix-gpt2-is-causal", + "transformers>=4.36,<4.49.0", ], "exporters-gpu": [ "onnx", "onnxruntime-gpu", "timm", - # "transformers>=4.36,<4.49.0", - "transformers@git+https://github.com/huggingface/transformers.git@fix-gpt2-is-causal", + "transformers>=4.36,<4.49.0", ], "exporters-tf": [ "tensorflow>=2.4,<=2.12.1", From 656941a47f7deb8e1fbdbafffe0ea05b4b3aa3a8 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Jan 2025 16:48:58 +0100 Subject: [PATCH 17/26] remove modernbert condition --- tests/exporters/exporters_utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index f43b70201c..3aefc6a061 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -128,11 +128,7 @@ "mobilenet-v2": "hf-internal-testing/tiny-random-MobileNetV2Model", "mobilenet-v1": "google/mobilenet_v1_0.75_192", "mobilevit": "hf-internal-testing/tiny-random-mobilevit", - **( - {"modernbert": "hf-internal-testing/tiny-random-ModernBertForMaskedLM"} - if is_transformers_version(">=", "4.48") - else {} - ), + "modernbert": "hf-internal-testing/tiny-random-ModernBertForMaskedLM", "mpnet": "hf-internal-testing/tiny-random-MPNetModel", "mpt": "hf-internal-testing/tiny-random-MptForCausalLM", "mt5": "lewtun/tiny-random-mt5", @@ -276,7 +272,7 @@ # "mobilenet_v1": "google/mobilenet_v1_0.75_192", # "mobilenet_v2": "google/mobilenet_v2_0.35_96", "mobilevit": "apple/mobilevit-small", - **({"modernbert": "answerdotai/ModernBERT-base"} if is_transformers_version(">=", "4.48") else {}), + "modernbert": "answerdotai/ModernBERT-base", "mpt": "mosaicml/mpt-7b", "mt5": "lewtun/tiny-random-mt5", # Not using google/mt5-small because it takes too much time for testing. "musicgen": "facebook/musicgen-small", From 1bcb38f3de31c5d541be4402a45d3c15ae0684d3 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Jan 2025 16:52:38 +0100 Subject: [PATCH 18/26] style --- tests/exporters/exporters_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 3aefc6a061..2b9bca7a73 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from optimum.utils import is_transformers_version - VALIDATE_EXPORT_ON_SHAPES_SLOW = { "batch_size": [1, 3, 5], From 23fa20ebbe59e7ab7d8dabaf87548ac965484ff4 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Jan 2025 17:17:44 +0100 Subject: [PATCH 19/26] fix MistralModelPatcher --- optimum/exporters/onnx/model_patcher.py | 26 +++---------------------- 1 file changed, 3 insertions(+), 23 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 90615e71bc..07b994aaf0 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -1025,11 +1025,7 @@ def __init__( ): super().__init__(config, model, model_kwargs) - if ( - is_transformers_version(">=", "4.42") - and is_transformers_version("<", "4.48") - and self.real_config._config.model_type == "mistral" - ): + if is_transformers_version(">=", "4.42") and self.real_config._config.model_type == "mistral": self._update_causal_mask_original = self._model[0].auto_model._update_causal_mask def patched_forward(input_ids, attention_mask): @@ -1322,16 +1318,10 @@ def _update_causal_mask_patched( return causal_mask -class MistralModelPatcher(ModelPatcher): +class MistralModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() - if is_transformers_version(">=", "4.36"): - AttentionMaskConverter._unmask_unattended = _unmask_unattended_patched_staticmethod - patch_everywhere( - "_prepare_4d_causal_attention_mask_for_sdpa", _prepare_4d_causal_attention_mask_for_sdpa_patched - ) - if is_transformers_version(">=", "4.42") and is_transformers_version("<", "4.48"): if hasattr(self._model, "model"): self._model.model._update_causal_mask = types.MethodType( @@ -1343,12 +1333,6 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) - if is_transformers_version(">=", "4.36"): - AttentionMaskConverter._unmask_unattended = staticmethod(self.original_unmask_unattended) - patch_everywhere( - "_prepare_4d_causal_attention_mask_for_sdpa", self.original_prepare_4d_causal_attention_mask_for_sdpa - ) - if is_transformers_version(">=", "4.42") and is_transformers_version("<", "4.48"): if hasattr(self._model, "model"): self._model.model._update_causal_mask = types.MethodType( @@ -1365,11 +1349,7 @@ def __init__( ): super().__init__(config, model, model_kwargs) - if is_transformers_version(">=", "4.36"): - self.original_prepare_4d_causal_attention_mask_for_sdpa = _prepare_4d_causal_attention_mask_for_sdpa - self.original_unmask_unattended = AttentionMaskConverter._unmask_unattended - - if is_transformers_version(">=", "4.42") and is_transformers_version("<", "4.48"): + if is_transformers_version(">=", "4.42"): if hasattr(self._model, "model"): self._update_causal_mask_original = self._model.model._update_causal_mask else: From 24c8f4b883139ded4b114588ce163e05d28ca655 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Jan 2025 18:38:23 +0100 Subject: [PATCH 20/26] correctly patch gpt2 in vision encoder decoder --- optimum/exporters/onnx/model_patcher.py | 160 +++++++++++++----------- 1 file changed, 84 insertions(+), 76 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 07b994aaf0..9640497042 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -36,9 +36,9 @@ if is_transformers_version(">=", "4.42"): from transformers.cache_utils import SlidingWindowCache, StaticCache if is_transformers_version(">=", "4.48"): - import transformers.integrations.sdpa_attention from transformers.cache_utils import DynamicCache, EncoderDecoderCache - from transformers.integrations.sdpa_attention import repeat_kv + from transformers.integrations.sdpa_attention import repeat_kv, sdpa_attention_forward + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS if TYPE_CHECKING: @@ -436,7 +436,62 @@ def patched_forward(*args, **kwargs): self.patched_forward = patched_forward +def patched_sdpa_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + is_causal: Optional[bool] = None, + **kwargs, +) -> Tuple[torch.Tensor, None]: + if hasattr(module, "num_key_value_groups"): + key = repeat_kv(key, module.num_key_value_groups) + value = repeat_kv(value, module.num_key_value_groups) + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key.shape[-2]] + # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions + # Reference: https://github.com/pytorch/pytorch/issues/112577. + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + if is_causal is None: + is_causal = causal_mask is None and query.shape[2] > 1 + + # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. + # We convert it to a bool for the SDPA kernel that only accepts bools. + if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor): + is_causal = is_causal.item() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=causal_mask, + dropout_p=dropout, + scale=scaling, + is_causal=is_causal, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, None + + class VisionEncoderDecoderPatcher(Seq2SeqModelPatcher): + def __enter__(self): + super().__enter__() + if is_transformers_version(">=", "4.48"): + ALL_ATTENTION_FUNCTIONS["sdpa"] = patched_sdpa_attention_forward + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + if is_transformers_version(">=", "4.48"): + ALL_ATTENTION_FUNCTIONS["sdpa"] = sdpa_attention_forward + def __init__( self, config: "OnnxConfig", @@ -450,14 +505,16 @@ def __init__( model.decoder.model.decoder.config.use_cache = True -def _unmask_unattended_patched_legacy( - expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float] -): - return expanded_mask +if is_transformers_version(">=", "4.39"): + def _unmask_unattended_patched(expanded_mask: torch.Tensor, min_dtype: float): + return expanded_mask +else: -def _unmask_unattended_patched(expanded_mask: torch.Tensor, min_dtype: float): - return expanded_mask + def _unmask_unattended_patched( + expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float] + ): + return expanded_mask def _make_causal_mask_patched( @@ -492,14 +549,6 @@ def _make_causal_mask_patched( return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) -_make_causal_mask_patched_staticmethod = staticmethod(_make_causal_mask_patched) - -if is_transformers_version(">=", "4.39"): - _unmask_unattended_patched_staticmethod = staticmethod(_unmask_unattended_patched) -else: - _unmask_unattended_patched_staticmethod = staticmethod(_unmask_unattended_patched_legacy) - - # Adapted from _prepare_4d_causal_attention_mask def _prepare_4d_causal_attention_mask_for_sdpa_patched( attention_mask: Optional[torch.Tensor], @@ -535,74 +584,29 @@ def _prepare_4d_causal_attention_mask_for_sdpa_patched( return attention_mask -def patched_sdpa_attention_forward( - module: torch.nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Optional[torch.Tensor], - dropout: float = 0.0, - scaling: Optional[float] = None, - is_causal: Optional[bool] = None, - **kwargs, -) -> Tuple[torch.Tensor, None]: - if hasattr(module, "num_key_value_groups"): - key = repeat_kv(key, module.num_key_value_groups) - value = repeat_kv(value, module.num_key_value_groups) - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] - # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions - # Reference: https://github.com/pytorch/pytorch/issues/112577. - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - if is_causal is None: - is_causal = causal_mask is None and query.shape[2] > 1 - - # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. - # We convert it to a bool for the SDPA kernel that only accepts bools. - if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor): - is_causal = is_causal.item() - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=causal_mask, - dropout_p=dropout, - scale=scaling, - is_causal=is_causal, - ) - attn_output = attn_output.transpose(1, 2).contiguous() - return attn_output, None - - class DecoderModelPatcher(ModelPatcher): def __enter__(self): super().__enter__() + if is_transformers_version(">=", "4.35"): + AttentionMaskConverter._make_causal_mask = staticmethod(_make_causal_mask_patched) + if is_transformers_version(">=", "4.36"): - AttentionMaskConverter._unmask_unattended = _unmask_unattended_patched_staticmethod + AttentionMaskConverter._unmask_unattended = staticmethod(_unmask_unattended_patched) patch_everywhere( "_prepare_4d_causal_attention_mask_for_sdpa", _prepare_4d_causal_attention_mask_for_sdpa_patched ) - if is_transformers_version(">=", "4.48"): - transformers.integrations.sdpa_attention.sdpa_attention_forward = patched_sdpa_attention_forward - def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) + if is_transformers_version(">=", "4.35"): + AttentionMaskConverter._make_causal_mask = staticmethod(self.original_make_causal_mask) + if is_transformers_version(">=", "4.36"): AttentionMaskConverter._unmask_unattended = staticmethod(self.original_unmask_unattended) patch_everywhere( "_prepare_4d_causal_attention_mask_for_sdpa", self.original_prepare_4d_causal_attention_mask_for_sdpa ) - if is_transformers_version(">=", "4.48"): - transformers.integrations.sdpa_attention.sdpa_attention_forward = self.original_sdpa_attention_forward - def __init__( self, config: "OnnxConfig", @@ -611,12 +615,12 @@ def __init__( ): super().__init__(config, model, model_kwargs) + if is_transformers_version(">=", "4.35"): + self.original_make_causal_mask = AttentionMaskConverter._make_causal_mask + if is_transformers_version(">=", "4.36"): - self.original_prepare_4d_causal_attention_mask_for_sdpa = _prepare_4d_causal_attention_mask_for_sdpa self.original_unmask_unattended = AttentionMaskConverter._unmask_unattended - - if is_transformers_version(">=", "4.48"): - self.original_sdpa_attention_forward = transformers.integrations.sdpa_attention.sdpa_attention_forward + self.original_prepare_4d_causal_attention_mask_for_sdpa = _prepare_4d_causal_attention_mask_for_sdpa def falcon_build_alibi_tensor_patched( @@ -1025,7 +1029,11 @@ def __init__( ): super().__init__(config, model, model_kwargs) - if is_transformers_version(">=", "4.42") and self.real_config._config.model_type == "mistral": + if ( + is_transformers_version(">=", "4.42") + and is_transformers_version("<", "4.48") + and self.real_config._config.model_type == "mistral" + ): self._update_causal_mask_original = self._model[0].auto_model._update_causal_mask def patched_forward(input_ids, attention_mask): @@ -1349,7 +1357,7 @@ def __init__( ): super().__init__(config, model, model_kwargs) - if is_transformers_version(">=", "4.42"): + if is_transformers_version(">=", "4.42") and is_transformers_version("<", "4.48"): if hasattr(self._model, "model"): self._update_causal_mask_original = self._model.model._update_causal_mask else: @@ -1359,9 +1367,9 @@ def __init__( class CLIPModelPatcher(ModelPatcher): def __enter__(self): super().__enter__() - if is_transformers_version(">=", "4.43"): - self.original_sdpa_forward, CLIPSdpaAttention.forward = CLIPSdpaAttention.forward, CLIPAttention.forward + self.original_sdpa_forward = CLIPSdpaAttention.forward + CLIPSdpaAttention.forward = CLIPAttention.forward def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) From 3694ea4ea26bb3bc1a9701ca054e7a021ec4a8ea Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Sun, 26 Jan 2025 14:01:59 +0100 Subject: [PATCH 21/26] patch sdpa attention forward everywhere --- optimum/exporters/onnx/model_patcher.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 9640497042..aa068c46af 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -374,10 +374,16 @@ def __enter__(self): self.patch_ops() setattr(self._model, self.orig_forward_name, self.patched_forward) + if is_transformers_version(">=", "4.48"): + ALL_ATTENTION_FUNCTIONS["sdpa"] = patched_sdpa_attention_forward + def __exit__(self, exc_type, exc_value, traceback): self.restore_ops() setattr(self._model, self.orig_forward_name, self.orig_forward) + if is_transformers_version(">=", "4.48"): + ALL_ATTENTION_FUNCTIONS["sdpa"] = sdpa_attention_forward + def __call__(self, *args, **kwargs): if getattr(self._model, self.orig_forward_name) is self.orig_forward: logger.warning("Running the non-patched model") @@ -482,16 +488,6 @@ def patched_sdpa_attention_forward( class VisionEncoderDecoderPatcher(Seq2SeqModelPatcher): - def __enter__(self): - super().__enter__() - if is_transformers_version(">=", "4.48"): - ALL_ATTENTION_FUNCTIONS["sdpa"] = patched_sdpa_attention_forward - - def __exit__(self, exc_type, exc_value, traceback): - super().__exit__(exc_type, exc_value, traceback) - if is_transformers_version(">=", "4.48"): - ALL_ATTENTION_FUNCTIONS["sdpa"] = sdpa_attention_forward - def __init__( self, config: "OnnxConfig", @@ -509,6 +505,7 @@ def __init__( def _unmask_unattended_patched(expanded_mask: torch.Tensor, min_dtype: float): return expanded_mask + else: def _unmask_unattended_patched( From 3d7d586957c22a6f99050f967f37d691525acef7 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Sun, 26 Jan 2025 14:12:29 +0100 Subject: [PATCH 22/26] fix gpt2 cross attention in seq2seq as well --- optimum/exporters/onnx/model_patcher.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index aa068c46af..0444fdfec0 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -374,16 +374,10 @@ def __enter__(self): self.patch_ops() setattr(self._model, self.orig_forward_name, self.patched_forward) - if is_transformers_version(">=", "4.48"): - ALL_ATTENTION_FUNCTIONS["sdpa"] = patched_sdpa_attention_forward - def __exit__(self, exc_type, exc_value, traceback): self.restore_ops() setattr(self._model, self.orig_forward_name, self.orig_forward) - if is_transformers_version(">=", "4.48"): - ALL_ATTENTION_FUNCTIONS["sdpa"] = sdpa_attention_forward - def __call__(self, *args, **kwargs): if getattr(self._model, self.orig_forward_name) is self.orig_forward: logger.warning("Running the non-patched model") @@ -391,6 +385,18 @@ def __call__(self, *args, **kwargs): class Seq2SeqModelPatcher(ModelPatcher): + def __enter__(self): + super().__enter__() + if is_transformers_version(">=", "4.48"): + # this is required when gpt2 is used as decoder in any + # encoder-decoder model with cross attention blocks + ALL_ATTENTION_FUNCTIONS["sdpa"] = patched_sdpa_attention_forward + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + if is_transformers_version(">=", "4.48"): + ALL_ATTENTION_FUNCTIONS["sdpa"] = sdpa_attention_forward + def __init__( self, config: "OnnxConfig", From 10833d8f5cca95f95cd8855a79b7d054d124cc1f Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 29 Jan 2025 10:48:53 +0100 Subject: [PATCH 23/26] moved traceable cache to a file for simplicity of model patcher --- optimum/exporters/onnx/_traceable_cache.py | 92 ++++++++++++++++++++++ optimum/exporters/onnx/model_patcher.py | 87 +------------------- 2 files changed, 93 insertions(+), 86 deletions(-) create mode 100644 optimum/exporters/onnx/_traceable_cache.py diff --git a/optimum/exporters/onnx/_traceable_cache.py b/optimum/exporters/onnx/_traceable_cache.py new file mode 100644 index 0000000000..50d95b935c --- /dev/null +++ b/optimum/exporters/onnx/_traceable_cache.py @@ -0,0 +1,92 @@ +from typing import Any, Dict, Optional, Tuple + +import torch +from transformers.cache_utils import logger + + +# Simply removing the nn.Module, same as in https://github.com/huggingface/transformers/pull/35873 +class TraceableCache: + """ + Base, abstract class for all caches. The actual data structure is specific to each subclass. + """ + + def __init__(self): + super().__init__() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. These are specific to each subclass and allow new types of + cache to be created. + + Return: + A tuple containing the updated key and value states. + """ + raise NotImplementedError("Make sure to implement `update` in a subclass.") + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + + # Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length" + # Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles + # infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so + # we change naming to be more explicit + def get_max_length(self) -> Optional[int]: + logger.warning_once( + "`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. " + "Calling `get_max_cache()` will raise error from v4.48" + ) + return self.get_max_cache_shape() + + def get_max_cache_shape(self) -> Optional[int]: + """Returns the maximum sequence length (i.e. max capacity) of the cache object""" + raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") + + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache.""" + # Cache without size limit -> all cache is usable + # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache + # length, we will need to evict part of the cache (and thus not all cache is usable) + max_length = self.get_max_cache_shape() + previous_seq_length = self.get_seq_length(layer_idx) + if max_length is not None and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx] != []: + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + if self.value_cache[layer_idx] != []: + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + @property + def seen_tokens(self): + logger.warning_once( + "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " + "model input instead." + ) + if hasattr(self, "_seen_tokens"): + return self._seen_tokens + else: + return None diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 0444fdfec0..16ee85d464 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -25,6 +25,7 @@ from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet from ...utils import is_transformers_version, logging +from ._traceable_cache import TraceableCache if is_transformers_version(">=", "4.35"): @@ -155,92 +156,6 @@ def onnx_compatible_unfold(input_tensor, dimension, size, step): return result -# removing the nn.Module, same as in https://github.com/huggingface/transformers/pull/35873 -class TraceableCache: - """ - Base, abstract class for all caches. The actual data structure is specific to each subclass. - """ - - def __init__(self): - super().__init__() - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. These are specific to each subclass and allow new types of - cache to be created. - - Return: - A tuple containing the updated key and value states. - """ - raise NotImplementedError("Make sure to implement `update` in a subclass.") - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") - - # Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length" - # Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles - # infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so - # we change naming to be more explicit - def get_max_length(self) -> Optional[int]: - logger.warning_once( - "`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. " - "Calling `get_max_cache()` will raise error from v4.48" - ) - return self.get_max_cache_shape() - - def get_max_cache_shape(self) -> Optional[int]: - """Returns the maximum sequence length (i.e. max capacity) of the cache object""" - raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") - - def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: - """Given the sequence length of the new inputs, returns the usable length of the cache.""" - # Cache without size limit -> all cache is usable - # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache - # length, we will need to evict part of the cache (and thus not all cache is usable) - max_length = self.get_max_cache_shape() - previous_seq_length = self.get_seq_length(layer_idx) - if max_length is not None and previous_seq_length + new_seq_length > max_length: - return max_length - new_seq_length - return previous_seq_length - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx] != []: - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - if self.value_cache[layer_idx] != []: - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - @property - def seen_tokens(self): - logger.warning_once( - "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " - "model input instead." - ) - if hasattr(self, "_seen_tokens"): - return self._seen_tokens - else: - return None UNSUPPORTED_OPS_PATCHING_SPEC = [PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold)] From 9491d17f16eb3f3c27e98dbca07aba47e34f48de Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Wed, 29 Jan 2025 10:49:50 +0100 Subject: [PATCH 24/26] Apply suggestions from code review --- optimum/exporters/tasks.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 0db04a82bf..1a216ce6e8 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -897,9 +897,7 @@ class TasksManager: "feature-extraction", "fill-mask", "text-classification", - # "multiple-choice", "token-classification", - # "question-answering", onnx="ModernBertOnnxConfig", ), "mpnet": supported_tasks_mapping( From 2b731297c246296b4d775acf40a4efbcbffa74e8 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 29 Jan 2025 10:51:06 +0100 Subject: [PATCH 25/26] style --- optimum/exporters/onnx/model_patcher.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 16ee85d464..53476ff206 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -156,8 +156,6 @@ def onnx_compatible_unfold(input_tensor, dimension, size, step): return result - - UNSUPPORTED_OPS_PATCHING_SPEC = [PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold)] CACHE_PATCHING_SPEC = [PatchingSpec(transformers.cache_utils, "Cache", TraceableCache, transformers.cache_utils.Cache)] From dea98a04f363b58a8fde2ac2380784f4749547c2 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 29 Jan 2025 10:56:58 +0100 Subject: [PATCH 26/26] fix --- optimum/exporters/onnx/_traceable_cache.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/optimum/exporters/onnx/_traceable_cache.py b/optimum/exporters/onnx/_traceable_cache.py index 50d95b935c..052cb04b12 100644 --- a/optimum/exporters/onnx/_traceable_cache.py +++ b/optimum/exporters/onnx/_traceable_cache.py @@ -1,7 +1,10 @@ +import logging from typing import Any, Dict, Optional, Tuple import torch -from transformers.cache_utils import logger + + +logger = logging.getLogger(__name__) # Simply removing the nn.Module, same as in https://github.com/huggingface/transformers/pull/35873