Skip to content

Commit

Permalink
Add more audio models: ast, hubert, unispeech, unispeech-sat, wavlm (#…
Browse files Browse the repository at this point in the history
…651)

* wav2vec2 base support

* fix outputs for audio-xvector

* add CTC modeling

* some tests and modeling

* add xvector

* fix doc

* fix doc

* try fix tests

* disable auto triggered CIs for inf1

* add ast, hubert, wav2vec, wavlm

* tests for unispeech

* remove debug

* change comment
  • Loading branch information
JingyaHuang authored Jul 14, 2024
1 parent 56cb8a5 commit 3d88fac
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 8 deletions.
7 changes: 6 additions & 1 deletion docs/source/package_reference/supported_models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
| Architecture | Task |
|---------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------|
| ALBERT | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification |
| AST | feature-extraction, audio-classification |
| BERT | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification |
| BLOOM | text-generation |
| Beit | feature-extraction, image-classification |
Expand All @@ -39,6 +40,7 @@ limitations under the License.
| ESM | feature-extraction, fill-mask, text-classification, token-classification |
| FlauBERT | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification |
| GPT2 | text-generation |
| Hubert | feature-extraction, automatic-speech-recognition, audio-classification |
| Levit | feature-extraction, image-classification |
| Llama, Llama 2, Llama 3 | text-generation |
| Mistral | text-generation |
Expand All @@ -53,9 +55,12 @@ limitations under the License.
| RoFormer | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification |
| Swin | feature-extraction, image-classification |
| T5 | text2text-generation |
| UniSpeech | feature-extraction, automatic-speech-recognition, audio-classification |
| UniSpeech-SAT | feature-extraction, automatic-speech-recognition, audio-classification, audio-frame-classification, audio-xvector |
| ViT | feature-extraction, image-classification |
| Wav2Vec2 | feature-extraction, automatic-speech-recognition, audio-classification, audio-frame-classification, audio-xvector |
| WavLM | feature-extraction, automatic-speech-recognition, audio-classification, audio-frame-classification, audio-xvector |
| XLM | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification |
| ViT | feature-extraction, image-classification |
| XLM-RoBERTa | feature-extraction, fill-mask, multiple-choice, question-answering, text-classification, token-classification |
| Yolos | feature-extraction, object-detection |

Expand Down
140 changes: 139 additions & 1 deletion optimum/exporters/neuron/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@

import torch

from ...neuron.utils import DummyBeamValuesGenerator, DummyControNetInputGenerator, DummyMaskedPosGenerator
from ...neuron.utils import (
ASTDummyAudioInputGenerator,
DummyBeamValuesGenerator,
DummyControNetInputGenerator,
DummyMaskedPosGenerator,
)
from ...utils import (
DummyInputGenerator,
DummySeq2SeqDecoderTextInputGenerator,
Expand Down Expand Up @@ -423,11 +428,144 @@ def inputs(self) -> List[str]:
@property
def outputs(self) -> List[str]:
common_outputs = super().outputs
if self.task == "feature-extraction":
common_outputs = ["last_hidden_state", "extract_features"]
if self.task == "audio-xvector":
common_outputs.append("embeddings")
return common_outputs


@register_in_tasks_manager(
"audio-spectrogram-transformer",
*[
"feature-extraction",
"audio-classification",
],
)
class ASTNeuronConfig(AudioNeuronConfig):
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
num_mel_bins="num_mel_bins", max_length="max_length", allow_new=True
)
DUMMY_INPUT_GENERATOR_CLASSES = (ASTDummyAudioInputGenerator,)

@property
def inputs(self) -> List[str]:
return ["input_values"]


@register_in_tasks_manager(
"hubert",
*[
"feature-extraction",
"automatic-speech-recognition",
"audio-classification",
],
)
class HubertNeuronConfig(Wav2Vec2NeuronConfig):
@property
def outputs(self) -> List[str]:
common_outputs = super().outputs
if self.task == "feature-extraction":
common_outputs = ["last_hidden_state"]
return common_outputs


# TODO: compilation failed due to a bug in xla: https://github.com/pytorch/xla/issues/6398.
# @register_in_tasks_manager(
# "sew",
# *[
# "feature-extraction",
# "automatic-speech-recognition",
# "audio-classification",
# ],
# )
# class SEWNeuronConfig(Wav2Vec2NeuronConfig):
# pass


# TODO: compilation failed due to a bug in xla: https://github.com/pytorch/xla/issues/6398.
# @register_in_tasks_manager(
# "sew-d",
# *[
# "feature-extraction",
# "automatic-speech-recognition",
# "audio-classification",
# ],
# )
# class SEWDNeuronConfig(Wav2Vec2NeuronConfig):
# pass


@register_in_tasks_manager(
"unispeech",
*[
"feature-extraction",
"automatic-speech-recognition",
"audio-classification",
],
)
class UniSpeechNeuronConfig(Wav2Vec2NeuronConfig):
pass


@register_in_tasks_manager(
"unispeech-sat",
*[
"feature-extraction",
"automatic-speech-recognition",
"audio-classification",
"audio-frame-classification",
"audio-xvector",
],
)
class UniSpeechSATNeuronConfig(Wav2Vec2NeuronConfig):
pass


# TODO: compilation failed due to a bug in xla: https://github.com/pytorch/xla/issues/6398.
# @register_in_tasks_manager(
# "wav2vec2-bert",
# *[
# "feature-extraction",
# "automatic-speech-recognition",
# "audio-classification",
# "audio-frame-classification",
# "audio-xvector",
# ],
# )
# class Wav2Vec2BertNeuronConfig(Wav2Vec2NeuronConfig):
# pass


# TODO: compilation failed due to a bug in xla: https://github.com/pytorch/xla/issues/6398.
# @register_in_tasks_manager(
# "wav2vec2-conformer",
# *[
# "feature-extraction",
# "automatic-speech-recognition",
# "audio-classification",
# "audio-frame-classification",
# "audio-xvector",
# ],
# )
# class Wav2Vec2ConformerNeuronConfig(Wav2Vec2NeuronConfig):
# pass


@register_in_tasks_manager(
"wavlm",
*[
"feature-extraction",
"automatic-speech-recognition",
"audio-classification",
"audio-frame-classification",
"audio-xvector",
],
)
class WavLMNeuronConfig(Wav2Vec2NeuronConfig):
pass


@register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers")
class UNetNeuronConfig(VisionNeuronConfig):
ATOL_FOR_VALIDATION = 1e-3
Expand Down
14 changes: 12 additions & 2 deletions optimum/neuron/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@
"is_torch_xla_available",
"is_transformers_neuronx_available",
],
"input_generators": ["DummyBeamValuesGenerator", "DummyMaskedPosGenerator", "DummyControNetInputGenerator"],
"input_generators": [
"DummyBeamValuesGenerator",
"DummyMaskedPosGenerator",
"DummyControNetInputGenerator",
"ASTDummyAudioInputGenerator",
],
"misc": [
"DiffusersPretrainedConfig",
"check_if_weights_replacable",
Expand Down Expand Up @@ -93,7 +98,12 @@
is_torch_xla_available,
is_transformers_neuronx_available,
)
from .input_generators import DummyBeamValuesGenerator, DummyControNetInputGenerator, DummyMaskedPosGenerator
from .input_generators import (
ASTDummyAudioInputGenerator,
DummyBeamValuesGenerator,
DummyControNetInputGenerator,
DummyMaskedPosGenerator,
)
from .misc import (
DiffusersPretrainedConfig,
check_if_weights_replacable,
Expand Down
10 changes: 10 additions & 0 deletions optimum/neuron/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from ...utils import (
DTYPE_MAPPER,
DummyAudioInputGenerator,
DummyInputGenerator,
NormalizedTextConfig,
NormalizedVisionConfig,
Expand Down Expand Up @@ -163,3 +164,12 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
self.width // 2**num_cross_attn_blocks,
)
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)


# copied from https://github.com/huggingface/optimum/blob/171020c775cec6ff77826c3f5f5e5c1498b23f81/optimum/exporters/onnx/model_configs.py#L1363C1-L1368C111
class ASTDummyAudioInputGenerator(DummyAudioInputGenerator):
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
shape = [self.batch_size, self.normalized_config.max_length, self.normalized_config.num_mel_bins]
if input_name == "input_values":
return self.random_float_tensor(shape, min_value=-1, max_value=1, framework=framework, dtype=float_dtype)
return super().generate(input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype)
9 changes: 9 additions & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

EXPORT_MODELS_TINY = {
"albert": "hf-internal-testing/tiny-random-AlbertModel",
"audio-spectrogram-transformer": "Ericwang/tiny-random-ast",
"beit": "hf-internal-testing/tiny-random-BeitForImageClassification",
"bert": "hf-internal-testing/tiny-random-BertModel",
"camembert": "hf-internal-testing/tiny-random-camembert",
Expand All @@ -32,6 +33,7 @@
"electra": "hf-internal-testing/tiny-random-ElectraModel",
"esm": "hf-internal-testing/tiny-random-EsmModel",
"flaubert": "flaubert/flaubert_small_cased",
"hubert": "hf-internal-testing/tiny-random-HubertModel",
"levit": "hf-internal-testing/tiny-random-LevitModel",
"mobilebert": "hf-internal-testing/tiny-random-MobileBertModel",
"mobilenet-v2": "hf-internal-testing/tiny-random-MobileNetV2Model",
Expand All @@ -40,8 +42,15 @@
"phi": "bumblebee-testing/tiny-random-PhiModel",
"roberta": "hf-internal-testing/tiny-random-RobertaModel",
"roformer": "hf-internal-testing/tiny-random-RoFormerModel",
# "sew": "hf-internal-testing/tiny-random-SEWModel", # blocked
# "sew-d": "hf-internal-testing/tiny-random-SEWDModel", # blocked
"swin": "hf-internal-testing/tiny-random-SwinModel",
"unispeech": "hf-internal-testing/tiny-random-unispeech",
"unispeech-sat": "hf-internal-testing/tiny-random-unispeech-sat",
"vit": "hf-internal-testing/tiny-random-vit",
"wav2vec2": "hf-internal-testing/tiny-random-Wav2Vec2Model",
# "wav2vec2-conformer": "hf-internal-testing/tiny-random-wav2vec2-conformer", # blocked
"wavlm": "hf-internal-testing/tiny-random-wavlm",
"xlm": "hf-internal-testing/tiny-random-XLMModel",
"xlm-roberta": "hf-internal-testing/tiny-xlm-roberta",
"yolos": "hf-internal-testing/tiny-random-YolosModel",
Expand Down
6 changes: 3 additions & 3 deletions tests/exporters/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import os
import random
import unittest
Expand Down Expand Up @@ -129,15 +128,16 @@ def _neuronx_export(
if library_name == "sentence_transformers":
model_class = TasksManager.get_model_class_for_task(task, framework="pt", library=library_name)
model = model_class(model_name)
reference_model = model_class(model_name)
if "clip" in model[0].__class__.__name__.lower():
config = model[0].model.config
else:
config = model[0].auto_model.config
else:
model_class = TasksManager.get_model_class_for_task(task, framework="pt")
model_class = TasksManager.get_model_class_for_task(task, model_type=model_type, framework="pt")
config = AutoConfig.from_pretrained(model_name)
model = model_class.from_config(config)
reference_model = copy.deepcopy(model)
reference_model = model_class.from_config(config)

mandatory_shapes = {
name: DEFAULT_DUMMY_SHAPES.get(name) or EXTREA_DEFAULT_DUMMY_SHAPES.get(name)
Expand Down
10 changes: 9 additions & 1 deletion tests/inference/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

MODEL_NAMES = {
"albert": "hf-internal-testing/tiny-random-AlbertModel",
"audio-spectrogram-transformer": "Ericwang/tiny-random-ast",
"beit": "hf-internal-testing/tiny-random-BeitForImageClassification",
"bert": "hf-internal-testing/tiny-random-BertModel",
"camembert": "hf-internal-testing/tiny-random-camembert",
Expand All @@ -45,6 +46,7 @@
"dpt": "hf-internal-testing/tiny-random-DPTModel",
"electra": "hf-internal-testing/tiny-random-ElectraModel",
"flaubert": "flaubert/flaubert_small_cased",
"hubert": "hf-internal-testing/tiny-random-HubertModel",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
"latent-consistency": "echarlaix/tiny-random-latent-consistency",
"levit": "hf-internal-testing/tiny-random-LevitModel",
Expand All @@ -55,12 +57,18 @@
"phi": "bumblebee-testing/tiny-random-PhiModel",
"roberta": "hf-internal-testing/tiny-random-RobertaModel",
"roformer": "hf-internal-testing/tiny-random-RoFormerModel",
# "sew": "hf-internal-testing/tiny-random-SEWModel", # blocked
# "sew-d": "hf-internal-testing/tiny-random-SEWDModel", # blocked
"swin": "hf-internal-testing/tiny-random-SwinModel",
"vit": "hf-internal-testing/tiny-random-vit",
"stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch",
"stable-diffusion-ip2p": "asntr/tiny-stable-diffusion-pix2pix-torch",
"stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl",
"unispeech": "hf-internal-testing/tiny-random-unispeech",
"unispeech-sat": "hf-internal-testing/tiny-random-unispeech-sat",
"vit": "hf-internal-testing/tiny-random-vit",
"wav2vec2": "hf-internal-testing/tiny-random-Wav2Vec2Model",
# "wav2vec2-conformer": "hf-internal-testing/tiny-random-wav2vec2-conformer", # blocked
"wavlm": "hf-internal-testing/tiny-random-wavlm",
"xlm": "hf-internal-testing/tiny-random-XLMModel",
"xlm-roberta": "hf-internal-testing/tiny-xlm-roberta",
"yolos": "hf-internal-testing/tiny-random-YolosModel",
Expand Down

0 comments on commit 3d88fac

Please sign in to comment.