diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index e2bcd7fe20..26202e889b 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -279,7 +279,7 @@ class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): class Qwen2OnnxConfig(LlamaOnnxConfig): - pass + MIN_TRANSFORMERS_VERSION = version.parse("4.37.0") class GemmaOnnxConfig(LlamaOnnxConfig): @@ -291,6 +291,7 @@ class GemmaOnnxConfig(LlamaOnnxConfig): class PhiOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 14 # Phi now uses F.scaled_dot_product_attention by default for torch>=2.1.1. NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + MIN_TRANSFORMERS_VERSION = version.parse("4.36.0") class Phi3OnnxConfig(PhiOnnxConfig): @@ -299,6 +300,7 @@ class Phi3OnnxConfig(PhiOnnxConfig): ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfigWithGQA + MIN_TRANSFORMERS_VERSION = version.parse("4.41.0") class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): @@ -1173,7 +1175,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]: class OwlV2OnnxConfig(OwlViTOnnxConfig): - pass + MIN_TRANSFORMERS_VERSION = version.parse("4.35.0") class LayoutLMOnnxConfig(TextAndVisionOnnxConfig): diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 1f873e4e71..4c1f845893 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -276,13 +276,16 @@ def __init__( model.decoder.model.decoder.config.use_cache = True -def _unmask_unattended_patched( - expanded_mask: torch.Tensor, - min_dtype: float, +def _unmask_unattended_patched_legacy( + expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float] ): return expanded_mask +def _unmask_unattended_patched(expanded_mask: torch.Tensor, min_dtype: float): + return expanded_mask + + def _make_causal_mask_patched( input_ids_shape: torch.Size, dtype: torch.dtype, @@ -316,7 +319,11 @@ def _make_causal_mask_patched( _make_causal_mask_patched_staticmethod = staticmethod(_make_causal_mask_patched) -_unmask_unattended_patched_staticmethod = staticmethod(_unmask_unattended_patched) + +if _transformers_version >= version.parse("4.39.0"): + _unmask_unattended_patched_staticmethod = staticmethod(_unmask_unattended_patched) +else: + _unmask_unattended_patched_staticmethod = staticmethod(_unmask_unattended_patched_legacy) # Adapted from _prepare_4d_causal_attention_mask diff --git a/setup.py b/setup.py index 41598aeba5..ce7e537330 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ REQUIRED_PKGS = [ "coloredlogs", "sympy", - "transformers[sentencepiece]>=4.26.0,<4.43.0", + "transformers[sentencepiece]>=4.29.0,<4.43.0", "torch>=1.11", "packaging", "numpy<2.0", # transformers requires numpy<2.0 https://github.com/huggingface/transformers/pull/31569