From 967c6e2c1e5d0ea7b7b66f880b7ce8d033b91239 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 20 Jan 2025 14:31:14 +0100 Subject: [PATCH] traceable cache --- optimum/exporters/onnx/_tensor_cache.py | 95 ------------------------- optimum/exporters/onnx/model_patcher.py | 91 ++++++++++++++++++++++- 2 files changed, 89 insertions(+), 97 deletions(-) delete mode 100644 optimum/exporters/onnx/_tensor_cache.py diff --git a/optimum/exporters/onnx/_tensor_cache.py b/optimum/exporters/onnx/_tensor_cache.py deleted file mode 100644 index 154b186ef4..0000000000 --- a/optimum/exporters/onnx/_tensor_cache.py +++ /dev/null @@ -1,95 +0,0 @@ -from typing import Any, Dict, Optional, Tuple - -import torch -from transformers import logging - - -logger = logging.get_logger(__name__) - - -# The same as transformers.cache_utils.Cache but iherits from torch.Tensor instead of torch.nn.Module -class Cache(torch.Tensor): - """ - Base, abstract class for all caches. The actual data structure is specific to each subclass. - """ - - def __init__(self): - super().__init__() - - def update( - self, - key_states: torch.Tensor, - value_states: torch.Tensor, - layer_idx: int, - cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. - - Parameters: - key_states (`torch.Tensor`): - The new key states to cache. - value_states (`torch.Tensor`): - The new value states to cache. - layer_idx (`int`): - The index of the layer to cache the states for. - cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. These are specific to each subclass and allow new types of - cache to be created. - - Return: - A tuple containing the updated key and value states. - """ - raise NotImplementedError("Make sure to implement `update` in a subclass.") - - def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - # TODO: deprecate this function in favor of `cache_position` - raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") - - # Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length" - # Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles - # infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so - # we change naming to be more explicit - def get_max_length(self) -> Optional[int]: - logger.warning_once( - "`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. " - "Calling `get_max_cache()` will raise error from v4.48" - ) - return self.get_max_cache_shape() - - def get_max_cache_shape(self) -> Optional[int]: - """Returns the maximum sequence length (i.e. max capacity) of the cache object""" - raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") - - def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: - """Given the sequence length of the new inputs, returns the usable length of the cache.""" - # Cache without size limit -> all cache is usable - # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache - # length, we will need to evict part of the cache (and thus not all cache is usable) - max_length = self.get_max_cache_shape() - previous_seq_length = self.get_seq_length(layer_idx) - if max_length is not None and previous_seq_length + new_seq_length > max_length: - return max_length - new_seq_length - return previous_seq_length - - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - if self.key_cache[layer_idx] != []: - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - if self.value_cache[layer_idx] != []: - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - @property - def seen_tokens(self): - logger.warning_once( - "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " - "model input instead." - ) - if hasattr(self, "_seen_tokens"): - return self._seen_tokens - else: - return None diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index a608e947d5..29d0561f99 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -27,7 +27,6 @@ from ...configuration_utils import _transformers_version from ...utils import logging -from ._tensor_cache import Cache as PatchedCache if _transformers_version > version.parse("4.34.99"): @@ -155,8 +154,96 @@ def onnx_compatible_unfold(input_tensor, dimension, size, step): return result +# The same as transformers.cache_utils.Cache but iherits from torch.Tensor instead of torch.nn.Module +class TraceableCache(torch.Tensor): + """ + Base, abstract class for all caches. The actual data structure is specific to each subclass. + """ + + def __init__(self): + super().__init__() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. These are specific to each subclass and allow new types of + cache to be created. + + Return: + A tuple containing the updated key and value states. + """ + raise NotImplementedError("Make sure to implement `update` in a subclass.") + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + + # Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length" + # Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles + # infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so + # we change naming to be more explicit + def get_max_length(self) -> Optional[int]: + logger.warning_once( + "`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. " + "Calling `get_max_cache()` will raise error from v4.48" + ) + return self.get_max_cache_shape() + + def get_max_cache_shape(self) -> Optional[int]: + """Returns the maximum sequence length (i.e. max capacity) of the cache object""" + raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") + + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache.""" + # Cache without size limit -> all cache is usable + # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache + # length, we will need to evict part of the cache (and thus not all cache is usable) + max_length = self.get_max_cache_shape() + previous_seq_length = self.get_seq_length(layer_idx) + if max_length is not None and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx] != []: + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + if self.value_cache[layer_idx] != []: + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + @property + def seen_tokens(self): + logger.warning_once( + "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` " + "model input instead." + ) + if hasattr(self, "_seen_tokens"): + return self._seen_tokens + else: + return None + + UNSUPPORTED_OPS_PATCHING_SPEC = [PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold)] -CACHE_PATCHING_SPEC = [PatchingSpec(transformers.cache_utils, "Cache", PatchedCache, transformers.cache_utils.Cache)] +CACHE_PATCHING_SPEC = [PatchingSpec(transformers.cache_utils, "Cache", TraceableCache, transformers.cache_utils.Cache)] class ModelPatcher: