Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformers 4.48 #2158

Merged
merged 26 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
5190280
test
IlyasMoutawwakil Jan 16, 2025
6a03d76
testing tensor cache x)
IlyasMoutawwakil Jan 20, 2025
7207215
fix logger
IlyasMoutawwakil Jan 20, 2025
6261094
condition cache class usage
IlyasMoutawwakil Jan 20, 2025
822066d
update opset for beit and data2vec vision and skip flattened/fused pk…
IlyasMoutawwakil Jan 20, 2025
3ab38fd
style
IlyasMoutawwakil Jan 20, 2025
d713e5a
fix args patcher
IlyasMoutawwakil Jan 20, 2025
bf4d1f3
fix modernbert testing
IlyasMoutawwakil Jan 20, 2025
230c3a0
adaot to new whisper returned generation length
IlyasMoutawwakil Jan 20, 2025
3d5d9c9
fix is_causal in transformers
IlyasMoutawwakil Jan 20, 2025
96e2714
fix modernbert failures
IlyasMoutawwakil Jan 20, 2025
78a2dba
style
IlyasMoutawwakil Jan 20, 2025
967c6e2
traceable cache
IlyasMoutawwakil Jan 20, 2025
1d74388
use pkv index
IlyasMoutawwakil Jan 24, 2025
d452c46
add version gard and clean up other model patcher version gards
IlyasMoutawwakil Jan 24, 2025
5dcab7f
patch sdpa attention in optimum for now
IlyasMoutawwakil Jan 24, 2025
656941a
remove modernbert condition
IlyasMoutawwakil Jan 24, 2025
1bcb38f
style
IlyasMoutawwakil Jan 24, 2025
23fa20e
fix MistralModelPatcher
IlyasMoutawwakil Jan 24, 2025
24c8f4b
correctly patch gpt2 in vision encoder decoder
IlyasMoutawwakil Jan 24, 2025
3694ea4
patch sdpa attention forward everywhere
IlyasMoutawwakil Jan 26, 2025
3d7d586
fix gpt2 cross attention in seq2seq as well
IlyasMoutawwakil Jan 26, 2025
10833d8
moved traceable cache to a file for simplicity of model patcher
IlyasMoutawwakil Jan 29, 2025
9491d17
Apply suggestions from code review
IlyasMoutawwakil Jan 29, 2025
2b73129
style
IlyasMoutawwakil Jan 29, 2025
dea98a0
fix
IlyasMoutawwakil Jan 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions optimum/exporters/onnx/_tensor_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Any, Dict, Optional, Tuple

import torch
from transformers import logging


logger = logging.get_logger(__name__)


# The same as transformers.cache_utils.Cache but iherits from torch.Tensor instead of torch.nn.Module
class Cache(torch.Tensor):
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
"""

def __init__(self):
super().__init__()

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
cache to be created.

Return:
A tuple containing the updated key and value states.
"""
raise NotImplementedError("Make sure to implement `update` in a subclass.")

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")

# Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length"
# Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles
# infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so
# we change naming to be more explicit
def get_max_length(self) -> Optional[int]:
logger.warning_once(
"`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. "
"Calling `get_max_cache()` will raise error from v4.48"
)
return self.get_max_cache_shape()

def get_max_cache_shape(self) -> Optional[int]:
"""Returns the maximum sequence length (i.e. max capacity) of the cache object"""
raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.")

def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
# Cache without size limit -> all cache is usable
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
# length, we will need to evict part of the cache (and thus not all cache is usable)
max_length = self.get_max_cache_shape()
previous_seq_length = self.get_seq_length(layer_idx)
if max_length is not None and previous_seq_length + new_seq_length > max_length:
return max_length - new_seq_length
return previous_seq_length

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
if self.key_cache[layer_idx] != []:
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
if self.value_cache[layer_idx] != []:
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

@property
def seen_tokens(self):
logger.warning_once(
"The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
"model input instead."
)
if hasattr(self, "_seen_tokens"):
return self._seen_tokens
else:
return None
7 changes: 3 additions & 4 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ class DeiTOnnxConfig(ViTOnnxConfig):


class BeitOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.


class ConvNextOnnxConfig(ViTOnnxConfig):
Expand Down Expand Up @@ -1573,13 +1573,12 @@ class Data2VecTextOnnxConfig(DistilBertOnnxConfig):


class Data2VecVisionOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.


class Data2VecAudioOnnxConfig(AudioOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedConfig
ATOL_FOR_VALIDATION = 1e-4
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
NORMALIZED_CONFIG_CLASS = NormalizedConfig


class PerceiverDummyInputGenerator(DummyVisionInputGenerator):
Expand Down
76 changes: 62 additions & 14 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import transformers
import transformers.cache_utils
from packaging import version
from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet
from transformers.utils import is_torch_available
Expand All @@ -31,6 +32,7 @@

from ...configuration_utils import _transformers_version
from ...utils import logging
from ._tensor_cache import Cache as PatchedCache


if _transformers_version > version.parse("4.34.99"):
Expand All @@ -49,6 +51,7 @@

from .base import OnnxConfig


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -158,6 +161,7 @@ def onnx_compatible_unfold(input_tensor, dimension, size, step):


UNSUPPORTED_OPS_PATCHING_SPEC = [PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold)]
CACHE_PATCHING_SPEC = [PatchingSpec(transformers.cache_utils, "Cache", PatchedCache, transformers.cache_utils.Cache)]


class ModelPatcher:
Expand All @@ -171,6 +175,7 @@ def __init__(

patching_specs = config.PATCHING_SPECS or []
patching_specs.extend(UNSUPPORTED_OPS_PATCHING_SPEC)
patching_specs.extend(CACHE_PATCHING_SPEC)

self._patching_specs = []
for spec in patching_specs:
Expand All @@ -197,8 +202,42 @@ def patched_forward(*args, **kwargs):
signature = inspect.signature(self.orig_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)

if _transformers_version >= version.parse("4.48"):
from transformers.cache_utils import DynamicCache, EncoderDecoderCache

if isinstance(kwargs.get("past_key_values"), (list, tuple)) and isinstance(
kwargs["past_key_values"][0], (list, tuple)
):
if len(kwargs["past_key_values"][0]) == 2:
kwargs["past_key_values"] = DynamicCache.from_legacy_cache(kwargs["past_key_values"])
elif len(kwargs["past_key_values"][0]) == 4:
kwargs["past_key_values"] = EncoderDecoderCache.from_legacy_cache(kwargs["past_key_values"])
else:
raise ValueError(
f"past_key_values should have either 2 or 4 elements, but it has {len(kwargs['past_key_values'][0])} elements"
)

elif any(isinstance(arg, (list, tuple)) for arg in args):
for i in range(len(args)):
if isinstance(args[i], (list, tuple)) and isinstance(args[i][0], (list, tuple)):
if len(args[i][0]) == 2:
args[i] = DynamicCache.from_legacy_cache(args[i])
elif len(args[i][0]) == 4:
args[i] = EncoderDecoderCache.from_legacy_cache(args[i])
else:
raise ValueError(
f"past_key_values should have either 2 or 4 elements, but it has {len(args[i])} elements"
)
break

outputs = self.orig_forward(*args, **kwargs)

if _transformers_version >= version.parse("4.48"):
if "past_key_values" in outputs and isinstance(
outputs["past_key_values"], (DynamicCache, EncoderDecoderCache)
):
outputs["past_key_values"] = outputs["past_key_values"].to_legacy_cache()

# This code block handles different cases of the filterd_outputs input to align it with the expected
# format of outputs. It is common for the output type of a model to vary, such as tensor, list,
# tuple, etc. For Transformers models, the output is encapsulated in a ModelOutput object that
Expand Down Expand Up @@ -230,6 +269,7 @@ def patched_forward(*args, **kwargs):
filterd_outputs[name] = outputs
name = list(config.outputs.keys())[0]
filterd_outputs[name] = outputs

return filterd_outputs

self.patched_forward = patched_forward
Expand Down Expand Up @@ -833,14 +873,22 @@ def patched_forward(
class SentenceTransformersTransformerPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()
if _transformers_version >= version.parse("4.42") and self.real_config._config.model_type == "mistral":
if (
_transformers_version >= version.parse("4.42")
and _transformers_version < version.parse("4.48")
and self.real_config._config.model_type == "mistral"
):
self._model[0].auto_model._update_causal_mask = types.MethodType(
_update_causal_mask_patched, self._model[0].auto_model
)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if _transformers_version >= version.parse("4.42") and self.real_config._config.model_type == "mistral":
if (
_transformers_version >= version.parse("4.42")
and _transformers_version < version.parse("4.48")
and self.real_config._config.model_type == "mistral"
):
self._model[0].auto_model._update_causal_mask = types.MethodType(
self._update_causal_mask_original, self._model[0].auto_model
)
Expand Down Expand Up @@ -1132,16 +1180,16 @@ def _update_causal_mask_patched(
padding_mask, min_dtype
)

# if (
# self.config._attn_implementation == "sdpa"
# and attention_mask is not None
# and attention_mask.device.type == "cuda"
# and not output_attentions
# ):
# # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# # Details: https://github.com/pytorch/pytorch/issues/110213
# causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

return causal_mask

Expand All @@ -1161,7 +1209,7 @@ def __enter__(self):
"_prepare_4d_causal_attention_mask_for_sdpa", _prepare_4d_causal_attention_mask_for_sdpa_patched
)

if _transformers_version >= version.parse("4.42"):
if _transformers_version >= version.parse("4.42") and _transformers_version < version.parse("4.48"):
if hasattr(self._model, "model"):
self._model.model._update_causal_mask = types.MethodType(
_update_causal_mask_patched, self._model.model
Expand All @@ -1183,7 +1231,7 @@ def __exit__(self, exc_type, exc_value, traceback):
"_prepare_4d_causal_attention_mask_for_sdpa", self.original_prepare_4d_causal_attention_mask_for_sdpa
)

if _transformers_version >= version.parse("4.42"):
if _transformers_version >= version.parse("4.42") and _transformers_version < version.parse("4.48"):
if hasattr(self._model, "model"):
self._model.model._update_causal_mask = types.MethodType(
self._update_causal_mask_original, self._model.model
Expand Down
4 changes: 2 additions & 2 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,9 +897,9 @@ class TasksManager:
"feature-extraction",
"fill-mask",
"text-classification",
"multiple-choice",
# "multiple-choice",
"token-classification",
"question-answering",
# "question-answering",
onnx="ModernBertOnnxConfig",
),
"mpnet": supported_tasks_mapping(
Expand Down
12 changes: 8 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@
"datasets>=1.2.1",
"evaluate",
"protobuf>=3.20.1",
"transformers>=4.36,<4.48.0",
# "transformers>=4.36,<4.49.0",
"transformers@git+https://github.com/huggingface/transformers.git@fix-gpt2-is-causal",
],
"onnxruntime-gpu": [
"onnx",
Expand All @@ -59,19 +60,22 @@
"evaluate",
"protobuf>=3.20.1",
"accelerate", # ORTTrainer requires it.
"transformers>=4.36,<4.48.0",
# "transformers>=4.36,<4.49.0",
"transformers@git+https://github.com/huggingface/transformers.git@fix-gpt2-is-causal",
],
"exporters": [
"onnx",
"onnxruntime",
"timm",
"transformers>=4.36,<4.48.0",
# "transformers>=4.36,<4.49.0",
"transformers@git+https://github.com/huggingface/transformers.git@fix-gpt2-is-causal",
],
"exporters-gpu": [
"onnx",
"onnxruntime-gpu",
"timm",
"transformers>=4.36,<4.48.0",
# "transformers>=4.36,<4.49.0",
"transformers@git+https://github.com/huggingface/transformers.git@fix-gpt2-is-causal",
],
"exporters-tf": [
"tensorflow>=2.4,<=2.12.1",
Expand Down
11 changes: 9 additions & 2 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from optimum.utils import is_transformers_version


VALIDATE_EXPORT_ON_SHAPES_SLOW = {
"batch_size": [1, 3, 5],
"sequence_length": [8, 33, 96, 154],
Expand Down Expand Up @@ -125,7 +128,11 @@
"mobilenet-v2": "hf-internal-testing/tiny-random-MobileNetV2Model",
"mobilenet-v1": "google/mobilenet_v1_0.75_192",
"mobilevit": "hf-internal-testing/tiny-random-mobilevit",
"modernbert": "hf-internal-testing/tiny-random-ModernBertForMaskedLM",
**(
{"modernbert": "hf-internal-testing/tiny-random-ModernBertForMaskedLM"}
if is_transformers_version(">=", "4.48")
else {}
),
"mpnet": "hf-internal-testing/tiny-random-MPNetModel",
"mpt": "hf-internal-testing/tiny-random-MptForCausalLM",
"mt5": "lewtun/tiny-random-mt5",
Expand Down Expand Up @@ -269,7 +276,7 @@
# "mobilenet_v1": "google/mobilenet_v1_0.75_192",
# "mobilenet_v2": "google/mobilenet_v2_0.35_96",
"mobilevit": "apple/mobilevit-small",
"modernbert": "answerdotai/ModernBERT-base",
**({"modernbert": "answerdotai/ModernBERT-base"} if is_transformers_version(">=", "4.48") else {}),
"mpt": "mosaicml/mpt-7b",
"mt5": "lewtun/tiny-random-mt5", # Not using google/mt5-small because it takes too much time for testing.
"musicgen": "facebook/musicgen-small",
Expand Down
4 changes: 3 additions & 1 deletion tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4612,7 +4612,9 @@ def test_compare_with_and_without_past_key_values(self, model_arch: str):

self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv))

if model_arch == "whisper" and is_transformers_version(">=", "4.43"):
if model_arch == "whisper" and is_transformers_version(">=", "4.48"):
gen_length = self.GENERATION_LENGTH
elif model_arch == "whisper" and is_transformers_version(">=", "4.43"):
gen_length = self.GENERATION_LENGTH + 2
else:
gen_length = self.GENERATION_LENGTH + 1
Expand Down
Loading