From a9ffe076057aea68f25bc11d4288c5e7a3c28c3f Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Mon, 17 Jul 2023 18:39:15 +0200 Subject: [PATCH] Enable SD XL ONNX export and ONNX Runtime inference (#1168) * add stable diffusion XL export * fix style * fix test model name * fix style * remove clip with projection from test * change model name * fix style * remove need create pretrainedconfig * fix style * fix dummy input generation * add saving second tokenzier when exporting a SD XL model * fix style * add SD XL pipeline * fix style * add test * add watermarker * fix style * add watermark * add test * set default height width stable diffusion pipeline * enable img2img task * fix style * enable to only have the second tokenizer and text encoder * add test * fix cli export * adapt test for batch size > 1 --- optimum/exporters/onnx/__main__.py | 76 +-- optimum/exporters/onnx/convert.py | 2 + optimum/exporters/onnx/model_configs.py | 51 +- optimum/exporters/onnx/utils.py | 38 +- optimum/exporters/tasks.py | 31 +- optimum/onnxruntime/__init__.py | 8 + optimum/onnxruntime/modeling_diffusion.py | 173 ++++-- optimum/onnxruntime/modeling_ort.py | 1 - .../diffusers/pipeline_stable_diffusion.py | 64 ++- .../pipeline_stable_diffusion_img2img.py | 35 +- .../pipeline_stable_diffusion_inpaint.py | 10 +- .../diffusers/pipeline_stable_diffusion_xl.py | 499 +++++++++++++++++ .../pipeline_stable_diffusion_xl_img2img.py | 506 ++++++++++++++++++ optimum/pipelines/diffusers/pipeline_utils.py | 49 ++ optimum/pipelines/diffusers/watermark.py | 27 + optimum/utils/__init__.py | 1 + optimum/utils/constant.py | 1 + optimum/utils/dummy_diffusers_objects.py | 22 + optimum/utils/import_utils.py | 2 +- optimum/utils/input_generators.py | 15 +- setup.py | 1 + tests/exporters/exporters_utils.py | 3 +- .../exporters/onnx/test_exporters_onnx_cli.py | 28 +- tests/exporters/onnx/test_onnx_export.py | 74 +-- .../test_stable_diffusion_pipeline.py | 114 +++- tests/onnxruntime/utils_onnxruntime_tests.py | 1 + 26 files changed, 1640 insertions(+), 192 deletions(-) create mode 100644 optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py create mode 100644 optimum/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py create mode 100644 optimum/pipelines/diffusers/watermark.py diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index 7200711b53..696cb86823 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -23,15 +23,7 @@ from transformers.utils import is_torch_available from ...commands.export.onnx import parse_args_onnx -from ...utils import ( - DEFAULT_DUMMY_SHAPES, - DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, - DIFFUSION_MODEL_UNET_SUBFOLDER, - DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, - DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, - ONNX_WEIGHTS_NAME, - logging, -) +from ...utils import DEFAULT_DUMMY_SHAPES, ONNX_WEIGHTS_NAME, logging from ...utils.save_utils import maybe_save_preprocessors from ..error_utils import AtolError, OutputMatchError, ShapeError from ..tasks import TasksManager @@ -71,8 +63,9 @@ def _get_submodels_and_onnx_configs( custom_architecture: bool, fn_get_submodels: Optional[Callable] = None, ): + is_stable_diffusion = "stable-diffusion" in task if not custom_architecture: - if task == "stable-diffusion": + if is_stable_diffusion: onnx_config = None models_and_onnx_configs = get_stable_diffusion_models_for_export(model) else: @@ -104,7 +97,7 @@ def _get_submodels_and_onnx_configs( if fn_get_submodels is not None: submodels_for_export = fn_get_submodels(model) else: - if task == "stable-diffusion": + if is_stable_diffusion: submodels_for_export = _get_submodels_for_export_stable_diffusion(model) elif ( model.config.is_encoder_decoder @@ -312,10 +305,19 @@ def main_export( ) custom_architecture = False - if task != "stable-diffusion" and model.config.model_type.replace( - "-", "_" - ) not in TasksManager.get_supported_model_type_for_task(task, exporter="onnx"): - custom_architecture = True + is_stable_diffusion = "stable-diffusion" in task + model_type = "stable-diffusion" if is_stable_diffusion else model.config.model_type.replace("_", "-") + + if not is_stable_diffusion: + if model_type in TasksManager._UNSUPPORTED_CLI_MODEL_TYPE: + raise ValueError( + f"{model_type} is not supported yet. Only {TasksManager._SUPPORTED_CLI_MODEL_TYPE} are supported. " + f"If you want to support {model_type} please propose a PR or open up an issue." + ) + if model.config.model_type.replace("-", "_") not in TasksManager.get_supported_model_type_for_task( + task, exporter="onnx" + ): + custom_architecture = True # TODO: support onnx_config.py in the model repo if custom_architecture and custom_onnx_configs is None: @@ -330,9 +332,8 @@ def main_export( if ( not custom_architecture - and task != "stable-diffusion" - and task + "-with-past" - in TasksManager.get_supported_tasks_for_model_type(model.config.model_type.replace("_", "-"), "onnx") + and not is_stable_diffusion + and task + "-with-past" in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx") ): if original_task == "auto": # Make -with-past the default if --task was not explicitely specified task = task + "-with-past" @@ -367,7 +368,7 @@ def main_export( fn_get_submodels=fn_get_submodels, ) - if task != "stable-diffusion": + if not is_stable_diffusion: needs_pad_token_id = ( isinstance(onnx_config, OnnxConfigWithPast) and getattr(model.config, "pad_token_id", None) is None @@ -391,7 +392,7 @@ def main_export( if opset < onnx_config.DEFAULT_ONNX_OPSET: raise ValueError( - f"Opset {opset} is not sufficient to export {model.config.model_type}. " + f"Opset {opset} is not sufficient to export {model_type}. " f"At least {onnx_config.DEFAULT_ONNX_OPSET} is required." ) if atol is None: @@ -415,28 +416,31 @@ def main_export( onnx_files_subpaths = None else: - onnx_files_subpaths = [ - DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, - DIFFUSION_MODEL_UNET_SUBFOLDER, - DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, - DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, - ] - # save the subcomponent configuration - for model_name, name_dir in zip(models_and_onnx_configs, onnx_files_subpaths): + for model_name in models_and_onnx_configs: subcomponent = models_and_onnx_configs[model_name][0] if hasattr(subcomponent, "save_config"): - subcomponent.save_config(output / name_dir) + subcomponent.save_config(output / model_name) elif hasattr(subcomponent, "config") and hasattr(subcomponent.config, "save_pretrained"): - subcomponent.config.save_pretrained(output / name_dir) + subcomponent.config.save_pretrained(output / model_name) - onnx_files_subpaths = [os.path.join(path, ONNX_WEIGHTS_NAME) for path in onnx_files_subpaths] + onnx_files_subpaths = [os.path.join(name_dir, ONNX_WEIGHTS_NAME) for name_dir in models_and_onnx_configs] # Saving the additional components needed to perform inference. - model.tokenizer.save_pretrained(output.joinpath("tokenizer")) model.scheduler.save_pretrained(output.joinpath("scheduler")) - if model.feature_extractor is not None: - model.feature_extractor.save_pretrained(output.joinpath("feature_extractor")) + + feature_extractor = getattr(model, "feature_extractor", None) + if feature_extractor is not None: + feature_extractor.save_pretrained(output.joinpath("feature_extractor")) + + tokenizer = getattr(model, "tokenizer", None) + if tokenizer is not None: + tokenizer.save_pretrained(output.joinpath("tokenizer")) + + tokenizer_2 = getattr(model, "tokenizer_2", None) + if tokenizer_2 is not None: + tokenizer_2.save_pretrained(output.joinpath("tokenizer_2")) + model.save_config(output) _, onnx_outputs = export_models( @@ -464,7 +468,7 @@ def main_export( # Optionally post process the obtained ONNX file(s), for example to merge the decoder / decoder with past if any # TODO: treating stable diffusion separately is quite ugly - if not no_post_process and task != "stable-diffusion": + if not no_post_process and not is_stable_diffusion: try: logger.info("Post-processing the exported models...") models_and_onnx_configs, onnx_files_subpaths = onnx_config.post_process_exported_models( @@ -475,7 +479,7 @@ def main_export( f"The post-processing of the ONNX export failed. The export can still be performed by passing the option --no-post-process. Detailed error: {e}" ) - if task == "stable-diffusion": + if is_stable_diffusion: use_subprocess = ( False # TODO: fix Can't pickle local object 'get_stable_diffusion_models_for_export..' ) diff --git a/optimum/exporters/onnx/convert.py b/optimum/exporters/onnx/convert.py index 907749227f..cad2fdcb0f 100644 --- a/optimum/exporters/onnx/convert.py +++ b/optimum/exporters/onnx/convert.py @@ -369,6 +369,8 @@ def _run_validation( if isinstance(value, (list, tuple)): value = config.flatten_output_collection_property(name, value) onnx_inputs.update({tensor_name: pt_tensor.cpu().numpy() for tensor_name, pt_tensor in value.items()}) + elif isinstance(value, dict): + onnx_inputs.update({tensor_name: pt_tensor.cpu().numpy() for tensor_name, pt_tensor in value.items()}) else: onnx_inputs[name] = value.cpu().numpy() diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 20304e6d6c..e2a948b35d 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -658,7 +658,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]: } -class CLIPTextOnnxConfig(TextEncoderOnnxConfig): +class CLIPTextWithProjectionOnnxConfig(TextEncoderOnnxConfig): ATOL_FOR_VALIDATION = 1e-3 # The ONNX export of this architecture needs the Trilu operator support, available since opset 14 DEFAULT_ONNX_OPSET = 14 @@ -666,6 +666,7 @@ class CLIPTextOnnxConfig(TextEncoderOnnxConfig): NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( vocab_size="vocab_size", sequence_length="max_position_embeddings", + num_layers="num_hidden_layers", allow_new=True, ) @@ -677,13 +678,33 @@ def inputs(self) -> Dict[str, Dict[int, str]]: @property def outputs(self) -> Dict[str, Dict[int, str]]: - return { + common_outputs = { + "text_embeds": {0: "batch_size", 1: "sequence_length"}, + "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, + } + if self._normalized_config.output_hidden_states: + for i in range(self._normalized_config.num_layers + 1): + common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"} + + return common_outputs + + +class CLIPTextOnnxConfig(CLIPTextWithProjectionOnnxConfig): + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + common_outputs = { "last_hidden_state": {0: "batch_size", 1: "sequence_length"}, "pooler_output": {0: "batch_size"}, } + if self._normalized_config.output_hidden_states: + for i in range(self._normalized_config.num_layers + 1): + common_outputs[f"hidden_states.{i}"] = {0: "batch_size", 1: "sequence_length"} + + return common_outputs def generate_dummy_inputs(self, framework: str = "pt", **kwargs): dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs) + if framework == "pt": import torch @@ -713,12 +734,19 @@ class UNetOnnxConfig(VisionOnnxConfig): @property def inputs(self) -> Dict[str, Dict[int, str]]: - return { + common_inputs = { "sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}, "timestep": {0: "steps"}, "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, } + # TODO : add text_image, image and image_embeds + if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": + common_inputs["text_embeds"] = {0: "batch_size"} + common_inputs["time_ids"] = {0: "batch_size"} + + return common_inputs + @property def outputs(self) -> Dict[str, Dict[int, str]]: return { @@ -734,8 +762,25 @@ def torch_to_onnx_output_map(self) -> Dict[str, str]: def generate_dummy_inputs(self, framework: str = "pt", **kwargs): dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs) dummy_inputs["encoder_hidden_states"] = dummy_inputs["encoder_hidden_states"][0] + + if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": + dummy_inputs["added_cond_kwargs"] = { + "text_embeds": dummy_inputs.pop("text_embeds"), + "time_ids": dummy_inputs.pop("time_ids"), + } + return dummy_inputs + def ordered_inputs(self, model) -> Dict[str, Dict[int, str]]: + inputs = super().ordered_inputs(model=model) + # to fix mismatch between model forward signature and expected inputs + # a dictionnary of additional embeddings `added_cond_kwargs` is expected depending on config.addition_embed_type + if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": + inputs["text_embeds"] = self.inputs["text_embeds"] + inputs["time_ids"] = self.inputs["time_ids"] + + return inputs + class VaeEncoderOnnxConfig(VisionOnnxConfig): ATOL_FOR_VALIDATION = 1e-2 diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 150b99db4f..c1bee9a4da 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -100,14 +100,24 @@ def _get_submodels_for_export_stable_diffusion( """ Returns the components of a Stable Diffusion model. """ + from diffusers import StableDiffusionXLPipeline + models_for_export = {} + if isinstance(pipeline, StableDiffusionXLPipeline): + projection_dim = pipeline.text_encoder_2.config.projection_dim + else: + projection_dim = pipeline.text_encoder.config.projection_dim # Text encoder - models_for_export["text_encoder"] = pipeline.text_encoder + if pipeline.text_encoder is not None: + if isinstance(pipeline, StableDiffusionXLPipeline): + pipeline.text_encoder.config.output_hidden_states = True + models_for_export["text_encoder"] = pipeline.text_encoder # U-NET # PyTorch does not support the ONNX export of torch.nn.functional.scaled_dot_product_attention pipeline.unet.set_attn_processor(AttnProcessor()) + pipeline.unet.config.text_encoder_projection_dim = projection_dim models_for_export["unet"] = pipeline.unet # VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565 @@ -124,6 +134,11 @@ def _get_submodels_for_export_stable_diffusion( vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample) models_for_export["vae_decoder"] = vae_decoder + text_encoder_2 = getattr(pipeline, "text_encoder_2", None) + if text_encoder_2 is not None: + text_encoder_2.config.output_hidden_states = True + models_for_export["text_encoder_2"] = text_encoder_2 + return models_for_export @@ -249,11 +264,12 @@ def get_stable_diffusion_models_for_export( models_for_export = _get_submodels_for_export_stable_diffusion(pipeline) # Text encoder - text_encoder_config_constructor = TasksManager.get_exporter_config_constructor( - model=pipeline.text_encoder, exporter="onnx", task="feature-extraction" - ) - text_encoder_onnx_config = text_encoder_config_constructor(pipeline.text_encoder.config) - models_for_export["text_encoder"] = (models_for_export["text_encoder"], text_encoder_onnx_config) + if "text_encoder" in models_for_export: + text_encoder_config_constructor = TasksManager.get_exporter_config_constructor( + model=pipeline.text_encoder, exporter="onnx", task="feature-extraction" + ) + text_encoder_onnx_config = text_encoder_config_constructor(pipeline.text_encoder.config) + models_for_export["text_encoder"] = (models_for_export["text_encoder"], text_encoder_onnx_config) # U-NET onnx_config_constructor = TasksManager.get_exporter_config_constructor( @@ -278,6 +294,16 @@ def get_stable_diffusion_models_for_export( vae_onnx_config = vae_config_constructor(vae_decoder.config) models_for_export["vae_decoder"] = (vae_decoder, vae_onnx_config) + if "text_encoder_2" in models_for_export: + onnx_config_constructor = TasksManager.get_exporter_config_constructor( + model=pipeline.text_encoder_2, + exporter="onnx", + task="feature-extraction", + model_type="clip-text-with-projection", + ) + onnx_config = onnx_config_constructor(pipeline.text_encoder_2.config) + models_for_export["text_encoder_2"] = (models_for_export["text_encoder_2"], onnx_config) + return models_for_export diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index b676ddb9b1..2f3c432968 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -171,6 +171,7 @@ class TasksManager: "audio-xvector": "AutoModelForAudioXVector", "image-to-text": "AutoModelForVision2Seq", "stable-diffusion": "StableDiffusionPipeline", + "stable-diffusion-xl": "StableDiffusionXLPipeline", "zero-shot-image-classification": "AutoModelForZeroShotImageClassification", "zero-shot-object-detection": "AutoModelForZeroShotObjectDetection", } @@ -267,6 +268,7 @@ class TasksManager: "image-to-text": "transformers", "sentence-similarity": "transformers", "stable-diffusion": "diffusers", + "stable-diffusion-xl": "diffusers", "summarization": "transformers", "visual-question-answering": "transformers", "zero-shot-classification": "transformers", @@ -390,6 +392,10 @@ class TasksManager: "feature-extraction", onnx="CLIPTextOnnxConfig", ), + "clip-text-with-projection": supported_tasks_mapping( + "feature-extraction", + onnx="CLIPTextWithProjectionOnnxConfig", + ), "codegen": supported_tasks_mapping( "feature-extraction", "feature-extraction-with-past", @@ -931,7 +937,14 @@ class TasksManager: onnx="YolosOnnxConfig", ), } - _UNSUPPORTED_CLI_MODEL_TYPE = {"unet", "vae-encoder", "vae-decoder", "clip-text-model", "trocr"} + _UNSUPPORTED_CLI_MODEL_TYPE = { + "unet", + "vae-encoder", + "vae-decoder", + "clip-text-model", + "clip-text-with-projection", + "trocr", + } _SUPPORTED_CLI_MODEL_TYPE = set(_SUPPORTED_MODEL_TYPE.keys()) - _UNSUPPORTED_CLI_MODEL_TYPE @classmethod @@ -1006,7 +1019,7 @@ def get_supported_tasks_for_model_type( if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: raise KeyError( f"{model_type_and_model_name} is not supported yet. " - f"Only {TasksManager._SUPPORTED_CLI_MODEL_TYPE} are supported. " + f"Only {TasksManager._SUPPORTED_MODEL_TYPE} are supported. " f"If you want to support {model_type} please propose a PR or open up an issue." ) elif exporter not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]: @@ -1271,7 +1284,7 @@ def _infer_task_from_model_or_model_class( ( target_name.startswith("Auto"), target_name.startswith("TFAuto"), - target_name == "StableDiffusionPipeline", + "StableDiffusion" in target_name, ) ): if target_name == auto_cls_name: @@ -1314,8 +1327,10 @@ def _infer_task_from_model_name_or_path( model_info = huggingface_hub.model_info(model_name_or_path, revision=revision) if model_info.library_name == "diffusers": # TODO : getattr(model_info, "model_index") defining auto_model_class_name currently set to None - if "stable-diffusion" in model_info.tags: - inferred_task_name = "stable-diffusion" + for task in ("stable-diffusion-xl", "stable-diffusion"): + if task in model_info.tags: + inferred_task_name = task + break else: pipeline_tag = getattr(model_info, "pipeline_tag", None) # conversational is not a supported task per se, just an alias that may map to @@ -1476,7 +1491,11 @@ def get_model_from_task( elif device is None: device = torch.device("cpu") - if version.parse(torch.__version__) >= version.parse("2.0"): + # TODO : fix EulerDiscreteScheduler loading to enable for SD models + if ( + version.parse(torch.__version__) >= version.parse("2.0") + and TasksManager._TASKS_TO_LIBRARY[task.replace("-with-past", "")] != "diffusers" + ): with device: # Initialize directly in the requested device, to save allocation time. Especially useful for large # models to initialize on cuda device. diff --git a/optimum/onnxruntime/__init__.py b/optimum/onnxruntime/__init__.py index e5904185c2..62e32cfe71 100644 --- a/optimum/onnxruntime/__init__.py +++ b/optimum/onnxruntime/__init__.py @@ -71,12 +71,16 @@ "ORTStableDiffusionPipeline", "ORTStableDiffusionImg2ImgPipeline", "ORTStableDiffusionInpaintPipeline", + "ORTStableDiffusionXLPipeline", + "ORTStableDiffusionXLImg2ImgPipeline", ] else: _import_structure["modeling_diffusion"] = [ "ORTStableDiffusionPipeline", "ORTStableDiffusionImg2ImgPipeline", "ORTStableDiffusionInpaintPipeline", + "ORTStableDiffusionXLPipeline", + "ORTStableDiffusionXLImg2ImgPipeline", ] @@ -124,12 +128,16 @@ ORTStableDiffusionImg2ImgPipeline, ORTStableDiffusionInpaintPipeline, ORTStableDiffusionPipeline, + ORTStableDiffusionXLImg2ImgPipeline, + ORTStableDiffusionXLPipeline, ) else: from .modeling_diffusion import ( ORTStableDiffusionImg2ImgPipeline, ORTStableDiffusionInpaintPipeline, ORTStableDiffusionPipeline, + ORTStableDiffusionXLImg2ImgPipeline, + ORTStableDiffusionXLPipeline, ) else: import sys diff --git a/optimum/onnxruntime/modeling_diffusion.py b/optimum/onnxruntime/modeling_diffusion.py index 51e0a85a3f..3541ad9480 100644 --- a/optimum/onnxruntime/modeling_diffusion.py +++ b/optimum/onnxruntime/modeling_diffusion.py @@ -28,6 +28,7 @@ LMSDiscreteScheduler, PNDMScheduler, StableDiffusionPipeline, + StableDiffusionXLPipeline, ) from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import CONFIG_NAME @@ -41,7 +42,10 @@ from ..pipelines.diffusers.pipeline_stable_diffusion import StableDiffusionPipelineMixin from ..pipelines.diffusers.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipelineMixin from ..pipelines.diffusers.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipelineMixin +from ..pipelines.diffusers.pipeline_stable_diffusion_xl import StableDiffusionXLPipelineMixin +from ..pipelines.diffusers.pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipelineMixin from ..utils import ( + DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, DIFFUSION_MODEL_UNET_SUBFOLDER, DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, @@ -77,6 +81,8 @@ def __init__( scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], feature_extractor: Optional[CLIPFeatureExtractor] = None, vae_encoder_session: Optional[ort.InferenceSession] = None, + text_encoder_2_session: Optional[ort.InferenceSession] = None, + tokenizer_2: Optional[CLIPTokenizer] = None, use_io_binding: Optional[bool] = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, ): @@ -114,11 +120,16 @@ def __init__( self._internal_dict = config self.vae_decoder = ORTModelVaeDecoder(vae_decoder_session, self) self.vae_decoder_model_path = Path(vae_decoder_session._model_path) - self.text_encoder = ORTModelTextEncoder(text_encoder_session, self) - self.text_encoder_model_path = Path(text_encoder_session._model_path) self.unet = ORTModelUnet(unet_session, self) self.unet_model_path = Path(unet_session._model_path) + if text_encoder_session is not None: + self.text_encoder_model_path = Path(text_encoder_session._model_path) + self.text_encoder = ORTModelTextEncoder(text_encoder_session, self) + else: + self.text_encoder_model_path = None + self.text_encoder = None + if vae_encoder_session is not None: self.vae_encoder_model_path = Path(vae_encoder_session._model_path) self.vae_encoder = ORTModelVaeEncoder(vae_encoder_session, self) @@ -126,7 +137,15 @@ def __init__( self.vae_encoder_model_path = None self.vae_encoder = None + if text_encoder_2_session is not None: + self.text_encoder_2_model_path = Path(text_encoder_2_session._model_path) + self.text_encoder_2 = ORTModelTextEncoder(text_encoder_2_session, self) + else: + self.text_encoder_2_model_path = None + self.text_encoder_2 = None + self.tokenizer = tokenizer + self.tokenizer_2 = tokenizer_2 self.scheduler = scheduler self.feature_extractor = feature_extractor self.safety_checker = None @@ -136,6 +155,7 @@ def __init__( DIFFUSION_MODEL_UNET_SUBFOLDER: self.unet, DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER: self.vae_decoder, DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER: self.vae_encoder, + DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER: self.text_encoder_2, } # Modify config to keep the resulting model compatible with diffusers pipelines @@ -156,6 +176,7 @@ def load_model( text_encoder_path: Union[str, Path], unet_path: Union[str, Path], vae_encoder_path: Optional[Union[str, Path]] = None, + text_encoder_2_path: Optional[Union[str, Path]] = None, provider: str = "CPUExecutionProvider", session_options: Optional[ort.SessionOptions] = None, provider_options: Optional[Dict] = None, @@ -173,6 +194,8 @@ def load_model( The path to the U-NET ONNX model. vae_encoder_path (`Union[str, Path]`, defaults to `None`): The path to the VAE encoder ONNX model. + text_encoder_2_path (`Union[str, Path]`, defaults to `None`): + The path to the second text decoder ONNX model. provider (`str`, defaults to `"CPUExecutionProvider"`): ONNX Runtime provider to use for loading the model. See https://onnxruntime.ai/docs/execution-providers/ for possible providers. @@ -182,16 +205,22 @@ def load_model( Provider option dictionary corresponding to the provider used. See available options for each provider: https://onnxruntime.ai/docs/api/c/group___global.html . Defaults to `None`. """ - vae_decoder_session = ORTModel.load_model(vae_decoder_path, provider, session_options, provider_options) - text_encoder_session = ORTModel.load_model(text_encoder_path, provider, session_options, provider_options) - unet_session = ORTModel.load_model(unet_path, provider, session_options, provider_options) + vae_decoder = ORTModel.load_model(vae_decoder_path, provider, session_options, provider_options) + unet = ORTModel.load_model(unet_path, provider, session_options, provider_options) - if vae_encoder_path is not None: - vae_encoder_session = ORTModel.load_model(vae_encoder_path, provider, session_options, provider_options) - else: - vae_encoder_session = None + sessions = { + "vae_encoder": vae_encoder_path, + "text_encoder": text_encoder_path, + "text_encoder_2": text_encoder_2_path, + } + + for key, value in sessions.items(): + if value is not None and value.is_file(): + sessions[key] = ORTModel.load_model(value, provider, session_options, provider_options) + else: + sessions[key] = None - return vae_decoder_session, text_encoder_session, unet_session, vae_encoder_session + return vae_decoder, sessions["text_encoder"], unet, sessions["vae_encoder"], sessions["text_encoder_2"] def _save_pretrained(self, save_directory: Union[str, Path]): save_directory = Path(save_directory) @@ -201,10 +230,13 @@ def _save_pretrained(self, save_directory: Union[str, Path]): self.unet_model_path: save_directory / DIFFUSION_MODEL_UNET_SUBFOLDER / ONNX_WEIGHTS_NAME, } - if self.vae_encoder_model_path is not None: - src_to_dst_path[self.vae_encoder_model_path] = ( - save_directory / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / ONNX_WEIGHTS_NAME - ) + sub_models_to_save = { + self.vae_encoder_model_path: DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER, + self.text_encoder_2_model_path: DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, + } + for path, subfolder in sub_models_to_save.items(): + if path is not None: + src_to_dst_path[path] = save_directory / subfolder / ONNX_WEIGHTS_NAME # TODO: Modify _get_external_data_paths to give dictionnary src_paths = list(src_to_dst_path.keys()) @@ -219,10 +251,14 @@ def _save_pretrained(self, save_directory: Union[str, Path]): if config_path.is_file(): shutil.copyfile(config_path, dst_path.parent / self.sub_component_config_name) - self.tokenizer.save_pretrained(save_directory / "tokenizer") self.scheduler.save_pretrained(save_directory / "scheduler") + if self.feature_extractor is not None: self.feature_extractor.save_pretrained(save_directory / "feature_extractor") + if self.tokenizer is not None: + self.tokenizer.save_pretrained(save_directory / "tokenizer") + if self.tokenizer_2 is not None: + self.tokenizer_2.save_pretrained(save_directory / "tokenizer_2") @classmethod def _from_pretrained( @@ -236,6 +272,7 @@ def _from_pretrained( text_encoder_file_name: str = ONNX_WEIGHTS_NAME, unet_file_name: str = ONNX_WEIGHTS_NAME, vae_encoder_file_name: str = ONNX_WEIGHTS_NAME, + text_encoder_2_file_name: str = ONNX_WEIGHTS_NAME, local_files_only: bool = False, provider: str = "CPUExecutionProvider", session_options: Optional[ort.SessionOptions] = None, @@ -248,12 +285,10 @@ def _from_pretrained( raise ValueError("The provider `'TensorrtExecutionProvider'` is not supported") model_id = str(model_id) - sub_models_to_load, _, _ = cls.extract_init_dict(config) - sub_models_names = set(sub_models_to_load.keys()).intersection({"feature_extractor", "tokenizer", "scheduler"}) - sub_models = {} + patterns = set(config.keys()) + sub_models_to_load = patterns.intersection({"feature_extractor", "tokenizer", "tokenizer_2", "scheduler"}) if not os.path.isdir(model_id): - patterns = set(config.keys()) patterns.update({"vae_encoder", "vae_decoder"}) allow_patterns = {os.path.join(k, "*") for k in patterns if not k.startswith("_")} allow_patterns.update( @@ -262,6 +297,7 @@ def _from_pretrained( text_encoder_file_name, unet_file_name, vae_encoder_file_name, + text_encoder_2_file_name, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name, @@ -279,8 +315,9 @@ def _from_pretrained( ) new_model_save_dir = Path(model_id) - for name in sub_models_names: - library_name, library_classes = sub_models_to_load[name] + sub_models = {} + for name in sub_models_to_load: + library_name, library_classes = config[name] if library_classes is not None: library = importlib.import_module(library_name) class_obj = getattr(library, library_classes) @@ -291,18 +328,14 @@ def _from_pretrained( else: sub_models[name] = load_method(new_model_save_dir) - vae_encoder_path = new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name - - if not vae_encoder_path.is_file(): - logger.warning( - f"VAE encoder not found in {model_id} and will not be loaded for inference. This component is needed for some tasks." - ) - - inference_sessions = cls.load_model( + vae_decoder, text_encoder, unet, vae_encoder, text_encoder_2 = cls.load_model( vae_decoder_path=new_model_save_dir / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / vae_decoder_file_name, text_encoder_path=new_model_save_dir / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / text_encoder_file_name, unet_path=new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name, - vae_encoder_path=vae_encoder_path if vae_encoder_path.is_file() else None, + vae_encoder_path=new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name, + text_encoder_2_path=new_model_save_dir + / DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER + / text_encoder_2_file_name, provider=provider, session_options=session_options, provider_options=provider_options, @@ -317,12 +350,16 @@ def _from_pretrained( ) return cls( - *inference_sessions[:-1], + vae_decoder_session=vae_decoder, + text_encoder_session=text_encoder, + unet_session=unet, config=config, - tokenizer=sub_models["tokenizer"], - scheduler=sub_models["scheduler"], - feature_extractor=sub_models.pop("feature_extractor", None), - vae_encoder_session=inference_sessions[-1], + tokenizer=sub_models.get("tokenizer", None), + scheduler=sub_models.get("scheduler"), + feature_extractor=sub_models.get("feature_extractor", None), + tokenizer_2=sub_models.get("tokenizer_2", None), + vae_encoder_session=vae_encoder, + text_encoder_2_session=text_encoder_2, use_io_binding=use_io_binding, model_save_dir=model_save_dir, ) @@ -426,6 +463,7 @@ def __init__(self, session: ort.InferenceSession, parent_model: ORTModel): self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())} config_path = Path(session._model_path).parent / self.CONFIG_NAME self.config = self.parent_model._dict_from_json_file(config_path) if config_path.is_file() else {} + self.input_dtype = {inputs.name: _ORT_TO_NP_TYPE[inputs.type] for inputs in self.session.get_inputs()} @property def device(self): @@ -451,14 +489,26 @@ def forward(self, input_ids: np.ndarray): class ORTModelUnet(_ORTDiffusionModelPart): def __init__(self, session: ort.InferenceSession, parent_model: ORTModel): super().__init__(session, parent_model) - self.input_dtype = {inputs.name: _ORT_TO_NP_TYPE[inputs.type] for inputs in self.session.get_inputs()} - def forward(self, sample: np.ndarray, timestep: np.ndarray, encoder_hidden_states: np.ndarray): + def forward( + self, + sample: np.ndarray, + timestep: np.ndarray, + encoder_hidden_states: np.ndarray, + text_embeds: Optional[np.ndarray] = None, + time_ids: Optional[np.ndarray] = None, + ): onnx_inputs = { "sample": sample, "timestep": timestep, "encoder_hidden_states": encoder_hidden_states, } + + if text_embeds is not None: + onnx_inputs["text_embeds"] = text_embeds + if time_ids is not None: + onnx_inputs["time_ids"] = time_ids + outputs = self.session.run(None, onnx_inputs) return outputs @@ -494,3 +544,52 @@ def __call__(self, *args, **kwargs): class ORTStableDiffusionInpaintPipeline(ORTStableDiffusionPipelineBase, StableDiffusionInpaintPipelineMixin): def __call__(self, *args, **kwargs): return StableDiffusionInpaintPipelineMixin.__call__(self, *args, **kwargs) + + +class ORTStableDiffusionXLPipelineBase(ORTStableDiffusionPipelineBase): + auto_model_class = StableDiffusionXLPipeline + + def __init__( + self, + vae_decoder_session: ort.InferenceSession, + text_encoder_session: ort.InferenceSession, + unet_session: ort.InferenceSession, + config: Dict[str, Any], + tokenizer: CLIPTokenizer, + scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + feature_extractor: Optional[CLIPFeatureExtractor] = None, + vae_encoder_session: Optional[ort.InferenceSession] = None, + text_encoder_2_session: Optional[ort.InferenceSession] = None, + tokenizer_2: Optional[CLIPTokenizer] = None, + use_io_binding: Optional[bool] = None, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + ): + super().__init__( + vae_decoder_session=vae_decoder_session, + text_encoder_session=text_encoder_session, + unet_session=unet_session, + config=config, + tokenizer=tokenizer, + scheduler=scheduler, + feature_extractor=feature_extractor, + vae_encoder_session=vae_encoder_session, + text_encoder_2_session=text_encoder_2_session, + tokenizer_2=tokenizer_2, + use_io_binding=use_io_binding, + model_save_dir=model_save_dir, + ) + + # additional invisible-watermark dependency for SD XL + from ..pipelines.diffusers.watermark import StableDiffusionXLWatermarker + + self.watermark = StableDiffusionXLWatermarker() + + +class ORTStableDiffusionXLPipeline(ORTStableDiffusionXLPipelineBase, StableDiffusionXLPipelineMixin): + def __call__(self, *args, **kwargs): + return StableDiffusionXLPipelineMixin.__call__(self, *args, **kwargs) + + +class ORTStableDiffusionXLImg2ImgPipeline(ORTStableDiffusionXLPipelineBase, StableDiffusionXLImg2ImgPipelineMixin): + def __call__(self, *args, **kwargs): + return StableDiffusionXLImg2ImgPipelineMixin.__call__(self, *args, **kwargs) diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index 54c3143cc1..1784766c6a 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -170,7 +170,6 @@ class ORTModel(OptimizedModel): @classproperty def export_feature(cls): logger.warning(f"{cls.__name__}.export_feature is deprecated, and will be removed in optimum 2.0.") - try: feature = TasksManager.infer_task_from_model(cls.auto_model_class) except ValueError: diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion.py index 5d46668ec1..c133f8c6d2 100644 --- a/optimum/pipelines/diffusers/pipeline_stable_diffusion.py +++ b/optimum/pipelines/diffusers/pipeline_stable_diffusion.py @@ -20,7 +20,7 @@ import torch from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput -from .pipeline_utils import DiffusionPipelineMixin +from .pipeline_utils import DiffusionPipelineMixin, rescale_noise_cfg logger = logging.getLogger(__name__) @@ -179,12 +179,31 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = generator.randn(*shape).astype(dtype) + elif latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * np.float64(self.scheduler.init_noise_sigma) + + return latents + # Adapted from https://github.com/huggingface/diffusers/blob/v0.17.1/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L264 def __call__( self, prompt: Optional[Union[str, List[str]]] = None, - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -198,6 +217,7 @@ def __call__( return_dict: bool = True, callback: Optional[Callable[[int, int, np.ndarray], None]] = None, callback_steps: int = 1, + guidance_rescale: float = 0.0, ): r""" Function invoked when calling the pipeline for generation. @@ -206,9 +226,9 @@ def __call__( prompt (`Optional[Union[str, List[str]]]`, defaults to None): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - height (`int`, defaults to 512): + height (`Optional[int]`, defaults to None): The height in pixels of the generated image. - width (`int`, defaults to 512): + width (`Optional[int]`, defaults to None): The width in pixels of the generated image. num_inference_steps (`int`, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -253,6 +273,11 @@ def __call__( callback_steps (`int`, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + guidance_rescale (`float`, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: @@ -261,6 +286,8 @@ def __call__( list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + height = height or self.unet.config["sample_size"] * self.vae_scale_factor + width = width or self.unet.config["sample_size"] * self.vae_scale_factor # check inputs. Raise error if not correct self.check_inputs( @@ -292,25 +319,19 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) - num_unet_in_channels = self.unet.config.get("in_channels", 4) - # get the initial random noise unless the user supplied it - latents_dtype = prompt_embeds.dtype - latents_shape = ( - batch_size * num_images_per_prompt, - num_unet_in_channels, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) - if latents is None: - latents = generator.randn(*latents_shape).astype(latents_dtype) - elif latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - # set timesteps self.scheduler.set_timesteps(num_inference_steps) timesteps = self.scheduler.timesteps - latents = latents * np.float64(self.scheduler.init_noise_sigma) + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + self.unet.config.get("in_channels", 4), + height, + width, + prompt_embeds.dtype, + generator, + latents, + ) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -340,6 +361,9 @@ def __call__( if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 scheduler_output = self.scheduler.step( diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion_img2img.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion_img2img.py index ca99ed0469..d2c23b2b04 100644 --- a/optimum/pipelines/diffusers/pipeline_stable_diffusion_img2img.py +++ b/optimum/pipelines/diffusers/pipeline_stable_diffusion_img2img.py @@ -14,50 +14,21 @@ import inspect import logging -import warnings from typing import Callable, List, Optional, Union import numpy as np import PIL import torch from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput -from diffusers.utils import PIL_INTERPOLATION, deprecate +from diffusers.utils import deprecate from .pipeline_stable_diffusion import StableDiffusionPipelineMixin +from .pipeline_utils import preprocess logger = logging.getLogger(__name__) -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess with 8->64 -def preprocess(image): - warnings.warn( - ( - "The preprocess method is deprecated and will be removed in a future version. Please" - " use VaeImageProcessor.preprocess instead" - ), - FutureWarning, - ) - if isinstance(image, torch.Tensor): - return image - elif isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - w, h = image[0].size - w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 - - image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = 2.0 * image - 1.0 - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - return image - - class StableDiffusionImg2ImgPipelineMixin(StableDiffusionPipelineMixin): # Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionImg2ImgPipeline.check_inputs def check_inputs( @@ -207,7 +178,7 @@ def __call__( # set timesteps self.scheduler.set_timesteps(num_inference_steps) - image = preprocess(image).cpu().numpy() + image = preprocess(image) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py index 6a5c3accdc..07a808acab 100644 --- a/optimum/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py +++ b/optimum/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py @@ -105,8 +105,8 @@ def __call__( prompt: Union[str, List[str]], image: PIL.Image.Image, mask_image: PIL.Image.Image, - height: int = 512, - width: int = 512, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -132,9 +132,9 @@ def __call__( `Image`, or tensor representing an image batch which will be upscaled. mask_image (`PIL.Image.Image`): `Image`, or tensor representing a masked image batch which will be upscaled. - height (`int`, defaults to 512): + height (`Optional[int]`, defaults to None): The height in pixels of the generated image. - width (`int`, defaults to 512): + width (`Optional[int]`, defaults to None): The width in pixels of the generated image. num_inference_steps (`int`, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -187,6 +187,8 @@ def __call__( list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + height = height or self.unet.config["sample_size"] * self.vae_scale_factor + width = width or self.unet.config["sample_size"] * self.vae_scale_factor # check inputs. Raise error if not correct self.check_inputs( diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py new file mode 100644 index 0000000000..4c8c015fed --- /dev/null +++ b/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl.py @@ -0,0 +1,499 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import logging +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput + +from .pipeline_utils import DiffusionPipelineMixin, rescale_noise_cfg + + +logger = logging.getLogger(__name__) + + +class StableDiffusionXLPipelineMixin(DiffusionPipelineMixin): + # Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int, + do_classifier_free_guidance: bool, + negative_prompt: Optional[Union[str, list]], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + pooled_prompt_embeds: Optional[np.ndarray] = None, + negative_pooled_prompt_embeds: Optional[np.ndarray] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`Union[str, List[str]]`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`Optional[Union[str, list]]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + # get prompt text embeddings + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="np").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not np.array_equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + input_ids=text_input_ids.astype(text_encoder.input_dtype.get("input_ids", np.int32)) + ) + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds[-2] + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = np.concatenate(prompt_embeds_list, axis=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config["force_zeros_for_empty_prompt"] + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = np.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = np.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + negative_prompt_embeds = text_encoder( + input_ids=uncond_input.input_ids.astype(text_encoder.input_dtype.get("input_ids", np.int32)) + ) + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds[-2] + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + negative_prompt_embeds_list.append(negative_prompt_embeds) + negative_prompt_embeds = np.concatenate(negative_prompt_embeds, axis=-1) + + pooled_prompt_embeds = np.repeat(pooled_prompt_embeds, num_images_per_prompt, axis=0) + negative_pooled_prompt_embeds = np.repeat(negative_pooled_prompt_embeds, num_images_per_prompt, axis=0) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs + def check_inputs( + self, + prompt: Union[str, List[str]], + height: Optional[int], + width: Optional[int], + callback_steps: int, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + pooled_prompt_embeds: Optional[np.ndarray] = None, + negative_pooled_prompt_embeds: Optional[np.ndarray] = None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = generator.randn(*shape).astype(dtype) + elif latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * np.float64(self.scheduler.init_noise_sigma) + + return latents + + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + extra_step_kwargs = {} + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_eta: + extra_step_kwargs["eta"] = eta + + return extra_step_kwargs + + # Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.__call__ + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[np.random.RandomState] = None, + latents: Optional[np.ndarray] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + pooled_prompt_embeds: Optional[np.ndarray] = None, + negative_pooled_prompt_embeds: Optional[np.ndarray] = None, + output_type: str = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`Optional[Union[str, List[str]]]`, defaults to None): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`Optional[int]`, defaults to None): + The height in pixels of the generated image. + width (`Optional[int]`, defaults to None): + The width in pixels of the generated image. + num_inference_steps (`int`, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to 5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`Optional[Union[str, list]]`): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, defaults to 1): + The number of images to generate per prompt. + eta (`float`, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`Optional[np.random.RandomState]`, defaults to `None`):: + A np.random.RandomState to make generation deterministic. + latents (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + callback (Optional[Callable], defaults to `None`): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + guidance_rescale (`float`, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # 0. Default height and width to unet + height = height or self.unet.config["sample_size"] * self.vae_scale_factor + width = width or self.unet.config["sample_size"] * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if generator is None: + generator = np.random + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + self.unet.config.get("in_channels", 4), + height, + width, + prompt_embeds.dtype, + generator, + latents, + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = (original_size + crops_coords_top_left + target_size,) + add_time_ids = np.array(add_time_ids, dtype=prompt_embeds.dtype) + + if do_classifier_free_guidance: + prompt_embeds = np.concatenate((negative_prompt_embeds, prompt_embeds), axis=0) + add_text_embeds = np.concatenate((negative_pooled_prompt_embeds, add_text_embeds), axis=0) + add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0) + add_time_ids = np.repeat(add_time_ids, batch_size * num_images_per_prompt, axis=0) + + # Adapted from diffusers to extend it for other runtimes than ORT + timestep_dtype = self.unet.input_dtype.get("timestep", np.float32) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet( + sample=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + text_embeds=add_text_embeds, + time_ids=add_time_ids, + ) + noise_pred = noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + image = latents + else: + latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215) + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + image = self.watermark.apply_watermark(image) + + # TODO: add image_processor + image = np.clip(image / 2 + 0.5, 0, 1).transpose((0, 2, 3, 1)) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py b/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py new file mode 100644 index 0000000000..7be02dc5cb --- /dev/null +++ b/optimum/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py @@ -0,0 +1,506 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import logging +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput + +from .pipeline_utils import DiffusionPipelineMixin, preprocess, rescale_noise_cfg + + +logger = logging.getLogger(__name__) + + +class StableDiffusionXLImg2ImgPipelineMixin(DiffusionPipelineMixin): + # Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def _encode_prompt( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int, + do_classifier_free_guidance: bool, + negative_prompt: Optional[Union[str, list]], + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + pooled_prompt_embeds: Optional[np.ndarray] = None, + negative_pooled_prompt_embeds: Optional[np.ndarray] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`Union[str, List[str]]`): + prompt to be encoded + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`Optional[Union[str, list]]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + """ + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + # get prompt text embeddings + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="np").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not np.array_equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + input_ids=text_input_ids.astype(text_encoder.input_dtype.get("input_ids", np.int32)) + ) + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds[-2] + prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = np.concatenate(prompt_embeds_list, axis=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config["force_zeros_for_empty_prompt"] + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = np.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = np.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="np", + ) + + negative_prompt_embeds = text_encoder( + input_ids=uncond_input.input_ids.astype(text_encoder.input_dtype.get("input_ids", np.int32)) + ) + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds[-2] + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + negative_prompt_embeds_list.append(negative_prompt_embeds) + negative_prompt_embeds = np.concatenate(negative_prompt_embeds, axis=-1) + + pooled_prompt_embeds = np.repeat(pooled_prompt_embeds, num_images_per_prompt, axis=0) + negative_pooled_prompt_embeds = np.repeat(negative_pooled_prompt_embeds, num_images_per_prompt, axis=0) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.check_inputs + def check_inputs( + self, + prompt: Union[str, List[str]], + strength: float, + callback_steps: int, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :].numpy() + + return timesteps, num_inference_steps - t_start + + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, generator=None): + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + else: + init_latents = self.vae_encoder(sample=image)[0] * self.vae_decoder.config.get("scaling_factor", 0.18215) + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = np.concatenate([init_latents] * additional_image_per_prompt, axis=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = np.concatenate([init_latents], axis=0) + + # add noise to latents using the timesteps + noise = generator.randn(*init_latents.shape).astype(dtype) + init_latents = self.scheduler.add_noise( + torch.from_numpy(init_latents), torch.from_numpy(noise), torch.from_numpy(timestep) + ) + return init_latents.numpy() + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype + ): + if self.config.get("requires_aesthetics_score"): + add_time_ids = (original_size + crops_coords_top_left + (aesthetic_score,),) + add_neg_time_ids = (original_size + crops_coords_top_left + (negative_aesthetic_score,),) + else: + add_time_ids = (original_size + crops_coords_top_left + target_size,) + add_neg_time_ids = (original_size + crops_coords_top_left + target_size,) + + add_time_ids = np.array(add_time_ids, dtype=dtype) + add_neg_time_ids = np.array(add_neg_time_ids, dtype=dtype) + + return add_time_ids, add_neg_time_ids + + # Adapted from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.__call__ + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + image: Union[np.ndarray, PIL.Image.Image] = None, + strength: float = 0.3, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[np.random.RandomState] = None, + latents: Optional[np.ndarray] = None, + prompt_embeds: Optional[np.ndarray] = None, + negative_prompt_embeds: Optional[np.ndarray] = None, + pooled_prompt_embeds: Optional[np.ndarray] = None, + negative_pooled_prompt_embeds: Optional[np.ndarray] = None, + output_type: str = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, np.ndarray], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + aesthetic_score: float = 6.0, + negative_aesthetic_score: float = 2.5, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`Optional[Union[str, List[str]]]`, defaults to None): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`Union[np.ndarray, PIL.Image.Image]`): + `Image`, or tensor representing an image batch which will be upscaled. + strength (`float`, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to 5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`Optional[Union[str, list]]`): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, defaults to 1): + The number of images to generate per prompt. + eta (`float`, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`Optional[np.random.RandomState]`, defaults to `None`):: + A np.random.RandomState to make generation deterministic. + latents (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`Optional[np.ndarray]`, defaults to `None`): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] instead of a + plain tuple. + callback (Optional[Callable], defaults to `None`): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + guidance_rescale (`float`, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Check inputs. Raise error if not correct + self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 1. Define call parameters + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if generator is None: + generator = np.random + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 2. Encode input prompt + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self._encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + ) + + # 3. Preprocess image + image = preprocess(image) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps) + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength) + latent_timestep = np.repeat(timesteps[:1], batch_size * num_images_per_prompt, axis=0) + timestep_dtype = self.unet.input_dtype.get("timestep", np.float32) + + latents_dtype = prompt_embeds.dtype + image = image.astype(latents_dtype) + + # 5. Prepare latent variables + latents = self.prepare_latents( + image, latent_timestep, batch_size, num_images_per_prompt, latents_dtype, generator + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = {} + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_eta: + extra_step_kwargs["eta"] = eta + + height, width = latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 8. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids, add_neg_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype=prompt_embeds.dtype, + ) + + if do_classifier_free_guidance: + prompt_embeds = np.concatenate((negative_prompt_embeds, prompt_embeds), axis=0) + add_text_embeds = np.concatenate((negative_pooled_prompt_embeds, add_text_embeds), axis=0) + add_time_ids = np.concatenate((add_time_ids, add_time_ids), axis=0) + add_time_ids = np.repeat(add_time_ids, batch_size * num_images_per_prompt, axis=0) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + for i, t in enumerate(self.progress_bar(timesteps)): + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = self.unet( + sample=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + text_embeds=add_text_embeds, + time_ids=add_time_ids, + ) + noise_pred = noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = self.scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + image = latents + else: + latents = latents / self.vae_decoder.config.get("scaling_factor", 0.18215) + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [self.vae_decoder(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + image = self.watermark.apply_watermark(image) + + # TODO: add image_processor + image = np.clip(image / 2 + 0.5, 0, 1).transpose((0, 2, 3, 1)) + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/optimum/pipelines/diffusers/pipeline_utils.py b/optimum/pipelines/diffusers/pipeline_utils.py index 7092003875..27cc684cb3 100644 --- a/optimum/pipelines/diffusers/pipeline_utils.py +++ b/optimum/pipelines/diffusers/pipeline_utils.py @@ -13,7 +13,13 @@ # limitations under the License. +import warnings + +import numpy as np +import PIL +import torch from diffusers import ConfigMixin +from diffusers.utils import PIL_INTERPOLATION from PIL import Image from tqdm.auto import tqdm @@ -51,3 +57,46 @@ def progress_bar(self, iterable=None, total=None): return tqdm(total=total, **self._progress_bar_config) else: raise ValueError("Either `total` or `iterable` has to be defined.") + + +# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess with 8->64 +def preprocess(image): + warnings.warn( + ( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead" + ), + FutureWarning, + ) + if isinstance(image, torch.Tensor): + return image.cpu().numpy() + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0).cpu().numpy() + return image + + +# Adapted from https://github.com/huggingface/diffusers/blob/v0.18.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L58 +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = np.std(noise_pred_text, axis=tuple(range(1, noise_pred_text.ndim)), keepdims=True) + std_cfg = np.std(noise_cfg, axis=tuple(range(1, noise_cfg.ndim)), keepdims=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg diff --git a/optimum/pipelines/diffusers/watermark.py b/optimum/pipelines/diffusers/watermark.py new file mode 100644 index 0000000000..e07b4829c6 --- /dev/null +++ b/optimum/pipelines/diffusers/watermark.py @@ -0,0 +1,27 @@ +import numpy as np +from imwatermark import WatermarkEncoder + + +WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 +WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] + + +# Adapted from https://github.com/huggingface/diffusers/blob/v0.18.1/src/diffusers/pipelines/stable_diffusion_xl/watermark.py#L12 +class StableDiffusionXLWatermarker: + def __init__(self): + self.watermark = WATERMARK_BITS + self.encoder = WatermarkEncoder() + self.encoder.set_watermark("bits", self.watermark) + + def apply_watermark(self, images: np.array): + # can't encode images that are smaller than 256 + if images.shape[-1] < 256: + return images + + images = (255 * (images / 2 + 0.5)).transpose((0, 2, 3, 1)) + + images = np.array([self.encoder.encode(image, "dwtDct") for image in images]).transpose((0, 3, 1, 2)) + + np.clip(2 * (images / 255 - 0.5), -1.0, 1.0, out=images) + + return images diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index 3042721938..df0db3f39a 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -15,6 +15,7 @@ from .constant import ( CONFIG_NAME, + DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER, DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, DIFFUSION_MODEL_UNET_SUBFOLDER, DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER, diff --git a/optimum/utils/constant.py b/optimum/utils/constant.py index 2750d1190d..4497b5246d 100644 --- a/optimum/utils/constant.py +++ b/optimum/utils/constant.py @@ -18,4 +18,5 @@ DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER = "text_encoder" DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER = "vae_decoder" DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER = "vae_encoder" +DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER = "text_encoder_2" ONNX_WEIGHTS_NAME = "model.onnx" diff --git a/optimum/utils/dummy_diffusers_objects.py b/optimum/utils/dummy_diffusers_objects.py index a6171d5317..f85a0987d4 100644 --- a/optimum/utils/dummy_diffusers_objects.py +++ b/optimum/utils/dummy_diffusers_objects.py @@ -46,3 +46,25 @@ def __init__(self, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["diffusers"]) + + +class ORTStableDiffusionXLPipeline(metaclass=DummyObject): + _backends = ["diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["diffusers"]) + + +class ORTStableDiffusionXLImg2ImgPipeline(metaclass=DummyObject): + _backends = ["diffusers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["diffusers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["diffusers"]) diff --git a/optimum/utils/import_utils.py b/optimum/utils/import_utils.py index 9d78eccd82..5e6049bd41 100644 --- a/optimum/utils/import_utils.py +++ b/optimum/utils/import_utils.py @@ -34,7 +34,7 @@ TORCH_MINIMUM_VERSION = packaging.version.parse("1.11.0") TRANSFORMERS_MINIMUM_VERSION = packaging.version.parse("4.25.0") -DIFFUSERS_MINIMUM_VERSION = packaging.version.parse("0.17.0") +DIFFUSERS_MINIMUM_VERSION = packaging.version.parse("0.18.0") # This is the minimal required version to support some ONNX Runtime features diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 30c79052e6..d88f21fd2b 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -605,7 +605,11 @@ class DummyTimestepInputGenerator(DummyInputGenerator): Generates dummy time step inputs. """ - SUPPORTED_INPUT_NAMES = ("timestep",) + SUPPORTED_INPUT_NAMES = ( + "timestep", + "text_embeds", + "time_ids", + ) def __init__( self, @@ -617,7 +621,7 @@ def __init__( ): self.task = task self.vocab_size = normalized_config.vocab_size - + self.text_encoder_projection_dim = normalized_config.text_encoder_projection_dim if random_batch_size_range: low, high = random_batch_size_range self.batch_size = random.randint(low, high) @@ -626,7 +630,12 @@ def __init__( def generate(self, input_name: str, framework: str = "pt"): shape = [self.batch_size] - return self.random_int_tensor(shape, max_value=self.vocab_size, framework=framework) + + if input_name == "timestep": + return self.random_int_tensor(shape, max_value=self.vocab_size, framework=framework) + + shape.append(self.text_encoder_projection_dim if input_name == "text_embeds" else 6) + return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework) class DummyLabelsGenerator(DummyInputGenerator): diff --git a/setup.py b/setup.py index b310fff0ae..7da8e334da 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,7 @@ "diffusers>=0.17.0", "torchaudio", "einops", + "invisible-watermark", ] QUALITY_REQUIRE = ["black~=23.1", "ruff>=0.0.241,<=0.0.259"] diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index c28613c793..423875ca28 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -237,5 +237,6 @@ } PYTORCH_STABLE_DIFFUSION_MODEL = { - ("hf-internal-testing/tiny-stable-diffusion-torch"), + "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", + "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", } diff --git a/tests/exporters/onnx/test_exporters_onnx_cli.py b/tests/exporters/onnx/test_exporters_onnx_cli.py index 39342cb4d5..a92a5d1881 100644 --- a/tests/exporters/onnx/test_exporters_onnx_cli.py +++ b/tests/exporters/onnx/test_exporters_onnx_cli.py @@ -27,12 +27,13 @@ from optimum.exporters.error_utils import MinimumVersionError from optimum.exporters.onnx.__main__ import main_export from optimum.onnxruntime import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME +from optimum.utils.testing_utils import require_diffusers if is_torch_available(): from optimum.exporters.tasks import TasksManager -from ..exporters_utils import PYTORCH_EXPORT_MODELS_TINY +from ..exporters_utils import PYTORCH_EXPORT_MODELS_TINY, PYTORCH_STABLE_DIFFUSION_MODEL def _get_models_to_test(export_models_dict: Dict): @@ -134,6 +135,31 @@ def test_all_models_tested(self): if len(missing_models_set) > 0: self.fail(f"Not testing all models. Missing models: {missing_models_set}") + @parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL.items()) + @require_torch + @require_vision + @require_diffusers + def test_exporters_cli_pytorch_cpu_stable_diffusion(self, model_type: str, model_name: str): + self._onnx_export(model_name, model_type) + + @parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL.items()) + @require_torch_gpu + @require_vision + @require_diffusers + @slow + @pytest.mark.run_slow + def test_exporters_cli_pytorch_gpu_stable_diffusion(self, model_type: str, model_name: str): + self._onnx_export(model_name, model_type, device="cuda") + + @parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL.items()) + @require_torch_gpu + @require_vision + @require_diffusers + @slow + @pytest.mark.run_slow + def test_exporters_cli_fp16_stable_diffusion(self, model_type: str, model_name: str): + self._onnx_export(model_name, model_type, device="cuda", fp16=True) + @parameterized.expand(_get_models_to_test(PYTORCH_EXPORT_MODELS_TINY)) @require_torch @require_vision diff --git a/tests/exporters/onnx/test_onnx_export.py b/tests/exporters/onnx/test_onnx_export.py index 15d4bec7c6..c97c9ff58c 100644 --- a/tests/exporters/onnx/test_onnx_export.py +++ b/tests/exporters/onnx/test_onnx_export.py @@ -23,7 +23,7 @@ import onnx import pytest from parameterized import parameterized -from transformers import AutoConfig, is_tf_available, is_torch_available, set_seed +from transformers import AutoConfig, is_tf_available, is_torch_available from transformers.testing_utils import require_onnx, require_tf, require_torch, require_torch_gpu, require_vision, slow from optimum.exporters.error_utils import AtolError @@ -40,7 +40,7 @@ from optimum.exporters.onnx.base import ConfigBehavior from optimum.exporters.onnx.config import TextDecoderOnnxConfig from optimum.exporters.onnx.model_configs import WhisperOnnxConfig -from optimum.utils import DummyPastKeyValuesGenerator, NormalizedTextConfig, is_diffusers_available +from optimum.utils import ONNX_WEIGHTS_NAME, DummyPastKeyValuesGenerator, NormalizedTextConfig from optimum.utils.testing_utils import grid_parameters, require_diffusers from ..exporters_utils import ( @@ -54,9 +54,6 @@ if is_torch_available() or is_tf_available(): from optimum.exporters.tasks import TasksManager -if is_diffusers_available(): - from diffusers import StableDiffusionPipeline - SEED = 42 @@ -314,6 +311,30 @@ def _onnx_export( gc.collect() + def _onnx_export_sd(self, model_type: str, model_name: str, device="cpu"): + pipeline = TasksManager.get_model_from_task(model_type, model_name, device=device) + models_and_onnx_configs = get_stable_diffusion_models_for_export(pipeline) + output_names = [os.path.join(name_dir, ONNX_WEIGHTS_NAME) for name_dir in models_and_onnx_configs] + model, _ = models_and_onnx_configs["vae_encoder"] + model.forward = lambda sample: {"latent_sample": model.encode(x=sample)["latent_dist"].parameters} + + with TemporaryDirectory() as tmpdirname: + _, onnx_outputs = export_models( + models_and_onnx_configs=models_and_onnx_configs, + opset=14, + output_dir=Path(tmpdirname), + output_names=output_names, + device=device, + ) + validate_models_outputs( + models_and_onnx_configs=models_and_onnx_configs, + onnx_named_outputs=onnx_outputs, + output_dir=Path(tmpdirname), + atol=1e-3, + onnx_files_subpaths=output_names, + use_subprocess=False, + ) + def test_all_models_tested(self): # make sure we test all models missing_models_set = TasksManager._SUPPORTED_CLI_MODEL_TYPE - set(PYTORCH_EXPORT_MODELS_TINY.keys()) @@ -383,40 +404,23 @@ def test_tensorflow_export(self, test_name, name, model_name, task, onnx_config_ self._onnx_export(test_name, name, model_name, task, onnx_config_class_constructor, monolith=monolith) - @parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL) + @parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL.items()) @require_torch @require_vision @require_diffusers - def test_pytorch_export_for_stable_diffusion_models(self, model_name): - set_seed(SEED) - - pipeline = StableDiffusionPipeline.from_pretrained(model_name) - output_names = [ - "text_encoder/model.onnx", - "unet/model.onnx", - "vae_encoder/model.onnx", - "vae_decoder/model.onnx", - ] - models_and_onnx_configs = get_stable_diffusion_models_for_export(pipeline) - model, _ = models_and_onnx_configs["vae_encoder"] - model.forward = lambda sample: {"latent_sample": model.encode(x=sample)["latent_dist"].parameters} + def test_pytorch_export_for_stable_diffusion_models(self, model_type, model_name): + self._onnx_export_sd(model_type, model_name) - with TemporaryDirectory() as tmpdirname: - _, onnx_outputs = export_models( - models_and_onnx_configs=models_and_onnx_configs, - opset=14, - output_dir=Path(tmpdirname), - output_names=output_names, - device="cpu", # TODO: Add GPU test - ) - validate_models_outputs( - models_and_onnx_configs=models_and_onnx_configs, - onnx_named_outputs=onnx_outputs, - output_dir=Path(tmpdirname), - atol=1e-3, - onnx_files_subpaths=output_names, - use_subprocess=False, - ) + @parameterized.expand(PYTORCH_STABLE_DIFFUSION_MODEL.items()) + @require_torch + @require_vision + @require_diffusers + @require_torch_gpu + @slow + @pytest.mark.run_slow + @pytest.mark.gpu_test + def test_pytorch_export_for_stable_diffusion_models_cuda(self, model_type, model_name): + self._onnx_export_sd(model_type, model_name, device="cuda") class CustomWhisperOnnxConfig(WhisperOnnxConfig): diff --git a/tests/onnxruntime/test_stable_diffusion_pipeline.py b/tests/onnxruntime/test_stable_diffusion_pipeline.py index aba1df44c5..e7b3bc5ec6 100644 --- a/tests/onnxruntime/test_stable_diffusion_pipeline.py +++ b/tests/onnxruntime/test_stable_diffusion_pipeline.py @@ -22,6 +22,7 @@ from diffusers import ( OnnxStableDiffusionImg2ImgPipeline, StableDiffusionPipeline, + StableDiffusionXLPipeline, ) from diffusers.utils import floats_tensor, load_image from parameterized import parameterized @@ -36,6 +37,8 @@ ORTModelVaeEncoder, ORTStableDiffusionImg2ImgPipeline, ORTStableDiffusionInpaintPipeline, + ORTStableDiffusionXLImg2ImgPipeline, + ORTStableDiffusionXLPipeline, ) from optimum.utils import logging from optimum.utils.testing_utils import grid_parameters, require_diffusers @@ -179,20 +182,24 @@ def test_compare_to_diffusers(self, model_arch: str): pipeline = StableDiffusionPipeline.from_pretrained(MODEL_NAMES[model_arch]) pipeline.safety_checker = None - num_images_per_prompt, height, width = 1, 64, 64 - latents_shape = ( - num_images_per_prompt, + batch_size, num_images_per_prompt, height, width = 1, 2, 64, 64 + + latents = ort_pipeline.prepare_latents( + batch_size * num_images_per_prompt, ort_pipeline.unet.config["in_channels"], - height // ort_pipeline.vae_scale_factor, - width // ort_pipeline.vae_scale_factor, + height, + width, + dtype=np.float32, + generator=np.random.RandomState(0), ) - latents = np.random.randn(*latents_shape).astype(np.float32) + kwargs = { "prompt": "sailing ship in storm by Leonardo da Vinci", "num_inference_steps": 1, "num_images_per_prompt": num_images_per_prompt, "height": height, "width": width, + "guidance_rescale": 0.1, } for output_type in ["latent", "np"]: @@ -222,6 +229,71 @@ def test_image_reproducibility(self, model_arch: str): self.assertFalse(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) +class ORTStableDiffusionXLPipelineTest(ORTModelTestMixin): + SUPPORTED_ARCHITECTURES = [ + "stable-diffusion-xl", + ] + ORTMODEL_CLASS = ORTStableDiffusionXLPipeline + TASK = "stable-diffusion-xl" + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_compare_to_diffusers(self, model_arch: str): + ort_pipeline = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) + self.assertIsInstance(ort_pipeline.text_encoder, ORTModelTextEncoder) + self.assertIsInstance(ort_pipeline.text_encoder_2, ORTModelTextEncoder) + self.assertIsInstance(ort_pipeline.vae_decoder, ORTModelVaeDecoder) + self.assertIsInstance(ort_pipeline.vae_encoder, ORTModelVaeEncoder) + self.assertIsInstance(ort_pipeline.unet, ORTModelUnet) + self.assertIsInstance(ort_pipeline.config, Dict) + + pipeline = StableDiffusionXLPipeline.from_pretrained(MODEL_NAMES[model_arch]) + batch_size, num_images_per_prompt, height, width = 2, 2, 64, 64 + latents = ort_pipeline.prepare_latents( + batch_size * num_images_per_prompt, + ort_pipeline.unet.config["in_channels"], + height, + width, + dtype=np.float32, + generator=np.random.RandomState(0), + ) + + kwargs = { + "prompt": ["sailing ship in storm by Leonardo da Vinci"] * batch_size, + "num_inference_steps": 1, + "num_images_per_prompt": num_images_per_prompt, + "height": height, + "width": width, + "guidance_rescale": 0.1, + } + + for output_type in ["latent", "np"]: + ort_outputs = ort_pipeline(latents=latents, output_type=output_type, **kwargs).images + self.assertIsInstance(ort_outputs, np.ndarray) + with torch.no_grad(): + outputs = pipeline(latents=torch.from_numpy(latents), output_type=output_type, **kwargs).images + + # Compare model outputs + self.assertTrue(np.allclose(ort_outputs, outputs, atol=1e-4)) + # Compare model devices + self.assertEqual(pipeline.device, ort_pipeline.device) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_image_reproducibility(self, model_arch: str): + pipeline = self.ORTMODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True) + inputs = _generate_inputs() + height = 64 + width = 64 + np.random.seed(0) + ort_outputs_1 = pipeline(**inputs, height=height, width=width) + np.random.seed(0) + ort_outputs_2 = pipeline(**inputs, height=height, width=width) + ort_outputs_3 = pipeline(**inputs, height=height, width=width) + self.assertTrue(np.array_equal(ort_outputs_1.images[0], ort_outputs_2.images[0])) + self.assertFalse(np.array_equal(ort_outputs_1.images[0], ort_outputs_3.images[0])) + + class ORTStableDiffusionInpaintPipelineTest(ORTStableDiffusionPipelineBase): SUPPORTED_ARCHITECTURES = [ "stable-diffusion", @@ -262,3 +334,33 @@ def generate_inputs(self, height=128, width=128): ).resize((64, 64)) return inputs + + +class ORTStableDiffusionXLImg2ImgPipelineTest(ORTModelTestMixin): + SUPPORTED_ARCHITECTURES = [ + "stable-diffusion-xl", + ] + ORTMODEL_CLASS = ORTStableDiffusionXLImg2ImgPipeline + TASK = "stable-diffusion-xl" + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_diffusers + def test_inference(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + pipeline = self.ORTMODEL_CLASS.from_pretrained(self.onnx_model_dirs[model_arch]) + inputs = self.generate_inputs() + output = pipeline(**inputs, generator=np.random.RandomState(0)).images[0, -3:, -3:, -1] + expected_slice = np.array([0.6515, 0.5405, 0.4858, 0.5632, 0.5174, 0.5681, 0.4948, 0.4253, 0.5080]) + + self.assertTrue(np.allclose(output.flatten(), expected_slice, atol=1e-1)) + + def generate_inputs(self, height=128, width=128): + inputs = _generate_inputs() + inputs["image"] = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/in_paint/overture-creations-5sI6fQgYIuo.png" + ).resize((height, width)) + + inputs["strength"] = 0.75 + return inputs diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index 88a43d5590..f83acd91e6 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -79,6 +79,7 @@ "segformer": "hf-internal-testing/tiny-random-SegformerModel", "squeezebert": "hf-internal-testing/tiny-random-SqueezeBertModel", "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", + "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", "swin": "hf-internal-testing/tiny-random-SwinModel", "t5": "hf-internal-testing/tiny-random-t5", "vit": "hf-internal-testing/tiny-random-vit",