diff --git a/docs/source/package_reference/supported_models.mdx b/docs/source/package_reference/supported_models.mdx index 6af449394..98064fb5d 100644 --- a/docs/source/package_reference/supported_models.mdx +++ b/docs/source/package_reference/supported_models.mdx @@ -21,6 +21,7 @@ limitations under the License. | Architecture | Task | |---------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------| | ALBERT | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification | +| AST | feature-extraction, audio-classification | | BERT | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification | | BLOOM | text-generation | | Beit | feature-extraction, image-classification | @@ -39,6 +40,7 @@ limitations under the License. | ESM | feature-extraction, fill-mask, text-classification, token-classification | | FlauBERT | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification | | GPT2 | text-generation | +| Hubert | feature-extraction, automatic-speech-recognition, audio-classification | | Levit | feature-extraction, image-classification | | Llama, Llama 2, Llama 3 | text-generation | | Mistral | text-generation | @@ -53,9 +55,12 @@ limitations under the License. | RoFormer | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification | | Swin | feature-extraction, image-classification | | T5 | text2text-generation | +| UniSpeech | feature-extraction, automatic-speech-recognition, audio-classification | +| UniSpeech-SAT | feature-extraction, automatic-speech-recognition, audio-classification, audio-frame-classification, audio-xvector | +| ViT | feature-extraction, image-classification | | Wav2Vec2 | feature-extraction, automatic-speech-recognition, audio-classification, audio-frame-classification, audio-xvector | +| WavLM | feature-extraction, automatic-speech-recognition, audio-classification, audio-frame-classification, audio-xvector | | XLM | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification | -| ViT | feature-extraction, image-classification | | XLM-RoBERTa | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification | | Yolos | feature-extraction, object-detection | diff --git a/optimum/exporters/neuron/model_configs.py b/optimum/exporters/neuron/model_configs.py index e1fed6e48..3ff82e886 100644 --- a/optimum/exporters/neuron/model_configs.py +++ b/optimum/exporters/neuron/model_configs.py @@ -19,7 +19,12 @@ import torch -from ...neuron.utils import DummyBeamValuesGenerator, DummyControNetInputGenerator, DummyMaskedPosGenerator +from ...neuron.utils import ( + ASTDummyAudioInputGenerator, + DummyBeamValuesGenerator, + DummyControNetInputGenerator, + DummyMaskedPosGenerator, +) from ...utils import ( DummyInputGenerator, DummySeq2SeqDecoderTextInputGenerator, @@ -423,11 +428,144 @@ def inputs(self) -> List[str]: @property def outputs(self) -> List[str]: common_outputs = super().outputs + if self.task == "feature-extraction": + common_outputs = ["last_hidden_state", "extract_features"] if self.task == "audio-xvector": common_outputs.append("embeddings") return common_outputs +@register_in_tasks_manager( + "audio-spectrogram-transformer", + *[ + "feature-extraction", + "audio-classification", + ], +) +class ASTNeuronConfig(AudioNeuronConfig): + NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( + num_mel_bins="num_mel_bins", max_length="max_length", allow_new=True + ) + DUMMY_INPUT_GENERATOR_CLASSES = (ASTDummyAudioInputGenerator,) + + @property + def inputs(self) -> List[str]: + return ["input_values"] + + +@register_in_tasks_manager( + "hubert", + *[ + "feature-extraction", + "automatic-speech-recognition", + "audio-classification", + ], +) +class HubertNeuronConfig(Wav2Vec2NeuronConfig): + @property + def outputs(self) -> List[str]: + common_outputs = super().outputs + if self.task == "feature-extraction": + common_outputs = ["last_hidden_state"] + return common_outputs + + +# TODO: compilation failed due to a bug in xla: https://github.com/pytorch/xla/issues/6398. +# @register_in_tasks_manager( +# "sew", +# *[ +# "feature-extraction", +# "automatic-speech-recognition", +# "audio-classification", +# ], +# ) +# class SEWNeuronConfig(Wav2Vec2NeuronConfig): +# pass + + +# TODO: compilation failed due to a bug in xla: https://github.com/pytorch/xla/issues/6398. +# @register_in_tasks_manager( +# "sew-d", +# *[ +# "feature-extraction", +# "automatic-speech-recognition", +# "audio-classification", +# ], +# ) +# class SEWDNeuronConfig(Wav2Vec2NeuronConfig): +# pass + + +@register_in_tasks_manager( + "unispeech", + *[ + "feature-extraction", + "automatic-speech-recognition", + "audio-classification", + ], +) +class UniSpeechNeuronConfig(Wav2Vec2NeuronConfig): + pass + + +@register_in_tasks_manager( + "unispeech-sat", + *[ + "feature-extraction", + "automatic-speech-recognition", + "audio-classification", + "audio-frame-classification", + "audio-xvector", + ], +) +class UniSpeechSATNeuronConfig(Wav2Vec2NeuronConfig): + pass + + +# TODO: compilation failed due to a bug in xla: https://github.com/pytorch/xla/issues/6398. +# @register_in_tasks_manager( +# "wav2vec2-bert", +# *[ +# "feature-extraction", +# "automatic-speech-recognition", +# "audio-classification", +# "audio-frame-classification", +# "audio-xvector", +# ], +# ) +# class Wav2Vec2BertNeuronConfig(Wav2Vec2NeuronConfig): +# pass + + +# TODO: compilation failed due to a bug in xla: https://github.com/pytorch/xla/issues/6398. +# @register_in_tasks_manager( +# "wav2vec2-conformer", +# *[ +# "feature-extraction", +# "automatic-speech-recognition", +# "audio-classification", +# "audio-frame-classification", +# "audio-xvector", +# ], +# ) +# class Wav2Vec2ConformerNeuronConfig(Wav2Vec2NeuronConfig): +# pass + + +@register_in_tasks_manager( + "wavlm", + *[ + "feature-extraction", + "automatic-speech-recognition", + "audio-classification", + "audio-frame-classification", + "audio-xvector", + ], +) +class WavLMNeuronConfig(Wav2Vec2NeuronConfig): + pass + + @register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers") class UNetNeuronConfig(VisionNeuronConfig): ATOL_FOR_VALIDATION = 1e-3 diff --git a/optimum/neuron/utils/__init__.py b/optimum/neuron/utils/__init__.py index 2f305947f..ce8283639 100644 --- a/optimum/neuron/utils/__init__.py +++ b/optimum/neuron/utils/__init__.py @@ -41,7 +41,12 @@ "is_torch_xla_available", "is_transformers_neuronx_available", ], - "input_generators": ["DummyBeamValuesGenerator", "DummyMaskedPosGenerator", "DummyControNetInputGenerator"], + "input_generators": [ + "DummyBeamValuesGenerator", + "DummyMaskedPosGenerator", + "DummyControNetInputGenerator", + "ASTDummyAudioInputGenerator", + ], "misc": [ "DiffusersPretrainedConfig", "check_if_weights_replacable", @@ -93,7 +98,12 @@ is_torch_xla_available, is_transformers_neuronx_available, ) - from .input_generators import DummyBeamValuesGenerator, DummyControNetInputGenerator, DummyMaskedPosGenerator + from .input_generators import ( + ASTDummyAudioInputGenerator, + DummyBeamValuesGenerator, + DummyControNetInputGenerator, + DummyMaskedPosGenerator, + ) from .misc import ( DiffusersPretrainedConfig, check_if_weights_replacable, diff --git a/optimum/neuron/utils/input_generators.py b/optimum/neuron/utils/input_generators.py index 7fbdd38d1..c98cb28eb 100644 --- a/optimum/neuron/utils/input_generators.py +++ b/optimum/neuron/utils/input_generators.py @@ -20,6 +20,7 @@ from ...utils import ( DTYPE_MAPPER, + DummyAudioInputGenerator, DummyInputGenerator, NormalizedTextConfig, NormalizedVisionConfig, @@ -163,3 +164,12 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int self.width // 2**num_cross_attn_blocks, ) return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) + + +# copied from https://github.com/huggingface/optimum/blob/171020c775cec6ff77826c3f5f5e5c1498b23f81/optimum/exporters/onnx/model_configs.py#L1363C1-L1368C111 +class ASTDummyAudioInputGenerator(DummyAudioInputGenerator): + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + shape = [self.batch_size, self.normalized_config.max_length, self.normalized_config.num_mel_bins] + if input_name == "input_values": + return self.random_float_tensor(shape, min_value=-1, max_value=1, framework=framework, dtype=float_dtype) + return super().generate(input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype) diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index 7dd5303b5..9969a892e 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -16,6 +16,7 @@ EXPORT_MODELS_TINY = { "albert": "hf-internal-testing/tiny-random-AlbertModel", + "audio-spectrogram-transformer": "Ericwang/tiny-random-ast", "beit": "hf-internal-testing/tiny-random-BeitForImageClassification", "bert": "hf-internal-testing/tiny-random-BertModel", "camembert": "hf-internal-testing/tiny-random-camembert", @@ -32,6 +33,7 @@ "electra": "hf-internal-testing/tiny-random-ElectraModel", "esm": "hf-internal-testing/tiny-random-EsmModel", "flaubert": "flaubert/flaubert_small_cased", + "hubert": "hf-internal-testing/tiny-random-HubertModel", "levit": "hf-internal-testing/tiny-random-LevitModel", "mobilebert": "hf-internal-testing/tiny-random-MobileBertModel", "mobilenet-v2": "hf-internal-testing/tiny-random-MobileNetV2Model", @@ -40,8 +42,15 @@ "phi": "bumblebee-testing/tiny-random-PhiModel", "roberta": "hf-internal-testing/tiny-random-RobertaModel", "roformer": "hf-internal-testing/tiny-random-RoFormerModel", + # "sew": "hf-internal-testing/tiny-random-SEWModel", # blocked + # "sew-d": "hf-internal-testing/tiny-random-SEWDModel", # blocked "swin": "hf-internal-testing/tiny-random-SwinModel", + "unispeech": "hf-internal-testing/tiny-random-unispeech", + "unispeech-sat": "hf-internal-testing/tiny-random-unispeech-sat", "vit": "hf-internal-testing/tiny-random-vit", + "wav2vec2": "hf-internal-testing/tiny-random-Wav2Vec2Model", + # "wav2vec2-conformer": "hf-internal-testing/tiny-random-wav2vec2-conformer", # blocked + "wavlm": "hf-internal-testing/tiny-random-wavlm", "xlm": "hf-internal-testing/tiny-random-XLMModel", "xlm-roberta": "hf-internal-testing/tiny-xlm-roberta", "yolos": "hf-internal-testing/tiny-random-YolosModel", diff --git a/tests/exporters/test_export.py b/tests/exporters/test_export.py index 50fcce2dd..316485e7b 100644 --- a/tests/exporters/test_export.py +++ b/tests/exporters/test_export.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import os import random import unittest @@ -129,15 +128,16 @@ def _neuronx_export( if library_name == "sentence_transformers": model_class = TasksManager.get_model_class_for_task(task, framework="pt", library=library_name) model = model_class(model_name) + reference_model = model_class(model_name) if "clip" in model[0].__class__.__name__.lower(): config = model[0].model.config else: config = model[0].auto_model.config else: - model_class = TasksManager.get_model_class_for_task(task, framework="pt") + model_class = TasksManager.get_model_class_for_task(task, model_type=model_type, framework="pt") config = AutoConfig.from_pretrained(model_name) model = model_class.from_config(config) - reference_model = copy.deepcopy(model) + reference_model = model_class.from_config(config) mandatory_shapes = { name: DEFAULT_DUMMY_SHAPES.get(name) or EXTREA_DEFAULT_DUMMY_SHAPES.get(name) diff --git a/tests/inference/inference_utils.py b/tests/inference/inference_utils.py index 3accdf96b..f215a3912 100644 --- a/tests/inference/inference_utils.py +++ b/tests/inference/inference_utils.py @@ -30,6 +30,7 @@ MODEL_NAMES = { "albert": "hf-internal-testing/tiny-random-AlbertModel", + "audio-spectrogram-transformer": "Ericwang/tiny-random-ast", "beit": "hf-internal-testing/tiny-random-BeitForImageClassification", "bert": "hf-internal-testing/tiny-random-BertModel", "camembert": "hf-internal-testing/tiny-random-camembert", @@ -45,6 +46,7 @@ "dpt": "hf-internal-testing/tiny-random-DPTModel", "electra": "hf-internal-testing/tiny-random-ElectraModel", "flaubert": "flaubert/flaubert_small_cased", + "hubert": "hf-internal-testing/tiny-random-HubertModel", "gpt2": "hf-internal-testing/tiny-random-gpt2", "latent-consistency": "echarlaix/tiny-random-latent-consistency", "levit": "hf-internal-testing/tiny-random-LevitModel", @@ -55,12 +57,18 @@ "phi": "bumblebee-testing/tiny-random-PhiModel", "roberta": "hf-internal-testing/tiny-random-RobertaModel", "roformer": "hf-internal-testing/tiny-random-RoFormerModel", + # "sew": "hf-internal-testing/tiny-random-SEWModel", # blocked + # "sew-d": "hf-internal-testing/tiny-random-SEWDModel", # blocked "swin": "hf-internal-testing/tiny-random-SwinModel", - "vit": "hf-internal-testing/tiny-random-vit", "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", "stable-diffusion-ip2p": "asntr/tiny-stable-diffusion-pix2pix-torch", "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", + "unispeech": "hf-internal-testing/tiny-random-unispeech", + "unispeech-sat": "hf-internal-testing/tiny-random-unispeech-sat", + "vit": "hf-internal-testing/tiny-random-vit", "wav2vec2": "hf-internal-testing/tiny-random-Wav2Vec2Model", + # "wav2vec2-conformer": "hf-internal-testing/tiny-random-wav2vec2-conformer", # blocked + "wavlm": "hf-internal-testing/tiny-random-wavlm", "xlm": "hf-internal-testing/tiny-random-XLMModel", "xlm-roberta": "hf-internal-testing/tiny-xlm-roberta", "yolos": "hf-internal-testing/tiny-random-YolosModel",