diff --git a/.github/workflows/test_inf1_export.yml b/.github/workflows/test_inf1_export.yml index 8db5ea7ef..3512a5b1e 100644 --- a/.github/workflows/test_inf1_export.yml +++ b/.github/workflows/test_inf1_export.yml @@ -1,16 +1,7 @@ name: Optimum neuron / Test INF1 partial export on: - push: - branches: [ main ] - paths: - - "setup.py" - - "optimum/**.py" - pull_request: - branches: [ main ] - paths: - - "setup.py" - - "optimum/**.py" + workflow_dispatch concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} diff --git a/.github/workflows/test_inf1_full_export.yml b/.github/workflows/test_inf1_full_export.yml index a7f6deb84..fbaf70727 100644 --- a/.github/workflows/test_inf1_full_export.yml +++ b/.github/workflows/test_inf1_full_export.yml @@ -1,14 +1,7 @@ name: Optimum neuron / Test INF1 full export on: - push: - branches: [ main ] - paths: - - "optimum/exporters/neuron/*.py" - pull_request: - branches: [ main ] - paths: - - "optimum/exporters/neuron/*.py" + workflow_dispatch concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} diff --git a/.github/workflows/test_inf1_inference.yml b/.github/workflows/test_inf1_inference.yml index ccb71a8f7..b587f8cd3 100644 --- a/.github/workflows/test_inf1_inference.yml +++ b/.github/workflows/test_inf1_inference.yml @@ -1,16 +1,7 @@ name: Optimum neuron / Test INF1 inference on: - push: - branches: [ main ] - paths: - - "setup.py" - - "optimum/**.py" - pull_request: - branches: [ main ] - paths: - - "setup.py" - - "optimum/**.py" + workflow_dispatch concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} diff --git a/.github/workflows/test_inf1_pipelines.yml b/.github/workflows/test_inf1_pipelines.yml index d953c2179..7351f8dcf 100644 --- a/.github/workflows/test_inf1_pipelines.yml +++ b/.github/workflows/test_inf1_pipelines.yml @@ -1,14 +1,7 @@ name: Optimum neuron / Test INF1 pipelines on: - push: - branches: [ main ] - paths: - - "optimum/neuron/pipelines/**.py" - pull_request: - branches: [ main ] - paths: - - "optimum/neuron/pipelines/**.py" + workflow_dispatch concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} diff --git a/docs/source/package_reference/modeling.mdx b/docs/source/package_reference/modeling.mdx index 819c68082..f0258b772 100644 --- a/docs/source/package_reference/modeling.mdx +++ b/docs/source/package_reference/modeling.mdx @@ -82,6 +82,22 @@ The following Neuron model classes are available for computer vision tasks. ### NeuronModelForObjectDetection [[autodoc]] modeling.NeuronModelForObjectDetection +## Audio + +The following auto classes are available for the following audio tasks. + +### NeuronModelForAudioClassification +[[autodoc]] modeling.NeuronModelForAudioClassification + +### NeuronModelForAudioFrameClassification +[[autodoc]] modeling.NeuronModelForAudioFrameClassification + +### NeuronModelForCTC +[[autodoc]] modeling.NeuronModelForCTC + +### NeuronModelForXVector +[[autodoc]] modeling.NeuronModelForXVector + ## Stable Diffusion The following Neuron model classes are available for stable diffusion tasks. diff --git a/docs/source/package_reference/supported_models.mdx b/docs/source/package_reference/supported_models.mdx index a1b6d0243..6af449394 100644 --- a/docs/source/package_reference/supported_models.mdx +++ b/docs/source/package_reference/supported_models.mdx @@ -53,6 +53,7 @@ limitations under the License. | RoFormer | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification | | Swin | feature-extraction, image-classification | | T5 | text2text-generation | +| Wav2Vec2 | feature-extraction, automatic-speech-recognition, audio-classification, audio-frame-classification, audio-xvector | | XLM | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification | | ViT | feature-extraction, image-classification | | XLM-RoBERTa | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification | diff --git a/optimum/commands/export/neuronx.py b/optimum/commands/export/neuronx.py index fd717ce1f..2d42565ee 100644 --- a/optimum/commands/export/neuronx.py +++ b/optimum/commands/export/neuronx.py @@ -249,6 +249,11 @@ def parse_args_neuronx(parser: "ArgumentParser"): default=1, help=f"Stable diffusion only. Number of images per prompt {doc_input}", ) + input_group.add_argument( + "--audio_sequence_length", + type=int, + help=f"Audio tasks only. Audio sequence length {doc_input}", + ) level_group = parser.add_mutually_exclusive_group() level_group.add_argument( diff --git a/optimum/exporters/neuron/config.py b/optimum/exporters/neuron/config.py index 33e680ef3..a5dd4202b 100644 --- a/optimum/exporters/neuron/config.py +++ b/optimum/exporters/neuron/config.py @@ -19,6 +19,7 @@ from typing import List from ...utils import ( + DummyAudioInputGenerator, DummyBboxInputGenerator, DummyInputGenerator, DummySeq2SeqDecoderTextInputGenerator, @@ -59,6 +60,15 @@ class TextAndVisionNeuronConfig(NeuronDefaultConfig): DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, DummyVisionInputGenerator, DummyBboxInputGenerator) +class AudioNeuronConfig(NeuronDefaultConfig): + """ + Handles audio architectures. + """ + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyAudioInputGenerator, DummyTextInputGenerator) + INPUT_ARGS = ("batch_size", "audio_sequence_length") + + class TextNeuronDecoderConfig(NeuronDecoderConfig): """ Handles text decoder architectures. diff --git a/optimum/exporters/neuron/model_configs.py b/optimum/exporters/neuron/model_configs.py index f65792853..e1fed6e48 100644 --- a/optimum/exporters/neuron/model_configs.py +++ b/optimum/exporters/neuron/model_configs.py @@ -36,6 +36,7 @@ ) from ..tasks import TasksManager from .config import ( + AudioNeuronConfig, TextAndVisionNeuronConfig, TextEncoderNeuronConfig, TextNeuronDecoderConfig, @@ -402,6 +403,31 @@ def outputs(self) -> List[str]: return common_outputs +@register_in_tasks_manager( + "wav2vec2", + *[ + "feature-extraction", + "automatic-speech-recognition", + "audio-classification", + "audio-frame-classification", + "audio-xvector", + ], +) +class Wav2Vec2NeuronConfig(AudioNeuronConfig): + NORMALIZED_CONFIG_CLASS = NormalizedConfig + + @property + def inputs(self) -> List[str]: + return ["input_values"] + + @property + def outputs(self) -> List[str]: + common_outputs = super().outputs + if self.task == "audio-xvector": + common_outputs.append("embeddings") + return common_outputs + + @register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers") class UNetNeuronConfig(VisionNeuronConfig): ATOL_FOR_VALIDATION = 1e-3 diff --git a/optimum/neuron/__init__.py b/optimum/neuron/__init__.py index 8cd1328b1..9f989a6c2 100644 --- a/optimum/neuron/__init__.py +++ b/optimum/neuron/__init__.py @@ -42,6 +42,10 @@ "NeuronModelForImageClassification", "NeuronModelForSemanticSegmentation", "NeuronModelForObjectDetection", + "NeuronModelForCTC", + "NeuronModelForAudioClassification", + "NeuronModelForAudioFrameClassification", + "NeuronModelForXVector", ], "modeling_diffusion": [ "NeuronStableDiffusionPipelineBase", @@ -71,7 +75,10 @@ from .accelerate import ModelParallelismPlugin, NeuronAccelerator, NeuronAcceleratorState, NeuronPartialState from .hf_argparser import NeuronHfArgumentParser from .modeling import ( + NeuronModelForAudioClassification, + NeuronModelForAudioFrameClassification, NeuronModelForCausalLM, + NeuronModelForCTC, NeuronModelForFeatureExtraction, NeuronModelForImageClassification, NeuronModelForMaskedLM, @@ -82,6 +89,7 @@ NeuronModelForSentenceTransformers, NeuronModelForSequenceClassification, NeuronModelForTokenClassification, + NeuronModelForXVector, ) from .modeling_decoder import NeuronDecoderModel from .modeling_diffusion import ( diff --git a/optimum/neuron/modeling.py b/optimum/neuron/modeling.py index 35b26a990..864bfdb5a 100644 --- a/optimum/neuron/modeling.py +++ b/optimum/neuron/modeling.py @@ -21,7 +21,11 @@ import torch from transformers import ( AutoModel, + AutoModelForAudioClassification, + AutoModelForAudioFrameClassification, + AutoModelForAudioXVector, AutoModelForCausalLM, + AutoModelForCTC, AutoModelForImageClassification, AutoModelForMaskedLM, AutoModelForMultipleChoice, @@ -37,6 +41,7 @@ ) from transformers.modeling_outputs import ( BaseModelOutputWithPooling, + CausalLMOutput, ImageClassifierOutput, MaskedLMOutput, ModelOutput, @@ -45,6 +50,7 @@ SemanticSegmenterOutput, SequenceClassifierOutput, TokenClassifierOutput, + XVectorOutput, ) from .generation import TokenSelector @@ -65,6 +71,7 @@ _TOKENIZER_FOR_DOC = "AutoTokenizer" _PROCESSOR_FOR_IMAGE = "AutoImageProcessor" +_GENERIC_PROCESSOR = "AutoProcessor" NEURON_MODEL_START_DOCSTRING = r""" This model inherits from [`~neuron.modeling.NeuronTracedModel`]. Check the superclass documentation for the generic methods the @@ -101,7 +108,14 @@ Args: pixel_values (`Union[torch.Tensor, None]` of shape `({0})`, defaults to `None`): Pixel values corresponding to the images in the current batch. - Pixel values can be obtained from encoded images using [`AutoFeatureExtractor`](https://huggingface.co/docs/transformers/autoclass_tutorial#autofeatureextractor). + Pixel values can be obtained from encoded images using [`AutoImageProcessor`](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoImageProcessor). +""" + +NEURON_AUDIO_INPUTS_DOCSTRING = r""" + Args: + input_values (`torch.Tensor` of shape `({0})`): + Float values of input raw speech waveform.. + Input values can be obtained from audio file loaded into an array using [`AutoProcessor`](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoProcessor). """ FEATURE_EXTRACTION_EXAMPLE = r""" @@ -546,7 +560,7 @@ def forward( >>> from optimum.neuron import {model_class} >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}") - >>> model = {model_class}.from_pretrained("{checkpoint}", export=True) + >>> model = {model_class}.from_pretrained("{checkpoint}") >>> num_choices = 4 >>> first_sentence = ["Members of the procession walk down the street holding small horn brass instruments."] * num_choices @@ -858,6 +872,308 @@ def forward( return ModelOutput(logits=logits, pred_boxes=pred_boxes, last_hidden_state=last_hidden_state) +AUDIO_CLASSIFICATION_EXAMPLE = r""" + Example of audio classification: + + ```python + >>> from transformers import {processor_class} + >>> from optimum.neuron import {model_class} + >>> from datasets import load_dataset + >>> import torch + + >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") + >>> dataset = dataset.sort("id") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> # audio file is decoded on the fly + >>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") + + >>> logits = model(**inputs).logits + >>> predicted_class_ids = torch.argmax(logits, dim=-1).item() + >>> predicted_label = model.config.id2label[predicted_class_ids] + ``` + Example using `transformers.pipeline`: + + ```python + >>> from transformers import {processor_class}, pipeline + >>> from optimum.neuron import {model_class} + + >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}") + >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") + >>> dataset = dataset.sort("id") + + >>> model = {model_class}.from_pretrained("{checkpoint}") + >>> ac = pipeline("audio-classification", model=model, feature_extractor=feature_extractor) + + >>> pred = ac(dataset[0]["audio"]["array"]) + ``` +""" + + +@add_start_docstrings( + """ + Neuron Model with an audio classification head. + """, + NEURON_MODEL_START_DOCSTRING, +) +class NeuronModelForAudioClassification(NeuronTracedModel): + """ + Neuron Model for audio-classification, with a sequence classification head on top (a linear layer over the pooled output) for tasks like + SUPERB Keyword Spotting. + """ + + auto_model_class = AutoModelForAudioClassification + + @add_start_docstrings_to_model_forward( + NEURON_AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length") + + AUDIO_CLASSIFICATION_EXAMPLE.format( + processor_class=_GENERIC_PROCESSOR, + model_class="NeuronModelForAudioClassification", + checkpoint="Jingya/wav2vec2-large-960h-lv60-self-neuronx-audio-classification", + ) + ) + def forward( + self, + input_values: torch.Tensor, + **kwargs, + ): + neuron_inputs = {"input_values": input_values} + + # run inference + with self.neuron_padding_manager(neuron_inputs) as inputs: + outputs = self.model(*inputs) # shape: [batch_size, num_labels] + outputs = self.remove_padding( + outputs, dims=[0], indices=[input_values.shape[0]] + ) # Remove padding on batch_size(0) + + logits = outputs[0] + + return SequenceClassifierOutput(logits=logits) + + +AUDIO_FRAME_CLASSIFICATION_EXAMPLE = r""" + Example of audio frame classification: + + ```python + >>> from transformers import {processor_class} + >>> from optimum.neuron import {model_class} + >>> from datasets import load_dataset + >>> import torch + + >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") + >>> dataset = dataset.sort("id") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt", sampling_rate=sampling_rate) + >>> logits = model(**inputs).logits + + >>> probabilities = torch.sigmoid(logits[0]) + >>> labels = (probabilities > 0.5).long() + >>> labels[0].tolist() + ``` +""" + + +@add_start_docstrings( + """ + Neuron Model with an audio frame classification head. + """, + NEURON_MODEL_START_DOCSTRING, +) +class NeuronModelForAudioFrameClassification(NeuronTracedModel): + """ + Neuron Model with a frame classification head on top for tasks like Speaker Diarization. + """ + + auto_model_class = AutoModelForAudioFrameClassification + + @add_start_docstrings_to_model_forward( + NEURON_AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length") + + AUDIO_FRAME_CLASSIFICATION_EXAMPLE.format( + processor_class=_GENERIC_PROCESSOR, + model_class="NeuronModelForAudioFrameClassification", + checkpoint="Jingya/wav2vec2-base-superb-sd-neuronx", + ) + ) + def forward( + self, + input_values: torch.Tensor, + **kwargs, + ): + neuron_inputs = {"input_values": input_values} + + # run inference + with self.neuron_padding_manager(neuron_inputs) as inputs: + outputs = self.model(*inputs) # shape: [batch_size, num_labels] + outputs = self.remove_padding( + outputs, dims=[0], indices=[input_values.shape[0]] + ) # Remove padding on batch_size(0) + + logits = outputs[0] + + return TokenClassifierOutput(logits=logits) + + +CTC_EXAMPLE = r""" + Example of CTC: + + ```python + >>> from transformers import {processor_class}, Wav2Vec2ForCTC + >>> from optimum.neuron import {model_class} + >>> from datasets import load_dataset + >>> import torch + + >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") + >>> dataset = dataset.sort("id") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> processor = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> # audio file is decoded on the fly + >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") + >>> logits = model(**inputs).logits + >>> predicted_ids = torch.argmax(logits, dim=-1) + + >>> transcription = processor.batch_decode(predicted_ids) + ``` + Example using `transformers.pipeline`: + + ```python + >>> from transformers import {processor_class}, pipeline + >>> from optimum.neuron import {model_class} + + >>> processor = {processor_class}.from_pretrained("{checkpoint}") + >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") + >>> dataset = dataset.sort("id") + + >>> model = {model_class}.from_pretrained("{checkpoint}") + >>> asr = pipeline("automatic-speech-recognition", model=model, feature_extractor=processor.feature_extractor, tokenizer=processor.tokenizer) + ``` +""" + + +@add_start_docstrings( + """ + Neuron Model with a connectionist temporal classification head. + """, + NEURON_MODEL_START_DOCSTRING, +) +class NeuronModelForCTC(NeuronTracedModel): + """ + Neuron Model with a language modeling head on top for Connectionist Temporal Classification (CTC). + """ + + auto_model_class = AutoModelForCTC + main_input_name = "input_values" + + @add_start_docstrings_to_model_forward( + NEURON_AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length") + + CTC_EXAMPLE.format( + processor_class=_GENERIC_PROCESSOR, + model_class="NeuronModelForCTC", + checkpoint="Jingya/wav2vec2-large-960h-lv60-self-neuronx-ctc", + ) + ) + def forward( + self, + input_values: torch.Tensor, + **kwargs, + ): + neuron_inputs = {"input_values": input_values} + + # run inference + with self.neuron_padding_manager(neuron_inputs) as inputs: + outputs = self.model(*inputs) # shape: [batch_size, sequence_length] + outputs = self.remove_padding( + outputs, dims=[0], indices=[input_values.shape[0]] + ) # Remove padding on batch_size(0) + + logits = outputs[0] + + return CausalLMOutput(logits=logits) + + +AUDIO_XVECTOR_EXAMPLE = r""" + Example of Audio XVector: + + ```python + >>> from transformers import {processor_class} + >>> from optimum.neuron import {model_class} + >>> from datasets import load_dataset + >>> import torch + + >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation") + >>> dataset = dataset.sort("id") + >>> sampling_rate = dataset.features["audio"].sampling_rate + + >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = feature_extractor( + ... [d["array"] for d in dataset[:2]["audio"]], sampling_rate=sampling_rate, return_tensors="pt", padding=True + ... ) + >>> embeddings = model(**inputs).embeddings + + >>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1) + + >>> cosine_sim = torch.nn.CosineSimilarity(dim=-1) + >>> similarity = cosine_sim(embeddings[0], embeddings[1]) + >>> threshold = 0.7 + >>> if similarity < threshold: + ... print("Speakers are not the same!") + >>> round(similarity.item(), 2) + ``` +""" + + +@add_start_docstrings( + """ + Neuron Model with an XVector feature extraction head on top for tasks like Speaker Verification. + """, + NEURON_MODEL_START_DOCSTRING, +) +class NeuronModelForXVector(NeuronTracedModel): + """ + Neuron Model with an XVector feature extraction head on top for tasks like Speaker Verification. + """ + + auto_model_class = AutoModelForAudioXVector + + @add_start_docstrings_to_model_forward( + NEURON_AUDIO_INPUTS_DOCSTRING.format("batch_size, sequence_length") + + AUDIO_XVECTOR_EXAMPLE.format( + processor_class=_GENERIC_PROCESSOR, + model_class="NeuronModelForXVector", + checkpoint="Jingya/wav2vec2-base-superb-sv-neuronx", + ) + ) + def forward( + self, + input_values: torch.Tensor, + **kwargs, + ): + neuron_inputs = {"input_values": input_values} + + # run inference + with self.neuron_padding_manager(neuron_inputs) as inputs: + outputs = self.model(*inputs) # shape: [batch_size, num_labels] + outputs = self.remove_padding( + outputs, dims=[0], indices=[input_values.shape[0]] + ) # Remove padding on batch_size(0) + + logits = outputs[0] + embeddings = outputs[1] + + return XVectorOutput(logits=logits, embeddings=embeddings) + + NEURON_CAUSALLM_MODEL_START_DOCSTRING = r""" This model inherits from [`~neuron.modeling.NeuronDecoderModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving) diff --git a/optimum/neuron/pipelines/transformers/base.py b/optimum/neuron/pipelines/transformers/base.py index 865457f1b..8bf80c019 100644 --- a/optimum/neuron/pipelines/transformers/base.py +++ b/optimum/neuron/pipelines/transformers/base.py @@ -18,7 +18,9 @@ from typing import Any, Dict, Optional, Union from transformers import ( + AudioClassificationPipeline, AutoConfig, + AutomaticSpeechRecognitionPipeline, FillMaskPipeline, ImageClassificationPipeline, ImageSegmentationPipeline, @@ -44,7 +46,9 @@ ) from ...modeling import ( + NeuronModelForAudioClassification, NeuronModelForCausalLM, + NeuronModelForCTC, NeuronModelForFeatureExtraction, NeuronModelForImageClassification, NeuronModelForMaskedLM, @@ -114,6 +118,18 @@ "default": "apple/deeplabv3-mobilevit-small", "type": "image", }, + "automatic-speech-recognition": { + "impl": AutomaticSpeechRecognitionPipeline, + "class": (NeuronModelForCTC,), + "default": "facebook/wav2vec2-large-960h-lv60-self", + "type": "multimodal", + }, + "audio-classification": { + "impl": AudioClassificationPipeline, + "class": (NeuronModelForAudioClassification,), + "default": "facebook/wav2vec2-large-960h-lv60-self", + "type": "audio", + }, } diff --git a/setup.py b/setup.py index 0d4759b21..e29fdf22d 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,8 @@ "peft", "compel", "rjieba", + "soundfile", + "librosa", "opencv-python-headless", ] @@ -55,6 +57,7 @@ "neuron-cc[tensorflow]==1.22.0.0", "protobuf", "torchvision", + "numpy==1.22.3", ], "neuronx": [ "wheel", diff --git a/tests/inference/inference_utils.py b/tests/inference/inference_utils.py index b37e434b9..3accdf96b 100644 --- a/tests/inference/inference_utils.py +++ b/tests/inference/inference_utils.py @@ -60,6 +60,7 @@ "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", "stable-diffusion-ip2p": "asntr/tiny-stable-diffusion-pix2pix-torch", "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", + "wav2vec2": "hf-internal-testing/tiny-random-Wav2Vec2Model", "xlm": "hf-internal-testing/tiny-random-XLMModel", "xlm-roberta": "hf-internal-testing/tiny-xlm-roberta", "yolos": "hf-internal-testing/tiny-random-YolosModel", diff --git a/tests/inference/test_modeling.py b/tests/inference/test_modeling.py index 987c36064..21a88ccab 100644 --- a/tests/inference/test_modeling.py +++ b/tests/inference/test_modeling.py @@ -20,13 +20,17 @@ import requests import torch +from datasets import load_dataset from huggingface_hub.constants import default_cache_path from parameterized import parameterized from PIL import Image from sentence_transformers import SentenceTransformer, util from transformers import ( + AutoFeatureExtractor, AutoImageProcessor, AutoModel, + AutoModelForAudioClassification, + AutoModelForCTC, AutoModelForImageClassification, AutoModelForMaskedLM, AutoModelForMultipleChoice, @@ -35,14 +39,17 @@ AutoModelForSemanticSegmentation, AutoModelForSequenceClassification, AutoModelForTokenClassification, + AutoProcessor, AutoTokenizer, - CLIPProcessor, PretrainedConfig, set_seed, ) from transformers.onnx.utils import get_preprocessor from optimum.neuron import ( + NeuronModelForAudioClassification, + NeuronModelForAudioFrameClassification, + NeuronModelForCTC, NeuronModelForFeatureExtraction, NeuronModelForImageClassification, NeuronModelForMaskedLM, @@ -53,6 +60,7 @@ NeuronModelForSentenceTransformers, NeuronModelForSequenceClassification, NeuronModelForTokenClassification, + NeuronModelForXVector, NeuronTracedModel, pipeline, ) @@ -432,7 +440,7 @@ def test_sentence_transformers_clip(self, model_arch): "two_dogs_in_snow.jpg", ) - processor = CLIPProcessor.from_pretrained(model_id, subfolder="0_CLIPModel") + processor = AutoProcessor.from_pretrained(model_id, subfolder="0_CLIPModel") inputs = processor(text=texts, images=Image.open("two_dogs_in_snow.jpg"), return_tensors="pt", padding=True) outputs = neuron_model(**inputs) self.assertIn("image_embeds", outputs) @@ -1722,3 +1730,411 @@ def test_pipeline_model(self): _ = pipe(url) gc.collect() + + +@is_inferentia_test +class NeuronModelForCTCIntegrationTest(NeuronModelTestMixin): + NEURON_MODEL_CLASS = NeuronModelForCTC + TASK = "automatic-speech-recognition" + STATIC_INPUTS_SHAPES = {"batch_size": 1, "audio_sequence_length": 100000} + if is_neuron_available(): + ATOL_FOR_VALIDATION = 1e-3 + SUPPORTED_ARCHITECTURES = [] + elif is_neuronx_available(): + ATOL_FOR_VALIDATION = 1e-3 + SUPPORTED_ARCHITECTURES = ["wav2vec2"] + else: + ATOL_FOR_VALIDATION = 1e-5 + SUPPORTED_ARCHITECTURES = [] + + def _load_neuron_model_and_processor(self, model_arch, suffix): + model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] + neuron_model = NeuronModelForCTC.from_pretrained(self.neuron_model_dirs[model_arch + suffix]) + preprocessor = AutoProcessor.from_pretrained(model_id) + self.assertIsInstance(neuron_model.model, torch.jit._script.ScriptModule) + self.assertIsInstance(neuron_model.config, PretrainedConfig) + return neuron_model, preprocessor + + def _load_transformers_model(self, model_arch): + model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] + set_seed(SEED) + transformers_model = AutoModelForCTC.from_pretrained(model_id) + return transformers_model + + def _prepare_inputs(self, processor, batch_size=1): + dataset = load_dataset( + "hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True + ) + dataset = dataset.sort("id") + sampling_rate = dataset.features["audio"].sampling_rate + inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") + if batch_size > 1: + for name, tensor in inputs.items(): + inputs[name] = torch.cat(batch_size * [tensor]) + return inputs + + def _validate_outputs(self, model_arch, suffix): + neuron_model, preprocessor = self._load_neuron_model_and_processor(model_arch, suffix) + inputs = self._prepare_inputs(preprocessor) + neuron_outputs = neuron_model(**inputs) + self.assertIn("logits", neuron_outputs) + self.assertIsInstance(neuron_outputs.logits, torch.Tensor) + + @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) + @requires_neuronx + def test_compare_to_transformers_dyn_bs(self, model_arch): + # Neuron model with dynamic batching + model_args = { + "test_name": model_arch + "_dyn_bs_true", + "model_arch": model_arch, + "dynamic_batch_size": True, + } + self._setup(model_args) + self._validate_outputs(model_arch, "_dyn_bs_true") + + gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) + def test_compare_to_transformers_non_dyn_bs(self, model_arch): + model_args = { + "test_name": model_arch + "_dyn_bs_false", + "model_arch": model_arch, + "dynamic_batch_size": False, + } + self._setup(model_args) + self._validate_outputs(model_arch, "_dyn_bs_false") + + gc.collect() + + def test_non_dyn_bs_neuron_model_on_false_batch_size(self): + model_arch = "wav2vec2" + model_args = { + "test_name": model_arch + "_dyn_bs_false", + "model_arch": model_arch, + "dynamic_batch_size": False, + } + self._setup(model_args) + neuron_model, preprocessor = self._load_neuron_model_and_processor(model_arch, "_dyn_bs_false") + inputs = self._prepare_inputs(preprocessor, batch_size=2) + + with self.assertRaises(Exception) as context: + _ = neuron_model(**inputs) + + self.assertIn("set `dynamic_batch_size=True` during the compilation", str(context.exception)) + + def test_pipeline_model(self): + model_arch = "wav2vec2" + model_args = {"test_name": model_arch + "_dyn_bs_false", "model_arch": model_arch} + self._setup(model_args) + + neuron_model, processor = self._load_neuron_model_and_processor(model_arch, "_dyn_bs_false") + + pipe = pipeline( + "automatic-speech-recognition", + model=neuron_model, + feature_extractor=processor.feature_extractor, + tokenizer=processor.tokenizer, + ) + dataset = load_dataset( + "hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True + ) + dataset = dataset.sort("id") + _ = pipe(dataset[0]["audio"]["array"]) + + gc.collect() + + +@is_inferentia_test +class NeuronModelForAudioClassificationIntegrationTest(NeuronModelTestMixin): + NEURON_MODEL_CLASS = NeuronModelForAudioClassification + TASK = "audio-classification" + STATIC_INPUTS_SHAPES = {"batch_size": 1, "audio_sequence_length": 100000} + if is_neuron_available(): + ATOL_FOR_VALIDATION = 1e-3 + SUPPORTED_ARCHITECTURES = [] + elif is_neuronx_available(): + ATOL_FOR_VALIDATION = 1e-3 + SUPPORTED_ARCHITECTURES = ["wav2vec2"] + else: + ATOL_FOR_VALIDATION = 1e-5 + SUPPORTED_ARCHITECTURES = [] + + def _load_neuron_model_and_processor(self, model_arch, suffix): + model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] + neuron_model = NeuronModelForAudioClassification.from_pretrained(self.neuron_model_dirs[model_arch + suffix]) + preprocessor = AutoProcessor.from_pretrained(model_id) + self.assertIsInstance(neuron_model.model, torch.jit._script.ScriptModule) + self.assertIsInstance(neuron_model.config, PretrainedConfig) + return neuron_model, preprocessor + + def _load_transformers_model(self, model_arch): + model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] + set_seed(SEED) + transformers_model = AutoModelForAudioClassification.from_pretrained(model_id) + return transformers_model + + def _prepare_inputs(self, processor, batch_size=1): + dataset = load_dataset( + "hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True + ) + dataset = dataset.sort("id") + sampling_rate = dataset.features["audio"].sampling_rate + inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") + if batch_size > 1: + for name, tensor in inputs.items(): + inputs[name] = torch.cat(batch_size * [tensor]) + return inputs + + def _validate_outputs(self, model_arch, suffix): + neuron_model, preprocessor = self._load_neuron_model_and_processor(model_arch, suffix) + inputs = self._prepare_inputs(preprocessor) + neuron_outputs = neuron_model(**inputs) + self.assertIn("logits", neuron_outputs) + self.assertIsInstance(neuron_outputs.logits, torch.Tensor) + + @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) + @requires_neuronx + def test_compare_to_transformers_dyn_bs(self, model_arch): + # Neuron model with dynamic batching + model_args = { + "test_name": model_arch + "_dyn_bs_true", + "model_arch": model_arch, + "dynamic_batch_size": True, + } + self._setup(model_args) + self._validate_outputs(model_arch, "_dyn_bs_true") + + gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) + def test_compare_to_transformers_non_dyn_bs(self, model_arch): + model_args = { + "test_name": model_arch + "_dyn_bs_false", + "model_arch": model_arch, + "dynamic_batch_size": False, + } + self._setup(model_args) + self._validate_outputs(model_arch, "_dyn_bs_false") + + gc.collect() + + def test_non_dyn_bs_neuron_model_on_false_batch_size(self): + model_arch = "wav2vec2" + model_args = { + "test_name": model_arch + "_dyn_bs_false", + "model_arch": model_arch, + "dynamic_batch_size": False, + } + self._setup(model_args) + neuron_model, preprocessor = self._load_neuron_model_and_processor(model_arch, "_dyn_bs_false") + inputs = self._prepare_inputs(preprocessor, batch_size=2) + + with self.assertRaises(Exception) as context: + _ = neuron_model(**inputs) + + self.assertIn("set `dynamic_batch_size=True` during the compilation", str(context.exception)) + + def test_pipeline_model(self): + model_arch = "wav2vec2" + model_args = {"test_name": model_arch + "_dyn_bs_false", "model_arch": model_arch} + self._setup(model_args) + + neuron_model, processor = self._load_neuron_model_and_processor(model_arch, "_dyn_bs_false") + + pipe = pipeline( + "audio-classification", + model=neuron_model, + feature_extractor=processor.feature_extractor, + tokenizer=processor.tokenizer, + ) + dataset = load_dataset( + "hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True + ) + dataset = dataset.sort("id") + _ = pipe(dataset[0]["audio"]["array"]) + + gc.collect() + + +@is_inferentia_test +class NeuronModelForAudioFrameClassificationIntegrationTest(NeuronModelTestMixin): + NEURON_MODEL_CLASS = NeuronModelForAudioFrameClassification + TASK = "audio-frame-classification" + STATIC_INPUTS_SHAPES = {"batch_size": 1, "audio_sequence_length": 100000} + if is_neuron_available(): + ATOL_FOR_VALIDATION = 1e-3 + SUPPORTED_ARCHITECTURES = [] + elif is_neuronx_available(): + ATOL_FOR_VALIDATION = 1e-3 + SUPPORTED_ARCHITECTURES = ["wav2vec2"] + else: + ATOL_FOR_VALIDATION = 1e-5 + SUPPORTED_ARCHITECTURES = [] + + def _load_neuron_model_and_processor(self, model_arch, suffix): + model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] + neuron_model = NeuronModelForAudioFrameClassification.from_pretrained( + self.neuron_model_dirs[model_arch + suffix] + ) + preprocessor = AutoFeatureExtractor.from_pretrained(model_id) + self.assertIsInstance(neuron_model.model, torch.jit._script.ScriptModule) + self.assertIsInstance(neuron_model.config, PretrainedConfig) + return neuron_model, preprocessor + + def _load_transformers_model(self, model_arch): + model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] + set_seed(SEED) + transformers_model = NeuronModelForAudioFrameClassification.from_pretrained(model_id) + return transformers_model + + def _prepare_inputs(self, processor, batch_size=1): + dataset = load_dataset( + "hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True + ) + dataset = dataset.sort("id") + sampling_rate = dataset.features["audio"].sampling_rate + inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") + if batch_size > 1: + for name, tensor in inputs.items(): + inputs[name] = torch.cat(batch_size * [tensor]) + return inputs + + def _validate_outputs(self, model_arch, suffix): + neuron_model, preprocessor = self._load_neuron_model_and_processor(model_arch, suffix) + inputs = self._prepare_inputs(preprocessor) + neuron_outputs = neuron_model(**inputs) + self.assertIn("logits", neuron_outputs) + self.assertIsInstance(neuron_outputs.logits, torch.Tensor) + + @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) + @requires_neuronx + def test_compare_to_transformers_dyn_bs(self, model_arch): + # Neuron model with dynamic batching + model_args = { + "test_name": model_arch + "_dyn_bs_true", + "model_arch": model_arch, + "dynamic_batch_size": True, + } + self._setup(model_args) + self._validate_outputs(model_arch, "_dyn_bs_true") + + gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) + def test_compare_to_transformers_non_dyn_bs(self, model_arch): + model_args = { + "test_name": model_arch + "_dyn_bs_false", + "model_arch": model_arch, + "dynamic_batch_size": False, + } + self._setup(model_args) + self._validate_outputs(model_arch, "_dyn_bs_false") + + gc.collect() + + def test_non_dyn_bs_neuron_model_on_false_batch_size(self): + model_arch = "wav2vec2" + model_args = { + "test_name": model_arch + "_dyn_bs_false", + "model_arch": model_arch, + "dynamic_batch_size": False, + } + self._setup(model_args) + neuron_model, preprocessor = self._load_neuron_model_and_processor(model_arch, "_dyn_bs_false") + inputs = self._prepare_inputs(preprocessor, batch_size=2) + + with self.assertRaises(Exception) as context: + _ = neuron_model(**inputs) + + self.assertIn("set `dynamic_batch_size=True` during the compilation", str(context.exception)) + + +@is_inferentia_test +class NeuronModelForXVectorIntegrationTest(NeuronModelTestMixin): + NEURON_MODEL_CLASS = NeuronModelForXVector + TASK = "audio-xvector" + STATIC_INPUTS_SHAPES = {"batch_size": 1, "audio_sequence_length": 100000} + if is_neuron_available(): + ATOL_FOR_VALIDATION = 1e-3 + SUPPORTED_ARCHITECTURES = [] + elif is_neuronx_available(): + ATOL_FOR_VALIDATION = 1e-3 + SUPPORTED_ARCHITECTURES = ["wav2vec2"] + else: + ATOL_FOR_VALIDATION = 1e-5 + SUPPORTED_ARCHITECTURES = [] + + def _load_neuron_model_and_processor(self, model_arch, suffix): + model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] + neuron_model = NeuronModelForXVector.from_pretrained(self.neuron_model_dirs[model_arch + suffix]) + preprocessor = AutoFeatureExtractor.from_pretrained(model_id) + self.assertIsInstance(neuron_model.model, torch.jit._script.ScriptModule) + self.assertIsInstance(neuron_model.config, PretrainedConfig) + return neuron_model, preprocessor + + def _load_transformers_model(self, model_arch): + model_id = self.ARCH_MODEL_MAP[model_arch] if model_arch in self.ARCH_MODEL_MAP else MODEL_NAMES[model_arch] + set_seed(SEED) + transformers_model = NeuronModelForXVector.from_pretrained(model_id) + return transformers_model + + def _prepare_inputs(self, processor, batch_size=1): + dataset = load_dataset( + "hf-internal-testing/librispeech_asr_demo", "clean", split="validation", trust_remote_code=True + ) + dataset = dataset.sort("id") + sampling_rate = dataset.features["audio"].sampling_rate + inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt") + if batch_size > 1: + for name, tensor in inputs.items(): + inputs[name] = torch.cat(batch_size * [tensor]) + return inputs + + def _validate_outputs(self, model_arch, suffix): + neuron_model, preprocessor = self._load_neuron_model_and_processor(model_arch, suffix) + inputs = self._prepare_inputs(preprocessor) + neuron_outputs = neuron_model(**inputs) + self.assertIn("logits", neuron_outputs) + self.assertIsInstance(neuron_outputs.logits, torch.Tensor) + + @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) + @requires_neuronx + def test_compare_to_transformers_dyn_bs(self, model_arch): + # Neuron model with dynamic batching + model_args = { + "test_name": model_arch + "_dyn_bs_true", + "model_arch": model_arch, + "dynamic_batch_size": True, + } + self._setup(model_args) + self._validate_outputs(model_arch, "_dyn_bs_true") + + gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) + def test_compare_to_transformers_non_dyn_bs(self, model_arch): + model_args = { + "test_name": model_arch + "_dyn_bs_false", + "model_arch": model_arch, + "dynamic_batch_size": False, + } + self._setup(model_args) + self._validate_outputs(model_arch, "_dyn_bs_false") + + gc.collect() + + def test_non_dyn_bs_neuron_model_on_false_batch_size(self): + model_arch = "wav2vec2" + model_args = { + "test_name": model_arch + "_dyn_bs_false", + "model_arch": model_arch, + "dynamic_batch_size": False, + } + self._setup(model_args) + neuron_model, preprocessor = self._load_neuron_model_and_processor(model_arch, "_dyn_bs_false") + inputs = self._prepare_inputs(preprocessor, batch_size=2) + + with self.assertRaises(Exception) as context: + _ = neuron_model(**inputs) + + self.assertIn("set `dynamic_batch_size=True` during the compilation", str(context.exception)) diff --git a/tests/inference/test_stable_diffusion_pipeline.py b/tests/inference/test_stable_diffusion_pipeline.py index 3d73ed38d..406cdb68c 100644 --- a/tests/inference/test_stable_diffusion_pipeline.py +++ b/tests/inference/test_stable_diffusion_pipeline.py @@ -197,6 +197,7 @@ def test_compatibility_with_compel(self, model_arch): pipe = self.NEURON_MODEL_CLASS.from_pretrained( MODEL_NAMES[model_arch], export=True, + disable_neuron_cache=True, inline_weights_to_neff=True, output_hidden_states=True, **input_shapes,