Skip to content

Commit

Permalink
test: main_export with model object and custom ops
Browse files Browse the repository at this point in the history
  • Loading branch information
hann-wang committed Nov 22, 2024
1 parent c10f763 commit b4aa01a
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 10 deletions.
12 changes: 6 additions & 6 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@


def main_export(
model_name_or_path_or_obj: Union[str, "PreTrainedModel",
"TFPreTrainedModel", "DiffusionPipeline"],
model_name_or_path: Union[str, "PreTrainedModel", "TFPreTrainedModel",
"DiffusionPipeline"],
output: Union[str, Path],
task: str = "auto",
opset: Optional[int] = None,
Expand Down Expand Up @@ -103,7 +103,7 @@ def main_export(
Args:
> Required parameters
model_name_or_path_or_obj (`Union[str, "PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"]`):
model_name_or_path (`Union[str, "PreTrainedModel", "TFPreTrainedModel", "DiffusionPipeline"]`):
Model ID on huggingface.co or path on disk to the model repository to export. Example: `model_name_or_path="BAAI/bge-m3"` or `mode_name_or_path="/path/to/model_folder`.
It is also possible to pass a model object to skip getting models from the export task.
output (`Union[str, Path]`):
Expand Down Expand Up @@ -252,11 +252,11 @@ def main_export(
"Please use one of the following tasks instead: `text-to-image`, `image-to-image`, `inpainting`."
)

if isinstance(model_name_or_path_or_obj, str):
if isinstance(model_name_or_path, str):
model = None
model_name_or_path = model_name_or_path_or_obj
model_name_or_path = model_name_or_path
else:
model = model_name_or_path_or_obj
model = model_name_or_path
model_name_or_path = model.config._name_or_path

if providers is None:
Expand Down
151 changes: 147 additions & 4 deletions tests/exporters/onnx/test_exporters_onnx_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
import os
import subprocess
import unittest
from unittest import mock
import itertools
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Dict, Optional
from typing import Dict, Optional, Union, TYPE_CHECKING, List, Callable

import onnx
import pytest
from parameterized import parameterized
from transformers import AutoModelForSequenceClassification, AutoTokenizer, is_torch_available
from transformers import AutoModelForSequenceClassification, AutoTokenizer, is_torch_available, AutoModel
from transformers.testing_utils import require_torch, require_torch_gpu, require_vision, slow

from optimum.exporters.error_utils import MinimumVersionError
Expand All @@ -33,12 +35,16 @@
ONNX_DECODER_WITH_PAST_NAME,
ONNX_ENCODER_NAME,
)
from optimum.utils.import_utils import is_onnxruntime_available
from optimum.utils.testing_utils import grid_parameters, require_diffusers, require_sentence_transformers, require_timm


if is_torch_available():
from optimum.exporters.tasks import TasksManager

if is_onnxruntime_available():
import onnxruntime as ort

from ..exporters_utils import (
NO_DYNAMIC_AXES_EXPORT_SHAPES_TRANSFORMERS,
PYTORCH_DIFFUSION_MODEL,
Expand All @@ -49,6 +55,15 @@
PYTORCH_TRANSFORMERS_MODEL_NO_DYNAMIC_AXES,
)

if TYPE_CHECKING:
from optimum.utils.import_utils import is_diffusers_available
from transformers import is_tf_available
from transformers.modeling_utils import PreTrainedModel
if is_diffusers_available():
from diffusers import DiffusionPipeline
if is_tf_available():
from transformers.modeling_tf_utils import TFPreTrainedModel


def _get_models_to_test(export_models_dict: Dict, library_name: str):
models_to_test = []
Expand Down Expand Up @@ -174,7 +189,8 @@ class OnnxCLIExportTestCase(unittest.TestCase):

def _onnx_export(
self,
model_name: str,
model_name: Union[str, "PreTrainedModel", "TFPreTrainedModel",
"DiffusionPipeline"],
task: str,
monolith: bool = False,
no_post_process: bool = False,
Expand All @@ -184,6 +200,11 @@ def _onnx_export(
variant: str = "default",
no_dynamic_axes: bool = False,
model_kwargs: Optional[Dict] = None,
do_validation: bool = True,
disable_dynamic_axes_fix: bool = False,
custom_export_fn: Optional[Callable[..., None]] = None,
providers: Optional[List[str]] = None,
session_options: Optional["ort.SessionOptions"] = None,
):
# We need to set this to some value to be able to test the outputs values for batch size > 1.
if task == "text-classification":
Expand All @@ -206,13 +227,19 @@ def _onnx_export(
no_dynamic_axes=no_dynamic_axes,
pad_token_id=pad_token_id,
model_kwargs=model_kwargs,
do_validation=do_validation,
disable_dynamic_axes_fix=disable_dynamic_axes_fix,
custom_export_fn=custom_export_fn,
providers=providers,
session_options=session_options,
)
except MinimumVersionError as e:
pytest.skip(f"Skipping due to minimum version requirements not met. Full error: {e}")

def _onnx_export_no_dynamic_axes(
self,
model_name: str,
model_name: Union[str, "PreTrainedModel", "TFPreTrainedModel",
"DiffusionPipeline"],
task: str,
input_shape: dict,
input_shape_for_validation: tuple,
Expand Down Expand Up @@ -739,3 +766,119 @@ def test_complex_synonyms(self):
model.save_pretrained(tmpdir_in)

main_export(model_name_or_path=tmpdir_in, output=tmpdir_out, task="text-classification")

@parameterized.expand(itertools.product([False, True], ["cuda", "cpu"]))
@require_vision
@require_torch_gpu
@slow
@pytest.mark.run_slow
def test_customized_export(
self,
use_custom_op: bool,
device: str,
):
import torch.version
from torch import nn
from torch.autograd import Function
from torch.onnx import export as pytorch_export, symbolic_helper

class CustomActivationFunc(Function):
@staticmethod
def forward(ctx, input_tensor):
return input_tensor

@staticmethod
def backward(ctx, grad_outputs: torch.Tensor):
return grad_outputs

@staticmethod
@symbolic_helper.parse_args("v")
def symbolic(g, input_tensor):
ret = g.op('CustomDomain::CustomActivation', input_tensor)
ret.setType(input_tensor.type())
return ret

class CustomActivation(nn.Module):
def forward(self, input_tensor):
return CustomActivationFunc.apply(input_tensor)

def replace_activation(model: nn.Module):
if hasattr(model, "intermediate_act_fn"):
setattr(model, "intermediate_act_fn", CustomActivation())
for child in model.children():
replace_activation(child)

test_name, model_type, model_name, task, variant, monolith, no_post_process = _get_models_to_test(
{
"beit":
"hf-internal-testing/tiny-random-BeitForImageClassification"
},
library_name="transformers")[0]

if device.startswith("cuda"):
if torch.version.hip:
providers = ["ROCMExecutionProvider"]
else:
providers = ["CUDAExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]

if is_onnxruntime_available():
do_validation = True
so = ort.SessionOptions()
else:
do_validation = False
so = None

custom_export_mock = mock.MagicMock(side_effect=pytorch_export)

model = AutoModel.from_pretrained(model_name).to(device)
TasksManager.standardize_model_attributes(model)

if use_custom_op:
replace_activation(model)
if do_validation:
with pytest.raises(Exception):
# this one will fail because no custom ops are registered in onnxruntime
self._onnx_export(
model,
task,
monolith,
no_post_process,
variant=variant,
device=model.device,
disable_dynamic_axes_fix=not do_validation,
do_validation=do_validation,
custom_export_fn=custom_export_mock,
providers=providers,
session_options=so,
)
self._onnx_export(
model,
task,
monolith,
no_post_process,
variant=variant,
device=model.device,
disable_dynamic_axes_fix=True,
do_validation=False,
custom_export_fn=custom_export_mock,
providers=providers,
session_options=so,
)
else:
self._onnx_export(
model,
task,
monolith,
no_post_process,
variant=variant,
device=model.device,
disable_dynamic_axes_fix=not do_validation,
do_validation=do_validation,
custom_export_fn=custom_export_mock,
providers=providers,
session_options=so,
)

custom_export_mock.assert_called()

0 comments on commit b4aa01a

Please sign in to comment.