Skip to content

Commit

Permalink
Enable SD XL ONNX export and ONNX Runtime inference (huggingface#1168)
Browse files Browse the repository at this point in the history
* add stable diffusion XL export

* fix style

* fix test model name

* fix style

* remove clip with projection from test

* change model name

* fix style

* remove need create pretrainedconfig

* fix style

* fix dummy input generation

* add saving second tokenzier when exporting a SD XL model

* fix style

* add SD XL pipeline

* fix style

* add test

* add watermarker

* fix style

* add watermark

* add test

* set default height width stable diffusion pipeline

* enable img2img task

* fix style

* enable to only have the second tokenizer and text encoder

* add test

* fix cli export

* adapt test for batch size > 1
  • Loading branch information
echarlaix committed Jul 17, 2023
1 parent c4750f6 commit a9ffe07
Show file tree
Hide file tree
Showing 26 changed files with 1,640 additions and 192 deletions.
76 changes: 40 additions & 36 deletions optimum/exporters/onnx/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,7 @@
from transformers.utils import is_torch_available

from ...commands.export.onnx import parse_args_onnx
from ...utils import (
DEFAULT_DUMMY_SHAPES,
DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER,
DIFFUSION_MODEL_UNET_SUBFOLDER,
DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER,
DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER,
ONNX_WEIGHTS_NAME,
logging,
)
from ...utils import DEFAULT_DUMMY_SHAPES, ONNX_WEIGHTS_NAME, logging
from ...utils.save_utils import maybe_save_preprocessors
from ..error_utils import AtolError, OutputMatchError, ShapeError
from ..tasks import TasksManager
Expand Down Expand Up @@ -71,8 +63,9 @@ def _get_submodels_and_onnx_configs(
custom_architecture: bool,
fn_get_submodels: Optional[Callable] = None,
):
is_stable_diffusion = "stable-diffusion" in task
if not custom_architecture:
if task == "stable-diffusion":
if is_stable_diffusion:
onnx_config = None
models_and_onnx_configs = get_stable_diffusion_models_for_export(model)
else:
Expand Down Expand Up @@ -104,7 +97,7 @@ def _get_submodels_and_onnx_configs(
if fn_get_submodels is not None:
submodels_for_export = fn_get_submodels(model)
else:
if task == "stable-diffusion":
if is_stable_diffusion:
submodels_for_export = _get_submodels_for_export_stable_diffusion(model)
elif (
model.config.is_encoder_decoder
Expand Down Expand Up @@ -312,10 +305,19 @@ def main_export(
)

custom_architecture = False
if task != "stable-diffusion" and model.config.model_type.replace(
"-", "_"
) not in TasksManager.get_supported_model_type_for_task(task, exporter="onnx"):
custom_architecture = True
is_stable_diffusion = "stable-diffusion" in task
model_type = "stable-diffusion" if is_stable_diffusion else model.config.model_type.replace("_", "-")

if not is_stable_diffusion:
if model_type in TasksManager._UNSUPPORTED_CLI_MODEL_TYPE:
raise ValueError(
f"{model_type} is not supported yet. Only {TasksManager._SUPPORTED_CLI_MODEL_TYPE} are supported. "
f"If you want to support {model_type} please propose a PR or open up an issue."
)
if model.config.model_type.replace("-", "_") not in TasksManager.get_supported_model_type_for_task(
task, exporter="onnx"
):
custom_architecture = True

# TODO: support onnx_config.py in the model repo
if custom_architecture and custom_onnx_configs is None:
Expand All @@ -330,9 +332,8 @@ def main_export(

if (
not custom_architecture
and task != "stable-diffusion"
and task + "-with-past"
in TasksManager.get_supported_tasks_for_model_type(model.config.model_type.replace("_", "-"), "onnx")
and not is_stable_diffusion
and task + "-with-past" in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx")
):
if original_task == "auto": # Make -with-past the default if --task was not explicitely specified
task = task + "-with-past"
Expand Down Expand Up @@ -367,7 +368,7 @@ def main_export(
fn_get_submodels=fn_get_submodels,
)

if task != "stable-diffusion":
if not is_stable_diffusion:
needs_pad_token_id = (
isinstance(onnx_config, OnnxConfigWithPast)
and getattr(model.config, "pad_token_id", None) is None
Expand All @@ -391,7 +392,7 @@ def main_export(

if opset < onnx_config.DEFAULT_ONNX_OPSET:
raise ValueError(
f"Opset {opset} is not sufficient to export {model.config.model_type}. "
f"Opset {opset} is not sufficient to export {model_type}. "
f"At least {onnx_config.DEFAULT_ONNX_OPSET} is required."
)
if atol is None:
Expand All @@ -415,28 +416,31 @@ def main_export(

onnx_files_subpaths = None
else:
onnx_files_subpaths = [
DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER,
DIFFUSION_MODEL_UNET_SUBFOLDER,
DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER,
DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER,
]

# save the subcomponent configuration
for model_name, name_dir in zip(models_and_onnx_configs, onnx_files_subpaths):
for model_name in models_and_onnx_configs:
subcomponent = models_and_onnx_configs[model_name][0]
if hasattr(subcomponent, "save_config"):
subcomponent.save_config(output / name_dir)
subcomponent.save_config(output / model_name)
elif hasattr(subcomponent, "config") and hasattr(subcomponent.config, "save_pretrained"):
subcomponent.config.save_pretrained(output / name_dir)
subcomponent.config.save_pretrained(output / model_name)

onnx_files_subpaths = [os.path.join(path, ONNX_WEIGHTS_NAME) for path in onnx_files_subpaths]
onnx_files_subpaths = [os.path.join(name_dir, ONNX_WEIGHTS_NAME) for name_dir in models_and_onnx_configs]

# Saving the additional components needed to perform inference.
model.tokenizer.save_pretrained(output.joinpath("tokenizer"))
model.scheduler.save_pretrained(output.joinpath("scheduler"))
if model.feature_extractor is not None:
model.feature_extractor.save_pretrained(output.joinpath("feature_extractor"))

feature_extractor = getattr(model, "feature_extractor", None)
if feature_extractor is not None:
feature_extractor.save_pretrained(output.joinpath("feature_extractor"))

tokenizer = getattr(model, "tokenizer", None)
if tokenizer is not None:
tokenizer.save_pretrained(output.joinpath("tokenizer"))

tokenizer_2 = getattr(model, "tokenizer_2", None)
if tokenizer_2 is not None:
tokenizer_2.save_pretrained(output.joinpath("tokenizer_2"))

model.save_config(output)

_, onnx_outputs = export_models(
Expand Down Expand Up @@ -464,7 +468,7 @@ def main_export(

# Optionally post process the obtained ONNX file(s), for example to merge the decoder / decoder with past if any
# TODO: treating stable diffusion separately is quite ugly
if not no_post_process and task != "stable-diffusion":
if not no_post_process and not is_stable_diffusion:
try:
logger.info("Post-processing the exported models...")
models_and_onnx_configs, onnx_files_subpaths = onnx_config.post_process_exported_models(
Expand All @@ -475,7 +479,7 @@ def main_export(
f"The post-processing of the ONNX export failed. The export can still be performed by passing the option --no-post-process. Detailed error: {e}"
)

if task == "stable-diffusion":
if is_stable_diffusion:
use_subprocess = (
False # TODO: fix Can't pickle local object 'get_stable_diffusion_models_for_export.<locals>.<lambda>'
)
Expand Down
2 changes: 2 additions & 0 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ def _run_validation(
if isinstance(value, (list, tuple)):
value = config.flatten_output_collection_property(name, value)
onnx_inputs.update({tensor_name: pt_tensor.cpu().numpy() for tensor_name, pt_tensor in value.items()})
elif isinstance(value, dict):
onnx_inputs.update({tensor_name: pt_tensor.cpu().numpy() for tensor_name, pt_tensor in value.items()})
else:
onnx_inputs[name] = value.cpu().numpy()

Expand Down
51 changes: 48 additions & 3 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,14 +658,15 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
}


class CLIPTextOnnxConfig(TextEncoderOnnxConfig):
class CLIPTextWithProjectionOnnxConfig(TextEncoderOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
# The ONNX export of this architecture needs the Trilu operator support, available since opset 14
DEFAULT_ONNX_OPSET = 14

NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
vocab_size="vocab_size",
sequence_length="max_position_embeddings",
num_layers="num_hidden_layers",
allow_new=True,
)

Expand All @@ -677,13 +678,33 @@ def inputs(self) -> Dict[str, Dict[int, str]]:

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
common_outputs = {
"text_embeds": {0: "batch_size", 1: "sequence_length"},
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
}
if self._normalized_config.output_hidden_states:
for i in range(self._normalized_config.num_layers + 1):
common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"}

return common_outputs


class CLIPTextOnnxConfig(CLIPTextWithProjectionOnnxConfig):
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = {
"last_hidden_state": {0: "batch_size", 1: "sequence_length"},
"pooler_output": {0: "batch_size"},
}
if self._normalized_config.output_hidden_states:
for i in range(self._normalized_config.num_layers + 1):
common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"}

return common_outputs

def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)

if framework == "pt":
import torch

Expand Down Expand Up @@ -713,12 +734,19 @@ class UNetOnnxConfig(VisionOnnxConfig):

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {
common_inputs = {
"sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
"timestep": {0: "steps"},
"encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
}

# TODO : add text_image, image and image_embeds
if getattr(self._normalized_config, "addition_embed_type", None) == "text_time":
common_inputs["text_embeds"] = {0: "batch_size"}
common_inputs["time_ids"] = {0: "batch_size"}

return common_inputs

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
Expand All @@ -734,8 +762,25 @@ def torch_to_onnx_output_map(self) -> Dict[str, str]:
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)
dummy_inputs["encoder_hidden_states"] = dummy_inputs["encoder_hidden_states"][0]

if getattr(self._normalized_config, "addition_embed_type", None) == "text_time":
dummy_inputs["added_cond_kwargs"] = {
"text_embeds": dummy_inputs.pop("text_embeds"),
"time_ids": dummy_inputs.pop("time_ids"),
}

return dummy_inputs

def ordered_inputs(self, model) -> Dict[str, Dict[int, str]]:
inputs = super().ordered_inputs(model=model)
# to fix mismatch between model forward signature and expected inputs
# a dictionnary of additional embeddings `added_cond_kwargs` is expected depending on config.addition_embed_type
if getattr(self._normalized_config, "addition_embed_type", None) == "text_time":
inputs["text_embeds"] = self.inputs["text_embeds"]
inputs["time_ids"] = self.inputs["time_ids"]

return inputs


class VaeEncoderOnnxConfig(VisionOnnxConfig):
ATOL_FOR_VALIDATION = 1e-2
Expand Down
38 changes: 32 additions & 6 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,24 @@ def _get_submodels_for_export_stable_diffusion(
"""
Returns the components of a Stable Diffusion model.
"""
from diffusers import StableDiffusionXLPipeline

models_for_export = {}
if isinstance(pipeline, StableDiffusionXLPipeline):
projection_dim = pipeline.text_encoder_2.config.projection_dim
else:
projection_dim = pipeline.text_encoder.config.projection_dim

# Text encoder
models_for_export["text_encoder"] = pipeline.text_encoder
if pipeline.text_encoder is not None:
if isinstance(pipeline, StableDiffusionXLPipeline):
pipeline.text_encoder.config.output_hidden_states = True
models_for_export["text_encoder"] = pipeline.text_encoder

# U-NET
# PyTorch does not support the ONNX export of torch.nn.functional.scaled_dot_product_attention
pipeline.unet.set_attn_processor(AttnProcessor())
pipeline.unet.config.text_encoder_projection_dim = projection_dim
models_for_export["unet"] = pipeline.unet

# VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565
Expand All @@ -124,6 +134,11 @@ def _get_submodels_for_export_stable_diffusion(
vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample)
models_for_export["vae_decoder"] = vae_decoder

text_encoder_2 = getattr(pipeline, "text_encoder_2", None)
if text_encoder_2 is not None:
text_encoder_2.config.output_hidden_states = True
models_for_export["text_encoder_2"] = text_encoder_2

return models_for_export


Expand Down Expand Up @@ -249,11 +264,12 @@ def get_stable_diffusion_models_for_export(
models_for_export = _get_submodels_for_export_stable_diffusion(pipeline)

# Text encoder
text_encoder_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.text_encoder, exporter="onnx", task="feature-extraction"
)
text_encoder_onnx_config = text_encoder_config_constructor(pipeline.text_encoder.config)
models_for_export["text_encoder"] = (models_for_export["text_encoder"], text_encoder_onnx_config)
if "text_encoder" in models_for_export:
text_encoder_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.text_encoder, exporter="onnx", task="feature-extraction"
)
text_encoder_onnx_config = text_encoder_config_constructor(pipeline.text_encoder.config)
models_for_export["text_encoder"] = (models_for_export["text_encoder"], text_encoder_onnx_config)

# U-NET
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
Expand All @@ -278,6 +294,16 @@ def get_stable_diffusion_models_for_export(
vae_onnx_config = vae_config_constructor(vae_decoder.config)
models_for_export["vae_decoder"] = (vae_decoder, vae_onnx_config)

if "text_encoder_2" in models_for_export:
onnx_config_constructor = TasksManager.get_exporter_config_constructor(
model=pipeline.text_encoder_2,
exporter="onnx",
task="feature-extraction",
model_type="clip-text-with-projection",
)
onnx_config = onnx_config_constructor(pipeline.text_encoder_2.config)
models_for_export["text_encoder_2"] = (models_for_export["text_encoder_2"], onnx_config)

return models_for_export


Expand Down
Loading

0 comments on commit a9ffe07

Please sign in to comment.