From d8574157ec11ab06af3e7fcd02fe460076246129 Mon Sep 17 00:00:00 2001 From: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com> Date: Fri, 30 Aug 2024 14:02:15 +0200 Subject: [PATCH] Add ControlNet support for SDXL (#675) * placeholders * export support poc * wrapup export * finish modeling * pipeline done * add compiler args for controlnet export * add test * update setup * fix * update setup * add doc * remove changes for debug * correct comment * placeholder * add sdxl specific inputs * export done * export fixed * doxstring * stage * pipeline done * doc * fix style / remove test * doc * Update optimum/neuron/pipelines/diffusers/pipeline_controlnet_sd_xl.py * Update optimum/exporters/neuron/convert.py * apply suggestions --- .../inference_tutorials/stable_diffusion.mdx | 51 ++ docs/source/package_reference/modeling.mdx | 5 + .../getting-started.ipynb | 5 +- .../stable-diffusion-txt2img.ipynb | 11 +- .../stable-diffusion-xl-txt2img.ipynb | 11 +- notebooks/text-classification/notebook.ipynb | 7 +- .../CodeLlama-7B-Compilation.ipynb | 29 +- .../text-generation/llama2-13b-chatbot.ipynb | 9 +- .../llama2-7b-fine-tuning.ipynb | 22 +- optimum/exporters/neuron/__main__.py | 34 +- optimum/exporters/neuron/convert.py | 114 ++- optimum/exporters/neuron/model_configs.py | 9 +- optimum/exporters/neuron/model_wrappers.py | 7 + optimum/exporters/neuron/utils.py | 29 +- optimum/neuron/__init__.py | 2 + optimum/neuron/modeling_diffusion.py | 16 +- .../diffusers/pipeline_controlnet.py | 2 +- .../diffusers/pipeline_controlnet_sd_xl.py | 747 +++++++++++++++++- 18 files changed, 987 insertions(+), 123 deletions(-) diff --git a/docs/source/inference_tutorials/stable_diffusion.mdx b/docs/source/inference_tutorials/stable_diffusion.mdx index 408924cd1..df9530875 100644 --- a/docs/source/inference_tutorials/stable_diffusion.mdx +++ b/docs/source/inference_tutorials/stable_diffusion.mdx @@ -635,4 +635,55 @@ compare.save("compare.png") /> + +## ControlNet with Stable Diffusion XL + +### Compile + +```bash +optimum-cli export neuron -m stabilityai/stable-diffusion-xl-base-1.0 --task stable-diffusion-xl --batch_size 1 --height 1024 --width 1024 --controlnet_ids diffusers/controlnet-canny-sdxl-1.0-small --num_images_per_prompt 1 sdxl_neuron_controlnet/ +``` + +### Text-to-Image + +```python +import cv2 +import numpy as np +from diffusers.utils import load_image +from PIL import Image +from optimum.neuron import NeuronStableDiffusionXLControlNetPipeline + +# Inputs +prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting" +negative_prompt = "low quality, bad quality, sketches" + +image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png" +) +image = np.array(image) +image = cv2.Canny(image, 100, 200) +image = image[:, :, None] +image = np.concatenate([image, image, image], axis=2) +image = Image.fromarray(image) + +controlnet_conditioning_scale = 0.5 # recommended for good generalization + +pipe = NeuronStableDiffusionXLControlNetPipeline.from_pretrained("sdxl_neuron_controlnet") + +images = pipe( + prompt, + negative_prompt=negative_prompt, + image=image, + controlnet_conditioning_scale=controlnet_conditioning_scale, +).images +images[0].save("hug_lab.png") +``` + +stable diffusion xl generated image with controlnet. + Are there any other stable diffusion features that you want us to support in 🤗`Optimum-neuron`? Please file an issue to [`Optimum-neuron` Github repo](https://github.com/huggingface/optimum-neuron) or discuss with us on [HuggingFace’s community forum](https://discuss.huggingface.co/c/optimum/), cheers 🤗 ! diff --git a/docs/source/package_reference/modeling.mdx b/docs/source/package_reference/modeling.mdx index f0258b772..121275fb2 100644 --- a/docs/source/package_reference/modeling.mdx +++ b/docs/source/package_reference/modeling.mdx @@ -139,3 +139,8 @@ The following Neuron model classes are available for stable diffusion tasks. ### NeuronStableDiffusionXLInpaintPipeline [[autodoc]] modeling_diffusion.NeuronStableDiffusionXLInpaintPipeline - __call__ + +### NeuronStableDiffusionXLControlNetPipeline + +[[autodoc]] modeling_diffusion.NeuronStableDiffusionXLControlNetPipeline + - __call__ diff --git a/notebooks/sentence-transformers/getting-started.ipynb b/notebooks/sentence-transformers/getting-started.ipynb index b72071dea..148022fe8 100644 --- a/notebooks/sentence-transformers/getting-started.ipynb +++ b/notebooks/sentence-transformers/getting-started.ipynb @@ -46,6 +46,7 @@ "source": [ "from optimum.neuron import NeuronModelForSentenceTransformers\n", "\n", + "\n", "# Sentence Transformers model from HuggingFace\n", "model_id = \"BAAI/bge-small-en-v1.5\"\n", "input_shapes = {\"batch_size\": 1, \"sequence_length\": 384} # mandatory shapes\n", @@ -88,9 +89,11 @@ "metadata": {}, "outputs": [], "source": [ - "from optimum.neuron import NeuronModelForSentenceTransformers\n", "from transformers import AutoTokenizer\n", "\n", + "from optimum.neuron import NeuronModelForSentenceTransformers\n", + "\n", + "\n", "model_id_or_path = \"bge_emb_inf2/\"\n", "tokenizer_id = \"BAAI/bge-small-en-v1.5\"\n", "\n", diff --git a/notebooks/stable-diffusion/stable-diffusion-txt2img.ipynb b/notebooks/stable-diffusion/stable-diffusion-txt2img.ipynb index dc26906f8..775fd254c 100644 --- a/notebooks/stable-diffusion/stable-diffusion-txt2img.ipynb +++ b/notebooks/stable-diffusion/stable-diffusion-txt2img.ipynb @@ -55,6 +55,7 @@ "source": [ "from optimum.neuron import NeuronStableDiffusionPipeline\n", "\n", + "\n", "model_id = \"stabilityai/stable-diffusion-2-1\"\n", "num_image_per_prompt = 1\n", "input_shapes = {\"batch_size\": 1, \"height\": 768, \"width\": 768, \"num_image_per_prompt\": num_image_per_prompt}\n", @@ -374,6 +375,8 @@ "outputs": [], "source": [ "from diffusers import DPMSolverMultistepScheduler\n", + "\n", + "\n", "stable_diffusion.scheduler = DPMSolverMultistepScheduler.from_config(stable_diffusion.scheduler.config)" ] }, @@ -384,11 +387,11 @@ "metadata": {}, "outputs": [], "source": [ - "from matplotlib import pyplot as plt\n", - "from matplotlib import image as mpimg\n", "import time\n", - "import copy\n", - "import numpy as np " + "\n", + "import numpy as np\n", + "from matplotlib import image as mpimg\n", + "from matplotlib import pyplot as plt" ] }, { diff --git a/notebooks/stable-diffusion/stable-diffusion-xl-txt2img.ipynb b/notebooks/stable-diffusion/stable-diffusion-xl-txt2img.ipynb index 189890438..c8fafda7c 100644 --- a/notebooks/stable-diffusion/stable-diffusion-xl-txt2img.ipynb +++ b/notebooks/stable-diffusion/stable-diffusion-xl-txt2img.ipynb @@ -56,6 +56,7 @@ "source": [ "from optimum.neuron import NeuronStableDiffusionXLPipeline\n", "\n", + "\n", "model_id = \"stabilityai/stable-diffusion-xl-base-1.0\"\n", "num_image_per_prompt = 1\n", "input_shapes = {\"batch_size\": 1, \"height\": 1024, \"width\": 1024, \"num_image_per_prompt\": num_image_per_prompt}\n", @@ -423,6 +424,8 @@ "outputs": [], "source": [ "from diffusers import DPMSolverMultistepScheduler\n", + "\n", + "\n", "stable_diffusion_xl.scheduler = DPMSolverMultistepScheduler.from_config(stable_diffusion_xl.scheduler.config)" ] }, @@ -433,11 +436,11 @@ "metadata": {}, "outputs": [], "source": [ - "from matplotlib import pyplot as plt\n", - "from matplotlib import image as mpimg\n", "import time\n", - "import copy\n", - "import numpy as np " + "\n", + "import numpy as np\n", + "from matplotlib import image as mpimg\n", + "from matplotlib import pyplot as plt" ] }, { diff --git a/notebooks/text-classification/notebook.ipynb b/notebooks/text-classification/notebook.ipynb index 7b0343d09..b03ac1502 100644 --- a/notebooks/text-classification/notebook.ipynb +++ b/notebooks/text-classification/notebook.ipynb @@ -85,6 +85,7 @@ "source": [ "from datasets import load_dataset\n", "\n", + "\n", "# Dataset id from huggingface.co/dataset\n", "dataset_id = \"philschmid/emotion\"\n", "\n", @@ -116,6 +117,7 @@ "source": [ "from random import randrange\n", "\n", + "\n", "random_id = randrange(len(raw_dataset['train']))\n", "raw_dataset['train'][random_id]\n", "# {'text': 'i feel isolated and alone in my trade', 'label': 0}" @@ -139,8 +141,11 @@ "metadata": {}, "outputs": [], "source": [ - "from transformers import AutoTokenizer\n", "import os\n", + "\n", + "from transformers import AutoTokenizer\n", + "\n", + "\n", "# Model id to load the tokenizer\n", "model_id = \"bert-base-uncased\"\n", "save_dataset_path = \"lm_dataset\"\n", diff --git a/notebooks/text-generation/CodeLlama-7B-Compilation.ipynb b/notebooks/text-generation/CodeLlama-7B-Compilation.ipynb index 1a50dbafc..f3ebf98fc 100644 --- a/notebooks/text-generation/CodeLlama-7B-Compilation.ipynb +++ b/notebooks/text-generation/CodeLlama-7B-Compilation.ipynb @@ -96,6 +96,7 @@ "source": [ "from optimum.neuron import pipeline\n", "\n", + "\n", "p = pipeline('text-generation', 'aws-neuron/CodeLlama-7b-hf-neuron-8xlarge')\n", "p(\"import socket\\n\\ndef ping_exponential_backoff(host: str):\",\n", " do_sample=True,\n", @@ -188,10 +189,12 @@ "outputs": [], "source": [ "from optimum.neuron import NeuronModelForCausalLM\n", + "\n", + "\n", "#num_cores should be changed based on the instance. inf2.24xlarge has 6 neuron processors (they have two cores each) so 12 total\n", "compiler_args = {\"num_cores\": 2, \"auto_cast_type\": 'fp16'}\n", "input_shapes = {\"batch_size\": 1, \"sequence_length\": 2048}\n", - "model = NeuronModelForCausalLM.from_pretrained(\"codellama/CodeLlama-7b-hf\", export=True, **compiler_args, **input_shapes) " + "model = NeuronModelForCausalLM.from_pretrained(\"codellama/CodeLlama-7b-hf\", export=True, **compiler_args, **input_shapes)" ] }, { @@ -211,8 +214,7 @@ "metadata": {}, "outputs": [], "source": [ - "model.save_pretrained(\"CodeLlama-7b-hf-neuron-8xlarge\")\n", - " " + "model.save_pretrained(\"CodeLlama-7b-hf-neuron-8xlarge\")\n" ] }, { @@ -251,10 +253,21 @@ "outputs": [], "source": [ "from huggingface_hub.hf_api import HfFolder\n", - "HfFolder.save_token('MY_HUGGINGFACE_TOKEN_HERE')\n", "\n", - "from huggingface_hub import login\n", - "from huggingface_hub import HfApi\n", + "\n", + "HfFolder.save_token('MY_HUGGINGFACE_TOKEN_HERE')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bdbc2537", + "metadata": {}, + "outputs": [], + "source": [ + "from huggingface_hub import HfApi, login\n", + "\n", + "\n", "api = HfApi()\n", "login()\n", "\n", @@ -264,9 +277,7 @@ " repo_type=\"model\",\n", " multi_commits=True,\n", " multi_commits_verbose=True,\n", - ")\n", - "\n", - "\n" + ")" ] } ], diff --git a/notebooks/text-generation/llama2-13b-chatbot.ipynb b/notebooks/text-generation/llama2-13b-chatbot.ipynb index 59ece3802..788ee756f 100644 --- a/notebooks/text-generation/llama2-13b-chatbot.ipynb +++ b/notebooks/text-generation/llama2-13b-chatbot.ipynb @@ -61,7 +61,6 @@ "outputs": [], "source": [ "# Special widgets are required for a nicer display\n", - "import sys\n", "!{sys.executable} -m pip install ipywidgets" ] }, @@ -103,6 +102,7 @@ "source": [ "from optimum.neuron import NeuronModelForCausalLM\n", "\n", + "\n", "compiler_args = {\"num_cores\": 24, \"auto_cast_type\": 'fp16'}\n", "input_shapes = {\"batch_size\": 1, \"sequence_length\": 2048}\n", "model = NeuronModelForCausalLM.from_pretrained(\n", @@ -153,6 +153,7 @@ "source": [ "from huggingface_hub import notebook_login\n", "\n", + "\n", "notebook_login(new_session=False)" ] }, @@ -175,6 +176,7 @@ "source": [ "from huggingface_hub import whoami\n", "\n", + "\n", "org = whoami()['name']\n", "\n", "repo_id = f\"{org}/llama-2-13b-chat-neuron\"\n", @@ -238,6 +240,7 @@ "source": [ "from optimum.neuron import NeuronModelForCausalLM\n", "\n", + "\n", "try:\n", " model\n", "except NameError:\n", @@ -262,6 +265,7 @@ "source": [ "from transformers import AutoTokenizer\n", "\n", + "\n", "tokenizer = AutoTokenizer.from_pretrained(\"NousResearch/Llama-2-13b-chat-hf\")" ] }, @@ -320,13 +324,10 @@ "source": [ "def format_chat_prompt(message, history, max_tokens):\n", " \"\"\" Convert a history of messages to a chat prompt\n", - " \n", - " \n", " Args:\n", " message(str): the new user message.\n", " history (List[str]): the list of user messages and assistant responses.\n", " max_tokens (int): the maximum number of input tokens accepted by the model.\n", - " \n", " Returns:\n", " a `str` prompt.\n", " \"\"\"\n", diff --git a/notebooks/text-generation/llama2-7b-fine-tuning.ipynb b/notebooks/text-generation/llama2-7b-fine-tuning.ipynb index f86eef356..c8db71270 100644 --- a/notebooks/text-generation/llama2-7b-fine-tuning.ipynb +++ b/notebooks/text-generation/llama2-7b-fine-tuning.ipynb @@ -154,9 +154,11 @@ } ], "source": [ - "from datasets import load_dataset\n", "from random import randrange\n", "\n", + "from datasets import load_dataset\n", + "\n", + "\n", "# Load dataset from the hub\n", "dataset = load_dataset(\"databricks/databricks-dolly-15k\", split=\"train\")\n", "\n", @@ -215,6 +217,7 @@ "source": [ "from random import randrange\n", "\n", + "\n", "print(format_dolly(dataset[randrange(len(dataset))]))" ] }, @@ -233,6 +236,7 @@ "source": [ "from transformers import AutoTokenizer\n", "\n", + "\n", "# Hugging Face model id\n", "model_id = \"philschmid/Llama-2-7b-hf\" # ungated\n", "# model_id = \"meta-llama/Llama-2-7b-hf\" # gated\n", @@ -257,10 +261,12 @@ "metadata": {}, "outputs": [], "source": [ - "from random import randint\n", "# add utils method to path for loading dataset\n", "import sys\n", - "sys.path.append(\"./scripts/utils\") # make sure you change this to the correct path \n", + "from random import randint\n", + "\n", + "\n", + "sys.path.append(\"./scripts/utils\") # make sure you change this to the correct path\n", "from pack_dataset import pack_dataset\n", "\n", "\n", @@ -337,7 +343,7 @@ "metadata": {}, "outputs": [], "source": [ - "# precompilation command \n", + "# precompilation command\n", "!MALLOC_ARENA_MAX=64 neuron_parallel_compile torchrun --nproc_per_node=32 scripts/run_clm.py \\\n", " --model_id {model_id} \\\n", " --dataset_path {dataset_path} \\\n", @@ -455,9 +461,11 @@ "metadata": {}, "outputs": [], "source": [ - "from optimum.neuron import NeuronModelForCausalLM\n", "from transformers import AutoTokenizer\n", "\n", + "from optimum.neuron import NeuronModelForCausalLM\n", + "\n", + "\n", "compiler_args = {\"num_cores\": 2, \"auto_cast_type\": 'fp16'}\n", "input_shapes = {\"batch_size\": 1, \"sequence_length\": 2048}\n", "\n", @@ -502,13 +510,13 @@ "def format_dolly_infernece(sample):\n", " instruction = f\"### Instruction\\n{sample['instruction']}\"\n", " context = f\"### Context\\n{sample['context']}\" if \"context\" in sample else None\n", - " response = f\"### Answer\\n\"\n", + " response = \"### Answer\\n\"\n", " # join all the parts together\n", " prompt = \"\\n\\n\".join([i for i in [instruction, context, response] if i is not None])\n", " return prompt\n", "\n", "\n", - "def generate(sample): \n", + "def generate(sample):\n", " prompt = format_dolly_infernece(sample)\n", " inputs = tokenizer(prompt, return_tensors=\"pt\")\n", " outputs = model.generate(**inputs,\n", diff --git a/optimum/exporters/neuron/__main__.py b/optimum/exporters/neuron/__main__.py index 74b4d1cf1..2fa28e68a 100644 --- a/optimum/exporters/neuron/__main__.py +++ b/optimum/exporters/neuron/__main__.py @@ -52,7 +52,6 @@ check_mandatory_input_shapes, get_encoder_decoder_models_for_export, get_stable_diffusion_models_for_export, - load_controlnets, replace_stable_diffusion_submodels, ) @@ -76,7 +75,7 @@ from transformers import PreTrainedModel if is_diffusers_available(): - from diffusers import ControlNetModel, DiffusionPipeline, ModelMixin, StableDiffusionPipeline + from diffusers import DiffusionPipeline, ModelMixin, StableDiffusionPipeline logger = logging.get_logger() @@ -207,7 +206,7 @@ def normalize_stable_diffusion_input_shapes( def infer_stable_diffusion_shapes_from_diffusers( input_shapes: Dict[str, Dict[str, int]], model: Union["StableDiffusionPipeline", "StableDiffusionXLPipeline"], - controlnets: Optional[List["ControlNetModel"]] = None, + has_controlnets: bool, ): if model.tokenizer is not None: sequence_length = model.tokenizer.model_max_length @@ -242,7 +241,10 @@ def infer_stable_diffusion_shapes_from_diffusers( ) # ControlNet - if controlnets: + if has_controlnets: + encoder_hidden_size = model.text_encoder.config.hidden_size + if hasattr(model, "text_encoder_2"): + encoder_hidden_size += model.text_encoder_2.config.hidden_size input_shapes["controlnet"] = { "batch_size": input_shapes["unet"]["batch_size"], "sequence_length": sequence_length, @@ -250,7 +252,7 @@ def infer_stable_diffusion_shapes_from_diffusers( "height": scaled_height, "width": scaled_width, "vae_scale_factor": vae_scale_factor, - "encoder_hidden_size": model.text_encoder.config.hidden_size, + "encoder_hidden_size": encoder_hidden_size, } return input_shapes @@ -272,7 +274,7 @@ def get_submodels_and_neuron_configs( lora_weight_names: Optional[Union[str, List[str]]] = None, lora_adapter_names: Optional[Union[str, List[str]]] = None, lora_scales: Optional[Union[float, List[float]]] = None, - controlnets: Optional[List["ControlNetModel"]] = None, + controlnet_ids: Optional[Union[str, List[str]]] = None, ): is_stable_diffusion = "stable-diffusion" in task is_encoder_decoder = ( @@ -295,7 +297,7 @@ def get_submodels_and_neuron_configs( lora_weight_names=lora_weight_names, lora_adapter_names=lora_adapter_names, lora_scales=lora_scales, - controlnets=controlnets, + controlnet_ids=controlnet_ids, ) elif is_encoder_decoder: optional_outputs = {"output_attentions": output_attentions, "output_hidden_states": output_hidden_states} @@ -356,7 +358,7 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion( lora_weight_names: Optional[Union[str, List[str]]] = None, lora_adapter_names: Optional[Union[str, List[str]]] = None, lora_scales: Optional[Union[float, List[float]]] = None, - controlnets: Optional[List["ControlNetModel"]] = None, + controlnet_ids: Optional[Union[str, List[str]]] = None, ): check_compiler_compatibility_for_stable_diffusion() model = replace_stable_diffusion_submodels(model, submodels) @@ -367,7 +369,7 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion( input_shapes = infer_stable_diffusion_shapes_from_diffusers( input_shapes=input_shapes, model=model, - controlnets=controlnets, + has_controlnets=controlnet_ids is not None, ) # Saving the model config and preprocessor as this is needed sometimes. @@ -396,7 +398,7 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion( lora_weight_names=lora_weight_names, lora_adapter_names=lora_adapter_names, lora_scales=lora_scales, - controlnets=controlnets, + controlnet_ids=controlnet_ids, controlnet_input_shapes=input_shapes.get("controlnet", None), ) output_model_names = { @@ -414,13 +416,14 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion( ) # ControlNet models - if controlnets: - for idx in range(len(controlnets)): + if controlnet_ids: + if isinstance(controlnet_ids, str): + controlnet_ids = [controlnet_ids] + for idx in range(len(controlnet_ids)): controlnet_name = DIFFUSION_MODEL_CONTROLNET_NAME + "_" + str(idx) output_model_names[controlnet_name] = os.path.join(controlnet_name, NEURON_FILE_NAME) del model - del controlnets return models_and_neuron_configs, output_model_names @@ -475,7 +478,7 @@ def load_models_and_neuron_configs( lora_weight_names: Optional[Union[str, List[str]]], lora_adapter_names: Optional[Union[str, List[str]]], lora_scales: Optional[Union[float, List[float]]], - controlnet_ids: Optional[Union[str, List[str]]], + controlnet_ids: Optional[Union[str, List[str]]] = None, output_attentions: bool = False, output_hidden_states: bool = False, library_name: Optional[str] = None, @@ -500,7 +503,6 @@ def load_models_and_neuron_configs( } if model is None: model = TasksManager.get_model_from_task(**model_kwargs) - controlnets = load_controlnets(controlnet_ids) models_and_neuron_configs, output_model_names = get_submodels_and_neuron_configs( model=model, @@ -518,7 +520,7 @@ def load_models_and_neuron_configs( lora_weight_names=lora_weight_names, lora_adapter_names=lora_adapter_names, lora_scales=lora_scales, - controlnets=controlnets, + controlnet_ids=controlnet_ids, ) return models_and_neuron_configs, output_model_names diff --git a/optimum/exporters/neuron/convert.py b/optimum/exporters/neuron/convert.py index 810bbbedf..51607f82e 100644 --- a/optimum/exporters/neuron/convert.py +++ b/optimum/exporters/neuron/convert.py @@ -194,6 +194,10 @@ def validate_model_outputs( ref_inputs = tuple(ref_inputs.values()) ref_outputs = reference_model(*ref_inputs) neuron_inputs = tuple(inputs.values()) + elif "controlnet" in getattr(config._config, "_class_name", "").lower(): + reference_model = config.patch_model_for_export(reference_model, ref_inputs) + neuron_inputs = ref_inputs = tuple(ref_inputs.values()) + ref_outputs = reference_model(*ref_inputs) else: ref_outputs = reference_model(**ref_inputs) neuron_inputs = tuple(config.flatten_inputs(inputs).values()) @@ -351,66 +355,58 @@ def export_models( output_path = output_dir / output_file_name output_path.parent.mkdir(parents=True, exist_ok=True) - try: - # TODO: Remove after the weights/neff separation compilation of sdxl is patched by a neuron sdk release: https://github.com/aws-neuron/aws-neuron-sdk/issues/859 - if not inline_weights_to_neff and getattr(sub_neuron_config, "is_sdxl", False): - logger.warning( - "The compilation of SDXL's unet with the weights/neff separation is broken since the Neuron sdk 2.18 release. `inline_weights_to_neff` will be set to True and the caching will be disabled. If you still want to separate the neff and weights, please downgrade your Neuron setup to the 2.17.1 release." - ) - inline_weights_to_neff = True - - start_time = time.time() - neuron_inputs, neuron_outputs = export( - model=submodel, - config=sub_neuron_config, - output=output_path, - compiler_workdir=compiler_workdir, - inline_weights_to_neff=inline_weights_to_neff, - optlevel=optlevel, - **compiler_kwargs, - ) - compilation_time = time.time() - start_time - total_compilation_time += compilation_time - logger.info(f"[Compilation Time] {np.round(compilation_time, 2)} seconds.") - all_inputs[model_name] = neuron_inputs - all_outputs[model_name] = neuron_outputs - # Add neuron specific configs to model components' original config - if hasattr(submodel, "config"): - model_config = submodel.config - elif configs and (model_name in configs.keys()): - model_config = configs[model_name] - else: - raise AttributeError("Cannot find model's configuration, please pass it with `configs`.") - - if is_diffusers_available() and isinstance(model_config, FrozenDict): - model_config = OrderedDict(model_config) - model_config = DiffusersPretrainedConfig.from_dict(model_config) - - model_config = store_compilation_config( - config=model_config, - input_shapes=sub_neuron_config.input_shapes, - compiler_kwargs=compiler_kwargs, - input_names=neuron_inputs, - output_names=neuron_outputs, - dynamic_batch_size=sub_neuron_config.dynamic_batch_size, - compiler_type=NEURON_COMPILER_TYPE, - compiler_version=NEURON_COMPILER_VERSION, - inline_weights_to_neff=inline_weights_to_neff, - optlevel=optlevel, - model_type=getattr(sub_neuron_config, "MODEL_TYPE", None), - task=getattr(sub_neuron_config, "task", None), - output_attentions=getattr(sub_neuron_config, "output_attentions", False), - output_hidden_states=getattr(sub_neuron_config, "output_hidden_states", False), - ) - model_config.save_pretrained(output_path.parent) - compile_configs[model_name] = model_config - except Exception as e: - failed_models.append((i, model_name)) - output_path.parent.rmdir() - logger.error( - f"An error occured when trying to trace {model_name} with the error message: {e}.\n" - f"The export is failed and {model_name} neuron model won't be stored." + # TODO: Remove after the weights/neff separation compilation of sdxl is patched by a neuron sdk release: https://github.com/aws-neuron/aws-neuron-sdk/issues/859 + if not inline_weights_to_neff and getattr(sub_neuron_config, "is_sdxl", False): + logger.warning( + "The compilation of SDXL's unet with the weights/neff separation is broken since the Neuron SDK 2.18 release. `inline_weights_to_neff` will be set to True and the caching will be disabled. If you still want to separate the neff and weights, please downgrade your Neuron setup to the 2.17.1 release." ) + inline_weights_to_neff = True + + start_time = time.time() + neuron_inputs, neuron_outputs = export( + model=submodel, + config=sub_neuron_config, + output=output_path, + compiler_workdir=compiler_workdir, + inline_weights_to_neff=inline_weights_to_neff, + optlevel=optlevel, + **compiler_kwargs, + ) + compilation_time = time.time() - start_time + total_compilation_time += compilation_time + logger.info(f"[Compilation Time] {np.round(compilation_time, 2)} seconds.") + all_inputs[model_name] = neuron_inputs + all_outputs[model_name] = neuron_outputs + # Add neuron specific configs to model components' original config + if hasattr(submodel, "config"): + model_config = submodel.config + elif configs and (model_name in configs.keys()): + model_config = configs[model_name] + else: + raise AttributeError("Cannot find model's configuration, please pass it with `configs`.") + + if is_diffusers_available() and isinstance(model_config, FrozenDict): + model_config = OrderedDict(model_config) + model_config = DiffusersPretrainedConfig.from_dict(model_config) + + model_config = store_compilation_config( + config=model_config, + input_shapes=sub_neuron_config.input_shapes, + compiler_kwargs=compiler_kwargs, + input_names=neuron_inputs, + output_names=neuron_outputs, + dynamic_batch_size=sub_neuron_config.dynamic_batch_size, + compiler_type=NEURON_COMPILER_TYPE, + compiler_version=NEURON_COMPILER_VERSION, + inline_weights_to_neff=inline_weights_to_neff, + optlevel=optlevel, + model_type=getattr(sub_neuron_config, "MODEL_TYPE", None), + task=getattr(sub_neuron_config, "task", None), + output_attentions=getattr(sub_neuron_config, "output_attentions", False), + output_hidden_states=getattr(sub_neuron_config, "output_hidden_states", False), + ) + model_config.save_pretrained(output_path.parent) + compile_configs[model_name] = model_config logger.info(f"[Total compilation Time] {np.round(total_compilation_time, 2)} seconds.") diff --git a/optimum/exporters/neuron/model_configs.py b/optimum/exporters/neuron/model_configs.py index 689de5331..1f59c9031 100644 --- a/optimum/exporters/neuron/model_configs.py +++ b/optimum/exporters/neuron/model_configs.py @@ -681,12 +681,19 @@ class ControlNetNeuronConfig(VisionNeuronConfig): DUMMY_INPUT_GENERATOR_CLASSES = ( DummyVisionInputGenerator, - DummyControNetInputGenerator, + DummyControNetInputGenerator, # Instead of `encoder_hidden_states` generated by `DummySeq2SeqDecoderTextInputGenerator` + DummyTimestepInputGenerator, + DummySeq2SeqDecoderTextInputGenerator, ) @property def inputs(self) -> List[str]: common_inputs = ["sample", "timestep", "encoder_hidden_states", "controlnet_cond", "conditioning_scale"] + + if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": + common_inputs.append("text_embeds") + common_inputs.append("time_ids") + return common_inputs @property diff --git a/optimum/exporters/neuron/model_wrappers.py b/optimum/exporters/neuron/model_wrappers.py index 8a47c779e..9c83168ce 100644 --- a/optimum/exporters/neuron/model_wrappers.py +++ b/optimum/exporters/neuron/model_wrappers.py @@ -94,12 +94,19 @@ def forward(self, *inputs): controlnet_cond = ordered_inputs.pop("controlnet_cond", None) conditioning_scale = ordered_inputs.pop("conditioning_scale", None) + # Additional conditions for the Stable Diffusion XL UNet. + added_cond_kwargs = { + "text_embeds": ordered_inputs.pop("text_embeds", None), + "time_ids": ordered_inputs.pop("time_ids", None), + } + out_tuple = self.model( sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states, controlnet_cond=controlnet_cond, conditioning_scale=conditioning_scale, + added_cond_kwargs=added_cond_kwargs, guess_mode=False, # TODO: support guess mode of ControlNet return_dict=False, **ordered_inputs, diff --git a/optimum/exporters/neuron/utils.py b/optimum/exporters/neuron/utils.py index 0d9f863bf..312d1a91c 100644 --- a/optimum/exporters/neuron/utils.py +++ b/optimum/exporters/neuron/utils.py @@ -119,7 +119,7 @@ def get_stable_diffusion_models_for_export( lora_weight_names: Optional[List[str]] = None, lora_adapter_names: Optional[List[str]] = None, lora_scales: Optional[List[float]] = None, - controlnets: Optional[List["ControlNetModel"]] = None, + controlnet_ids: Optional[Union[str, List[str]]] = None, controlnet_input_shapes: Optional[Dict[str, int]] = None, ) -> Dict[str, Tuple[Union["PreTrainedModel", "ModelMixin"], "NeuronDefaultConfig"]]: """ @@ -153,8 +153,8 @@ def get_stable_diffusion_models_for_export( List of adapter names to be used for referencing the loaded adapter models. lora_scales (`Optional[List[float]]`, defaults to `None`): List of scaling factors for lora adapters. - controlnets (`Optional[List["ControlNetModel"]]]`, defaults to `None`): - One or multiple ControlNets providing additional conditioning to the `unet` during the denoising process. If you set multiple + controlnet_ids (`Optional[Union[str, List[str]]]`, defaults to `None`): + Model ID of one or multiple ControlNets providing additional conditioning to the `unet` during the denoising process. If you set multiple ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. controlnet_input_shapes (`Optional[Dict[str, int]]`, defaults to `None`): Static shapes used for compiling ControlNets. @@ -170,6 +170,7 @@ def get_stable_diffusion_models_for_export( lora_weight_names=lora_weight_names, lora_adapter_names=lora_adapter_names, lora_scales=lora_scales, + controlnet_ids=controlnet_ids, ) library_name = "diffusers" @@ -227,7 +228,7 @@ def get_stable_diffusion_models_for_export( if task == "stable-diffusion-xl": unet_neuron_config.is_sdxl = True - unet_neuron_config.with_controlnet = True if controlnets else False + unet_neuron_config.with_controlnet = True if controlnet_ids else False models_for_export[DIFFUSION_MODEL_UNET_NAME] = (unet, unet_neuron_config) @@ -266,8 +267,12 @@ def get_stable_diffusion_models_for_export( models_for_export[DIFFUSION_MODEL_VAE_DECODER_NAME] = (vae_decoder, vae_decoder_neuron_config) # ControlNet - if controlnets: - for idx, controlnet in enumerate(controlnets): + if controlnet_ids: + if isinstance(controlnet_ids, str): + controlnet_ids = [controlnet_ids] + for idx in range(len(controlnet_ids)): + controlnet_name = DIFFUSION_MODEL_CONTROLNET_NAME + "_" + str(idx) + controlnet = models_for_export[controlnet_name] controlnet_config_constructor = TasksManager.get_exporter_config_constructor( model=controlnet, exporter="neuron", @@ -281,7 +286,7 @@ def get_stable_diffusion_models_for_export( dynamic_batch_size=dynamic_batch_size, **controlnet_input_shapes, ) - models_for_export[DIFFUSION_MODEL_CONTROLNET_NAME + "_" + str(idx)] = ( + models_for_export[controlnet_name] = ( controlnet, controlnet_neuron_config, ) @@ -351,6 +356,7 @@ def get_submodels_for_export_stable_diffusion( lora_weight_names: Optional[Union[str, List[str]]] = None, lora_adapter_names: Optional[Union[str, List[str]]] = None, lora_scales: Optional[List[float]] = None, + controlnet_ids: Optional[Union[str, List[str]]] = None, ) -> Dict[str, Union["PreTrainedModel", "ModelMixin"]]: """ Returns the components of a Stable Diffusion model. @@ -418,6 +424,15 @@ def get_submodels_for_export_stable_diffusion( vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample) models_for_export.append((DIFFUSION_MODEL_VAE_DECODER_NAME, vae_decoder)) + # ControlNets + controlnets = load_controlnets(controlnet_ids) + if controlnets: + for idx, controlnet in enumerate(controlnets): + controlnet.config.text_encoder_projection_dim = pipeline.unet.config.text_encoder_projection_dim + controlnet.config.requires_aesthetics_score = pipeline.unet.config.requires_aesthetics_score + controlnet.config.time_cond_proj_dim = pipeline.unet.config.time_cond_proj_dim + models_for_export.append((DIFFUSION_MODEL_CONTROLNET_NAME + "_" + str(idx), controlnet)) + return OrderedDict(models_for_export) diff --git a/optimum/neuron/__init__.py b/optimum/neuron/__init__.py index 9f989a6c2..2b8d7b81b 100644 --- a/optimum/neuron/__init__.py +++ b/optimum/neuron/__init__.py @@ -58,6 +58,7 @@ "NeuronStableDiffusionXLImg2ImgPipeline", "NeuronStableDiffusionXLInpaintPipeline", "NeuronStableDiffusionControlNetPipeline", + "NeuronStableDiffusionXLControlNetPipeline", ], "modeling_decoder": ["NeuronDecoderModel"], "modeling_seq2seq": ["NeuronModelForSeq2SeqLM"], @@ -100,6 +101,7 @@ NeuronStableDiffusionInstructPix2PixPipeline, NeuronStableDiffusionPipeline, NeuronStableDiffusionPipelineBase, + NeuronStableDiffusionXLControlNetPipeline, NeuronStableDiffusionXLImg2ImgPipeline, NeuronStableDiffusionXLInpaintPipeline, NeuronStableDiffusionXLPipeline, diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py index 4e3e193c1..30a310bf8 100644 --- a/optimum/neuron/modeling_diffusion.py +++ b/optimum/neuron/modeling_diffusion.py @@ -1058,15 +1058,15 @@ def forward( inputs = (sample, timestep, encoder_hidden_states) if timestep_cond is not None: inputs = inputs + (timestep_cond,) - if added_cond_kwargs is not None: - text_embeds = added_cond_kwargs.pop("text_embeds", None) - time_ids = added_cond_kwargs.pop("time_ids", None) - inputs = inputs + (text_embeds, time_ids) if mid_block_additional_residual is not None: inputs = inputs + (mid_block_additional_residual,) if down_block_additional_residuals is not None: for idx in range(len(down_block_additional_residuals)): inputs = inputs + (down_block_additional_residuals[idx],) + if added_cond_kwargs: + text_embeds = added_cond_kwargs.pop("text_embeds", None) + time_ids = added_cond_kwargs.pop("time_ids", None) + inputs = inputs + (text_embeds, time_ids) outputs = self.model(*inputs) return outputs @@ -1139,9 +1139,15 @@ def forward( controlnet_cond: torch.Tensor, conditioning_scale: float = 1.0, guess_mode: bool = False, + added_cond_kwargs: Optional[Dict] = None, return_dict: bool = True, ) -> Union["ControlNetOutput", Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: + timestep = timestep.expand((sample.shape[0],)).to(torch.long) inputs = (sample, timestep, encoder_hidden_states, controlnet_cond, conditioning_scale) + if added_cond_kwargs: + text_embeds = added_cond_kwargs.pop("text_embeds", None) + time_ids = added_cond_kwargs.pop("time_ids", None) + inputs += (text_embeds, time_ids) outputs = self.model(*inputs) if guess_mode: @@ -1320,7 +1326,7 @@ class NeuronStableDiffusionXLInpaintPipeline( class NeuronStableDiffusionXLControlNetPipeline( - NeuronStableDiffusionPipelineBase, NeuronStableDiffusionXLControlNetPipelineMixin + NeuronStableDiffusionXLPipelineBase, NeuronStableDiffusionXLControlNetPipelineMixin ): __call__ = NeuronStableDiffusionXLControlNetPipelineMixin.__call__ diff --git a/optimum/neuron/pipelines/diffusers/pipeline_controlnet.py b/optimum/neuron/pipelines/diffusers/pipeline_controlnet.py index 641123635..690872c83 100644 --- a/optimum/neuron/pipelines/diffusers/pipeline_controlnet.py +++ b/optimum/neuron/pipelines/diffusers/pipeline_controlnet.py @@ -238,7 +238,7 @@ def __call__( it will be overriden by the static batch size of neuron (except for dynamic batching). eta (`float`, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies - to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + to the [`diffusers.schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`Optional[Union[torch.Generator, List[torch.Generator]]]`, defaults to `None`): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. diff --git a/optimum/neuron/pipelines/diffusers/pipeline_controlnet_sd_xl.py b/optimum/neuron/pipelines/diffusers/pipeline_controlnet_sd_xl.py index 69e80292f..5555add8e 100644 --- a/optimum/neuron/pipelines/diffusers/pipeline_controlnet_sd_xl.py +++ b/optimum/neuron/pipelines/diffusers/pipeline_controlnet_sd_xl.py @@ -14,9 +14,748 @@ # limitations under the License. """Override some diffusers API for NeuronStableDiffusionXLControlNetPipelineMixin""" +import copy +import logging +from typing import Any, Callable, Dict, List, Optional, Tuple, Union -class NeuronStableDiffusionXLControlNetPipelineMixin: - def __call__(self): - raise NotImplementedError( - "`NeuronStableDiffusionXLControlNetPipelineMixin` is not yet supported but will come soon." +import torch +from diffusers import StableDiffusionXLControlNetPipeline +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import PipelineImageInput +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput + +from .pipeline_utils import StableDiffusionXLPipelineMixin + + +logger = logging.getLogger(__name__) + + +class NeuronStableDiffusionXLControlNetPipelineMixin( + StableDiffusionXLPipelineMixin, StableDiffusionXLControlNetPipeline +): + # Adapted from https://github.com/huggingface/diffusers/blob/v0.29.2/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py#L625 + # Replace class types with Neuron ones + def check_inputs( + self, + prompt, + prompt_2, + image, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Check `image` + if self.controlnet.__class__.__name__ == "NeuronControlNetModel": + self.check_image(image, prompt, prompt_embeds) + elif self.controlnet.__class__.__name__ == "NeuronMultiControlNetModel": + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are not supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + else: + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + raise ValueError( + f"{self.controlnet.__class__.__name__} is not a supported class for ControlNet. The class must be either `NeuronControlNetModel` or `NeuronMultiControlNetModel`." + ) + + # Check `controlnet_conditioning_scale` + if self.controlnet.__class__.__name__ == "NeuronControlNetModel": + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif self.controlnet.__class__.__name__ == "NeuronMultiControlNetModel": + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are not supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + raise ValueError( + f"{self.controlnet.__class__.__name__} is not a supported class for ControlNet. The class must be either `NeuronControlNetModel` or `NeuronMultiControlNetModel`." + ) + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if self.controlnet.__class__.__name__ == "NeuronMultiControlNetModel": + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Adapted from https://github.com/huggingface/diffusers/blob/v0.30.0/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py#L899 + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Adapted from https://github.com/huggingface/diffusers/blob/1f81fbe274e67c843283e69eb8f00bb56f75ffc4/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py#L1001 + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + sigmas: List[float] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`Optional[Union[str, List[str]]]`, defaults to `None`): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + image (`Optional["PipelineImageInput"]`, defaults to `None`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + num_inference_steps (`int`, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`Optional[List[int]]`, defaults to `None`): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`Optional[List[int]]`, defaults to `None`): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + denoising_end (`Optional[float]`, defaults to `None`): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`Optional[Union[str, List[str]]]`, defaults to `None`): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, defaults to 1): + The number of images to generate per prompt. + eta (`float`, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`diffusers.schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`Optional[Union[torch.Generator, List[torch.Generator]]]`, defaults to `None`): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`Optional[torch.Tensor]`, defaults to `None`): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`Optional[torch.Tensor]`, defaults to `None`): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`Optional[torch.Tensor]`, defaults to `None`): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`Optional[torch.Tensor]`, defaults to `None`): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`Optional[torch.Tensor]`, defaults to `None`): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + ip_adapter_image: (`Optional[PipelineImageInput]`, defaults to `None`): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`Optional[List[torch.Tensor]]`, defaults to `None`): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`Optional[str]`, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`Union[float, List[float]]`, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`Union[float, List[float]]`, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`Union[float, List[float]]`, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + original_size (`Optional[Tuple[int, int]]`, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int, int]`, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Optional[Tuple[int, int]]`, defaults to `None`): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_original_size (`Optional[Tuple[int, int]]`, defaults to `None`): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int, int]`, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Optional[Tuple[int, int]]`, defaults to `None`): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`Optional[int]`, defaults to `None`): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Optional[Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]]`, defaults to `None`): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List[str]`, defaults to `["latents"]`): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned containing the output images. + """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if controlnet.__class__.__name__ == "NeuronMultiControlNetModel" else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + prompt_2=prompt_2, + image=image, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + ip_adapter_image=ip_adapter_image, + ip_adapter_image_embeds=ip_adapter_image_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + controlnet_conditioning_scale=controlnet_conditioning_scale, + control_guidance_start=control_guidance_start, + control_guidance_end=control_guidance_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = torch.tensor([controlnet_conditioning_scale]) + if controlnet.__class__.__name__ == "NeuronMultiControlNetModel": + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if controlnet.__class__.__name__ == "NeuronControlNetModel" + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + # TODO: Remove after the guess mode of ControlNet is supported + if guess_mode: + logger.info("Disabling the guess mode as this is not supported yet.") + guess_mode = False + + # 3.1 Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + lora_scale = None + do_classifier_free_guidance = guidance_scale > 1.0 and ( + self.dynamic_batch_size or self.data_parallel_mode == "unet" + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.2 Encode ip_adapter_image + # TODO: support ip adapter + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + logger.info( + "IP adapter is not supported yet, `ip_adapter_image` and `ip_adapter_image_embeds` will be ignored." + ) + + # 4. Prepare image + height = self.vae_encoder.config.neuron["static_height"] + width = self.vae_encoder.config.neuron["static_width"] + if controlnet.__class__.__name__ == "NeuronControlNetModel": + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=None, + dtype=None, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif controlnet.__class__.__name__ == "NeuronMultiControlNetModel": + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=None, + dtype=None, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + scheduler=self.scheduler, + num_inference_steps=num_inference_steps, + device=None, + timesteps=timesteps, + sigmas=sigmas, + ) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + generator, + latents, + ) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=None, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if controlnet.__class__.__name__ == "NeuronControlNetModel" else keeps) + + # 7.2 Prepare added time ids & embeddings + if isinstance(image, list): + original_size = original_size or image[0].shape[-2:] + else: + original_size = original_size or image.shape[-2:] + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + controlnet_added_cond_kwargs = { + "text_embeds": add_text_embeds.chunk(2)[1], + "time_ids": add_time_ids.chunk(2)[1], + } + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + controlnet_added_cond_kwargs = copy.deepcopy(added_cond_kwargs) + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + # Duplicate inputs for ddp + t = torch.tensor([t] * 2) if self.data_parallel_mode == "unet" else t + cond_scale = ( + torch.tensor([cond_scale]).repeat(2) + if self.data_parallel_mode == "unet" + else torch.tensor(cond_scale) + ) + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=controlnet_added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + logger.info( + "IP adapter is not supported yet, `ip_adapter_image` and `ip_adapter_image_embeds` will be ignored." + ) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = ( + hasattr(self.vae_decoder.config, "latents_mean") and self.vae_decoder.config.latents_mean is not None + ) + has_latents_std = ( + hasattr(self.vae_decoder.config, "latents_std") and self.vae_decoder.config.latents_std is not None + ) + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = ( + latents * latents_std / getattr(self.vae_decoder.config, "scaling_factor", 0.18215) + latents_mean + ) + else: + latents = latents / getattr(self.vae_decoder.config, "scaling_factor", 0.18215) + + image = self.vae_decoder(latents)[0] + + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image)