From 2c524dfa392cef756f3625688f45a00299c424db Mon Sep 17 00:00:00 2001 From: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com> Date: Fri, 28 Jun 2024 13:54:07 +0200 Subject: [PATCH] Add Stable Diffusion ControlNet support (#622) * placeholders * export support poc * wrapup export * finish modeling * pipeline done * add compiler args for controlnet export * add test * update setup * fix * update setup * add doc * remove changes for debug * correct comment * address comments * Update docs/source/tutorials/stable_diffusion.mdx Co-authored-by: Michael Benayoun * Update docs/source/tutorials/stable_diffusion.mdx Co-authored-by: Michael Benayoun * Update optimum/commands/export/neuronx.py Co-authored-by: Michael Benayoun * apply suggestions * limit diffusers version, 0.29 incompatible * fix ruff check --------- Co-authored-by: David Corvoysier Co-authored-by: Michael Benayoun --- .github/workflows/check_code_quality.yml | 2 +- docs/source/package_reference/modeling.mdx | 5 + docs/source/tutorials/stable_diffusion.mdx | 85 +++ optimum/commands/export/neuronx.py | 7 + optimum/exporters/neuron/__main__.py | 45 +- optimum/exporters/neuron/base.py | 29 + optimum/exporters/neuron/convert.py | 78 ++- optimum/exporters/neuron/model_configs.py | 69 ++- optimum/exporters/neuron/model_wrappers.py | 51 ++ optimum/exporters/neuron/utils.py | 46 +- optimum/neuron/__init__.py | 4 + optimum/neuron/modeling_diffusion.py | 353 +++++++++-- optimum/neuron/pipelines/__init__.py | 4 + .../neuron/pipelines/diffusers/__init__.py | 2 + .../diffusers/pipeline_controlnet.py | 580 ++++++++++++++++++ .../diffusers/pipeline_controlnet_sd_xl.py | 22 + optimum/neuron/utils/__init__.py | 6 +- optimum/neuron/utils/constant.py | 1 + optimum/neuron/utils/input_generators.py | 90 +++ optimum/neuron/utils/misc.py | 1 - setup.py | 5 +- tests/cli/test_export_cli.py | 34 + .../test_stable_diffusion_pipeline.py | 45 +- 23 files changed, 1461 insertions(+), 103 deletions(-) create mode 100644 optimum/neuron/pipelines/diffusers/pipeline_controlnet.py create mode 100644 optimum/neuron/pipelines/diffusers/pipeline_controlnet_sd_xl.py diff --git a/.github/workflows/check_code_quality.yml b/.github/workflows/check_code_quality.yml index de139d476..b3300baf2 100644 --- a/.github/workflows/check_code_quality.yml +++ b/.github/workflows/check_code_quality.yml @@ -52,4 +52,4 @@ jobs: - name: Check style with ruff run: | source venv/bin/activate - ruff . + ruff check . diff --git a/docs/source/package_reference/modeling.mdx b/docs/source/package_reference/modeling.mdx index 119c1fdd9..819c68082 100644 --- a/docs/source/package_reference/modeling.mdx +++ b/docs/source/package_reference/modeling.mdx @@ -106,6 +106,11 @@ The following Neuron model classes are available for stable diffusion tasks. [[autodoc]] modeling_diffusion.NeuronLatentConsistencyModelPipeline - __call__ +### NeuronStableDiffusionControlNetPipeline + +[[autodoc]] modeling_diffusion.NeuronStableDiffusionControlNetPipeline + - __call__ + ### NeuronStableDiffusionXLPipeline [[autodoc]] modeling_diffusion.NeuronStableDiffusionXLPipeline diff --git a/docs/source/tutorials/stable_diffusion.mdx b/docs/source/tutorials/stable_diffusion.mdx index b67841e29..648fde3ff 100644 --- a/docs/source/tutorials/stable_diffusion.mdx +++ b/docs/source/tutorials/stable_diffusion.mdx @@ -516,4 +516,89 @@ image = pipe(prompt, num_inference_steps=4, guidance_scale=0).images[0] alt="stable diffusion generated image with LoRA adapter." /> + +## ControlNet + +ControlNet conditions the stable diffusion model with an additional input image. In Optimum Neuron, we support the compilation of one or multiple ControlNet(s) along with the stable diffusion checkpoint. The you can use the compiled artifacts to generate styled images. + +### Compile ControlNet + +We can either compile one or multiple ControlNet via the Optimum CLI or programatically via the `NeuronStableDiffusionControlNetPipeline` class by passing the `controlnet_ids`. + +* Export via the Optimum CLI + +```bash +optimum-cli export neuron -m runwayml/stable-diffusion-v1-5 --task stable-diffusion --batch_size 1 --height 512 --width 512 --controlnet_ids lllyasviel/sd-controlnet-canny --num_images_per_prompt 1 sd_neuron_controlnet/ +``` + +* Export via Python API + +```python +from optimum.neuron import NeuronStableDiffusionControlNetPipeline + +model_id = "runwayml/stable-diffusion-v1-5" +controlnet_id = "lllyasviel/sd-controlnet-canny" + +# [Neuron] pipeline +input_shapes = {"batch_size": 1, "height": 512, "width": 512, "num_images_per_prompt": 1} +compiler_args = {"auto_cast": "matmul", "auto_cast_type": "bf16"} +pipe = NeuronStableDiffusionControlNetPipeline.from_pretrained( + model_id, + controlnet_ids=controlnet_id, + export=True, + **input_shapes, + **compiler_args, +) +pipe.save_pretrained("sd_neuron_controlnet") +``` + +### Text-to-Image + +For text-to-image, we can specify an additional conditioning input. + +Here is an example with a canny image, a white outline of an image on a black background. The ControlNet will use the canny image as a control to guide the model to generate an image with the same outline. + +```python +import cv2 +import numpy as np +from diffusers import UniPCMultistepScheduler +from diffusers.utils import load_image, make_image_grid +from PIL import Image + +from optimum.neuron import NeuronStableDiffusionControlNetPipeline + + +# prepare canny image +original_image = load_image( + "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" +) + +image = np.array(original_image) + +low_threshold = 100 +high_threshold = 200 + +image = cv2.Canny(image, low_threshold, high_threshold) +image = image[:, :, None] +image = np.concatenate([image, image, image], axis=2) +canny_image = Image.fromarray(image) + +# load pre-compiled neuron model +pipe = NeuronStableDiffusionControlNetPipeline.from_pretrained("sd_neuron_controlnet") +pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + +# inference +output = pipe("the mona lisa", image=canny_image).images[0] +compare = make_image_grid([original_image, canny_image, output], rows=1, cols=3) +compare.save("compare.png") +``` + +stable diffusion 1.5 generated image with controlnet. + + Are there any other stable diffusion features that you want us to support in 🤗`Optimum-neuron`? Please file an issue to [`Optimum-neuron` Github repo](https://github.com/huggingface/optimum-neuron) or discuss with us on [HuggingFace’s community forum](https://discuss.huggingface.co/c/optimum/), cheers 🤗 ! diff --git a/optimum/commands/export/neuronx.py b/optimum/commands/export/neuronx.py index cfd1564dd..fd717ce1f 100644 --- a/optimum/commands/export/neuronx.py +++ b/optimum/commands/export/neuronx.py @@ -172,6 +172,13 @@ def parse_args_neuronx(parser: "ArgumentParser"): type=float, help="List of scaling factors for the lora adapters.", ) + optional_group.add_argument( + "--controlnet_ids", + default=None, + nargs="*", + type=str, + help="List of model ids (eg. `thibaud/controlnet-openpose-sdxl-1.0`) of ControlNet models.", + ) optional_group.add_argument( "--output_attentions", action="store_true", diff --git a/optimum/exporters/neuron/__main__.py b/optimum/exporters/neuron/__main__.py index 833485000..74b4d1cf1 100644 --- a/optimum/exporters/neuron/__main__.py +++ b/optimum/exporters/neuron/__main__.py @@ -26,6 +26,7 @@ from ...neuron.utils import ( DECODER_NAME, + DIFFUSION_MODEL_CONTROLNET_NAME, DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, DIFFUSION_MODEL_TEXT_ENCODER_NAME, DIFFUSION_MODEL_UNET_NAME, @@ -51,6 +52,7 @@ check_mandatory_input_shapes, get_encoder_decoder_models_for_export, get_stable_diffusion_models_for_export, + load_controlnets, replace_stable_diffusion_submodels, ) @@ -74,7 +76,7 @@ from transformers import PreTrainedModel if is_diffusers_available(): - from diffusers import DiffusionPipeline, ModelMixin, StableDiffusionPipeline + from diffusers import ControlNetModel, DiffusionPipeline, ModelMixin, StableDiffusionPipeline logger = logging.get_logger() @@ -205,6 +207,7 @@ def normalize_stable_diffusion_input_shapes( def infer_stable_diffusion_shapes_from_diffusers( input_shapes: Dict[str, Dict[str, int]], model: Union["StableDiffusionPipeline", "StableDiffusionXLPipeline"], + controlnets: Optional[List["ControlNetModel"]] = None, ): if model.tokenizer is not None: sequence_length = model.tokenizer.model_max_length @@ -232,11 +235,24 @@ def infer_stable_diffusion_shapes_from_diffusers( "width": scaled_width, } ) + input_shapes["unet"]["vae_scale_factor"] = vae_scale_factor input_shapes["vae_encoder"].update({"num_channels": vae_encoder_num_channels, "height": height, "width": width}) input_shapes["vae_decoder"].update( {"num_channels": vae_decoder_num_channels, "height": scaled_height, "width": scaled_width} ) + # ControlNet + if controlnets: + input_shapes["controlnet"] = { + "batch_size": input_shapes["unet"]["batch_size"], + "sequence_length": sequence_length, + "num_channels": unet_num_channels, + "height": scaled_height, + "width": scaled_width, + "vae_scale_factor": vae_scale_factor, + "encoder_hidden_size": model.text_encoder.config.hidden_size, + } + return input_shapes @@ -256,6 +272,7 @@ def get_submodels_and_neuron_configs( lora_weight_names: Optional[Union[str, List[str]]] = None, lora_adapter_names: Optional[Union[str, List[str]]] = None, lora_scales: Optional[Union[float, List[float]]] = None, + controlnets: Optional[List["ControlNetModel"]] = None, ): is_stable_diffusion = "stable-diffusion" in task is_encoder_decoder = ( @@ -278,6 +295,7 @@ def get_submodels_and_neuron_configs( lora_weight_names=lora_weight_names, lora_adapter_names=lora_adapter_names, lora_scales=lora_scales, + controlnets=controlnets, ) elif is_encoder_decoder: optional_outputs = {"output_attentions": output_attentions, "output_hidden_states": output_hidden_states} @@ -338,6 +356,7 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion( lora_weight_names: Optional[Union[str, List[str]]] = None, lora_adapter_names: Optional[Union[str, List[str]]] = None, lora_scales: Optional[Union[float, List[float]]] = None, + controlnets: Optional[List["ControlNetModel"]] = None, ): check_compiler_compatibility_for_stable_diffusion() model = replace_stable_diffusion_submodels(model, submodels) @@ -345,7 +364,11 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion( raise RuntimeError( "Stable diffusion export is not supported by neuron-cc on inf1, please use neuronx-cc on either inf2/trn1 instead." ) - input_shapes = infer_stable_diffusion_shapes_from_diffusers(input_shapes, model) + input_shapes = infer_stable_diffusion_shapes_from_diffusers( + input_shapes=input_shapes, + model=model, + controlnets=controlnets, + ) # Saving the model config and preprocessor as this is needed sometimes. model.scheduler.save_pretrained(output.joinpath("scheduler")) @@ -373,6 +396,8 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion( lora_weight_names=lora_weight_names, lora_adapter_names=lora_adapter_names, lora_scales=lora_scales, + controlnets=controlnets, + controlnet_input_shapes=input_shapes.get("controlnet", None), ) output_model_names = { DIFFUSION_MODEL_UNET_NAME: os.path.join(DIFFUSION_MODEL_UNET_NAME, NEURON_FILE_NAME), @@ -387,7 +412,15 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion( output_model_names[DIFFUSION_MODEL_TEXT_ENCODER_2_NAME] = os.path.join( DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, NEURON_FILE_NAME ) + + # ControlNet models + if controlnets: + for idx in range(len(controlnets)): + controlnet_name = DIFFUSION_MODEL_CONTROLNET_NAME + "_" + str(idx) + output_model_names[controlnet_name] = os.path.join(controlnet_name, NEURON_FILE_NAME) + del model + del controlnets return models_and_neuron_configs, output_model_names @@ -442,6 +475,7 @@ def load_models_and_neuron_configs( lora_weight_names: Optional[Union[str, List[str]]], lora_adapter_names: Optional[Union[str, List[str]]], lora_scales: Optional[Union[float, List[float]]], + controlnet_ids: Optional[Union[str, List[str]]], output_attentions: bool = False, output_hidden_states: bool = False, library_name: Optional[str] = None, @@ -466,6 +500,7 @@ def load_models_and_neuron_configs( } if model is None: model = TasksManager.get_model_from_task(**model_kwargs) + controlnets = load_controlnets(controlnet_ids) models_and_neuron_configs, output_model_names = get_submodels_and_neuron_configs( model=model, @@ -483,6 +518,7 @@ def load_models_and_neuron_configs( lora_weight_names=lora_weight_names, lora_adapter_names=lora_adapter_names, lora_scales=lora_scales, + controlnets=controlnets, ) return models_and_neuron_configs, output_model_names @@ -516,6 +552,7 @@ def main_export( lora_weight_names: Optional[Union[str, List[str]]] = None, lora_adapter_names: Optional[Union[str, List[str]]] = None, lora_scales: Optional[Union[float, List[float]]] = None, + controlnet_ids: Optional[Union[str, List[str]]] = None, **input_shapes, ): output = Path(output) @@ -545,6 +582,7 @@ def main_export( lora_weight_names=lora_weight_names, lora_adapter_names=lora_adapter_names, lora_scales=lora_scales, + controlnet_ids=controlnet_ids, **input_shapes, ) @@ -565,7 +603,7 @@ def main_export( is_stable_diffusion = "stable-diffusion" in task if is_stable_diffusion: # Do not validate vae encoder due to the sampling randomness - del neuron_outputs[-2] # -2 is the index of `vae_encoder` + neuron_outputs.pop("vae_encoder") models_and_neuron_configs.pop("vae_encoder", None) output_model_names.pop("vae_encoder", None) @@ -687,6 +725,7 @@ def main(): lora_weight_names=getattr(args, "lora_weight_names", None), lora_adapter_names=getattr(args, "lora_adapter_names", None), lora_scales=getattr(args, "lora_scales", None), + controlnet_ids=getattr(args, "controlnet_ids", None), **optional_outputs, **input_shapes, ) diff --git a/optimum/exporters/neuron/base.py b/optimum/exporters/neuron/base.py index 6d5451537..a5cd59a9d 100644 --- a/optimum/exporters/neuron/base.py +++ b/optimum/exporters/neuron/base.py @@ -15,6 +15,7 @@ """Neuron configuration base classes.""" import importlib +import re from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union @@ -162,6 +163,8 @@ def __init__( point_batch_size: Optional[int] = None, nb_points_per_image: Optional[int] = None, num_beams: Optional[int] = None, + vae_scale_factor: Optional[int] = None, + encoder_hidden_size: Optional[int] = None, output_attentions: bool = False, output_hidden_states: bool = False, # TODO: add custom dtype after optimum 1.13 release @@ -197,6 +200,8 @@ def __init__( "num_beams": num_beams, "image_size": image_size or getattr(self._config, "image_size", None), "patch_size": patch_size or getattr(self._config, "patch_size", None), + "vae_scale_factor": vae_scale_factor, + "encoder_hidden_size": encoder_hidden_size, } input_shapes = {} for name, value in axes_values.items(): @@ -331,6 +336,30 @@ def flatten_inputs(cls, inputs: Dict[str, Any]) -> Dict[str, Any]: flatten[name] = value return flatten + @classmethod + def unflatten_inputs(cls, inputs: Dict[str, Any]) -> Dict[str, Any]: + """ + Re-construct inputs that have been flatten for tracing. + """ + unflatten = {} + to_group = {} + for name, value in inputs.items(): + name_with_idx = re.findall(r"(.*?)_(\d+)", name) + if len(name_with_idx) > 0: + if name_with_idx[0][0] in to_group: + to_group[name_with_idx[0][0]].append((int(name_with_idx[0][1]), value)) + else: + to_group[name_with_idx[0][0]] = [(int(name_with_idx[0][1]), value)] + else: + unflatten[name] = value + + if to_group: + for name, values in to_group.items(): + ordered = sorted(values, key=lambda x: x[0]) + unflatten[name] = tuple([item[1] for item in ordered]) + + return unflatten + def patch_model_for_export( self, model: "PreTrainedModel", diff --git a/optimum/exporters/neuron/convert.py b/optimum/exporters/neuron/convert.py index c02e14578..810bbbedf 100644 --- a/optimum/exporters/neuron/convert.py +++ b/optimum/exporters/neuron/convert.py @@ -75,7 +75,7 @@ def validate_models_outputs( models_and_neuron_configs: Dict[ str, Tuple[Union["PreTrainedModel", "ModelMixin", torch.nn.Module], "NeuronDefaultConfig"] ], - neuron_named_outputs: List[List[str]], + neuron_named_outputs: Dict[str, List[str]], output_dir: Path, atol: Optional[float] = None, neuron_files_subpaths: Optional[Dict[str, str]] = None, @@ -127,7 +127,7 @@ def validate_models_outputs( config=sub_neuron_config, reference_model=ref_submodel, neuron_model_path=neuron_model_path, - neuron_named_outputs=neuron_named_outputs[i], + neuron_named_outputs=neuron_named_outputs[model_name], atol=atol, ) except Exception as e: @@ -179,23 +179,24 @@ def validate_model_outputs( # Reference outputs with torch.no_grad(): reference_model.eval() - ref_inputs = config.generate_dummy_inputs(return_tuple=False, **input_shapes) + inputs = config.generate_dummy_inputs(return_tuple=False, **input_shapes) + ref_inputs = config.unflatten_inputs(inputs) if hasattr(reference_model, "config") and getattr(reference_model.config, "is_encoder_decoder", False): reference_model = config.patch_model_for_export(reference_model, device="cpu", **input_shapes) if "SentenceTransformer" in reference_model.__class__.__name__: reference_model = config.patch_model_for_export(reference_model, ref_inputs) ref_outputs = reference_model(**ref_inputs) - neuron_inputs = tuple(config.flatten_inputs(ref_inputs).values()) + neuron_inputs = tuple(config.flatten_inputs(inputs).values()) elif "AutoencoderKL" in getattr(config._config, "_class_name", "") or getattr( reference_model.config, "is_encoder_decoder", False ): # VAE components for stable diffusion or Encoder-Decoder models ref_inputs = tuple(ref_inputs.values()) ref_outputs = reference_model(*ref_inputs) - neuron_inputs = ref_inputs + neuron_inputs = tuple(inputs.values()) else: ref_outputs = reference_model(**ref_inputs) - neuron_inputs = tuple(config.flatten_inputs(ref_inputs).values()) + neuron_inputs = tuple(config.flatten_inputs(inputs).values()) # Neuron outputs neuron_model = torch.jit.load(neuron_model_path) @@ -236,30 +237,38 @@ def validate_model_outputs( # Check the shape and values match shape_failures = [] value_failures = [] - for i, (name, output) in enumerate(zip(neuron_output_names_list, neuron_outputs)): - if isinstance(output, torch.Tensor): + for i, (name, neuron_output) in enumerate(zip(neuron_output_names_list, neuron_outputs)): + if isinstance(neuron_output, torch.Tensor): ref_output = ref_outputs[name].numpy() if isinstance(ref_outputs, dict) else ref_outputs[i].numpy() - output = output.numpy() - elif isinstance(output, tuple): # eg. `hidden_states` of `AutoencoderKL` is a tuple of tensors. + neuron_output = neuron_output.numpy() + elif isinstance(neuron_output, tuple): # eg. `hidden_states` of `AutoencoderKL` is a tuple of tensors; ref_output = torch.stack(ref_outputs[name]).numpy() - output = torch.stack(output).numpy() + neuron_output = torch.stack(neuron_output).numpy() + elif isinstance(neuron_output, list): + ref_output = [output.numpy() for output in ref_outputs[name]] + neuron_output = [output.numpy() for output in neuron_output] logger.info(f'\t- Validating Neuron Model output "{name}":') # Shape - if not output.shape == ref_output.shape: - logger.error(f"\t\t-[x] shape {output.shape} doesn't match {ref_output.shape}") - shape_failures.append((name, ref_output.shape, output.shape)) - else: - logger.info(f"\t\t-[✓] {output.shape} matches {ref_output.shape}") + output_list = ( + neuron_output if isinstance(neuron_output, list) else [neuron_output] + ) # eg. `down_block_res_samples` of `ControlNet` is a list of tensors. + ref_output_list = ref_output if isinstance(ref_output, list) else [ref_output] + for output, ref_output in zip(output_list, ref_output_list): + if not output.shape == ref_output.shape: + logger.error(f"\t\t-[x] shape {output.shape} doesn't match {ref_output.shape}") + shape_failures.append((name, ref_output.shape, output.shape)) + else: + logger.info(f"\t\t-[✓] {output.shape} matches {ref_output.shape}") - # Values - if not np.allclose(ref_output, output, atol=atol): - max_diff = np.amax(np.abs(ref_output - output)) - logger.error(f"\t\t-[x] values not close enough, max diff: {max_diff} (atol: {atol})") - value_failures.append((name, max_diff)) - else: - logger.info(f"\t\t-[✓] all values close (atol: {atol})") + # Values + if not np.allclose(ref_output, output, atol=atol): + max_diff = np.amax(np.abs(ref_output - output)) + logger.error(f"\t\t-[x] values not close enough, max diff: {max_diff} (atol: {atol})") + value_failures.append((name, max_diff)) + else: + logger.info(f"\t\t-[✓] all values close (atol: {atol})") if shape_failures: msg = "\n".join(f"- {t[0]}: got {t[1]} (reference) and {t[2]} (neuron)" for t in shape_failures) @@ -286,7 +295,7 @@ def export_models( compiler_kwargs: Optional[Dict[str, Any]] = {}, configs: Optional[Dict[str, Any]] = {}, model_name_or_path: Optional[str] = None, -) -> Tuple[List[List[str]], List[List[str]]]: +) -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]: """ Exports a Pytorch model with multiple component models to separate files. @@ -316,10 +325,11 @@ def export_models( model_name_or_path (`Optional[str]`, defaults to `None`): Path to pretrained model or model identifier from the Hugging Face Hub. Returns: - `Tuple[List[List[str]], List[List[str]]]`: A tuple with an ordered list of the model's inputs, and the named + `Tuple[Dict[str, List[str]], Dict[str, List[str]]]`: A tuple with two dictionaries containing ordered list of the model's inputs, and the named outputs from the Neuron configuration. """ - outputs = [] + all_inputs = {} + all_outputs = {} if compiler_workdir is not None: compiler_workdir = Path(compiler_workdir) @@ -362,7 +372,8 @@ def export_models( compilation_time = time.time() - start_time total_compilation_time += compilation_time logger.info(f"[Compilation Time] {np.round(compilation_time, 2)} seconds.") - outputs.append((neuron_inputs, neuron_outputs)) + all_inputs[model_name] = neuron_inputs + all_outputs[model_name] = neuron_outputs # Add neuron specific configs to model components' original config if hasattr(submodel, "config"): model_config = submodel.config @@ -415,8 +426,7 @@ def export_models( output_file_names.pop(model_name) models_and_neuron_configs.pop(model_name) - outputs = list(map(list, zip(*outputs))) - return outputs + return all_inputs, all_outputs def export( @@ -581,11 +591,11 @@ def add_stable_diffusion_compiler_args(config, compiler_args): identifier = getattr(config._config, "_name_or_path", "") + " " + getattr(config._config, "_class_name", "") identifier = identifier.lower() - sd_components = ["text_encoder", "vae", "vae_encoder", "vae_decoder"] + sd_components = ["text_encoder", "vae", "vae_encoder", "vae_decoder", "controlnet"] if any(component in identifier for component in sd_components): compiler_args.append("--enable-fast-loading-neuron-binaries") - # unet - if "unet" in identifier: + # unet or controlnet + if "unet" in identifier or "controlnet" in identifier: # SDXL unet doesn't support fast loading neuron binaries if not getattr(config, "is_sdxl", False): compiler_args.append("--enable-fast-loading-neuron-binaries") @@ -597,11 +607,11 @@ def improve_stable_diffusion_loading(config, neuron_model): # Combine the model name and its path to identify which is the subcomponent in Stable Diffusion pipeline identifier = getattr(config._config, "_name_or_path", "") + " " + getattr(config._config, "_class_name", "") identifier = identifier.lower() - sd_components = ["text_encoder", "unet", "vae", "vae_encoder", "vae_decoder"] + sd_components = ["text_encoder", "unet", "vae", "vae_encoder", "vae_decoder", "controlnet"] if any(component in identifier for component in sd_components): neuronx.async_load(neuron_model) # unet - if "unet" in identifier: + if "unet" in identifier or "controlnet" in identifier: neuronx.lazy_load(neuron_model) diff --git a/optimum/exporters/neuron/model_configs.py b/optimum/exporters/neuron/model_configs.py index b622d7634..f65792853 100644 --- a/optimum/exporters/neuron/model_configs.py +++ b/optimum/exporters/neuron/model_configs.py @@ -19,7 +19,7 @@ import torch -from ...neuron.utils import DummyBeamValuesGenerator, DummyMaskedPosGenerator +from ...neuron.utils import DummyBeamValuesGenerator, DummyControNetInputGenerator, DummyMaskedPosGenerator from ...utils import ( DummyInputGenerator, DummySeq2SeqDecoderTextInputGenerator, @@ -43,6 +43,7 @@ VisionNeuronConfig, ) from .model_wrappers import ( + ControlNetNeuronWrapper, NoCacheModelWrapper, SentenceTransformersCLIPNeuronWrapper, SentenceTransformersTransformerNeuronWrapper, @@ -404,7 +405,7 @@ def outputs(self) -> List[str]: @register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers") class UNetNeuronConfig(VisionNeuronConfig): ATOL_FOR_VALIDATION = 1e-3 - INPUT_ARGS = ("batch_size", "sequence_length", "num_channels", "width", "height") + INPUT_ARGS = ("batch_size", "sequence_length", "num_channels", "width", "height", "vae_scale_factor") MODEL_TYPE = "unet" CUSTOM_MODEL_WRAPPER = UnetNeuronWrapper NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( @@ -420,6 +421,7 @@ class UNetNeuronConfig(VisionNeuronConfig): DummyVisionInputGenerator, DummyTimestepInputGenerator, DummySeq2SeqDecoderTextInputGenerator, + DummyControNetInputGenerator, ) @property @@ -434,6 +436,10 @@ def inputs(self) -> List[str]: if getattr(self._normalized_config, "time_cond_proj_dim", None) is not None: common_inputs.append("timestep_cond") + if self.with_controlnet: + # outputs of controlnet + common_inputs += ["down_block_additional_residuals", "mid_block_additional_residual"] + return common_inputs @property @@ -445,6 +451,16 @@ def generate_dummy_inputs(self, return_tuple: bool = False, **kwargs): dummy_inputs["timestep"] = dummy_inputs["timestep"].float() dummy_inputs["encoder_hidden_states"] = dummy_inputs["encoder_hidden_states"][0] + # break down down_block_additional_residuals + num_down_block_outputs = len(self._normalized_config.down_block_types) * ( + self._normalized_config.layers_per_block + 1 + ) + down_block_additional_residuals = dummy_inputs.pop("down_block_additional_residuals", None) + + if down_block_additional_residuals: + for idx in range(num_down_block_outputs): + dummy_inputs[f"down_block_additional_residuals_{idx}"] = down_block_additional_residuals[idx] + if getattr(self._normalized_config, "addition_embed_type", None) == "text_time": dummy_inputs["added_cond_kwargs"] = { "text_embeds": dummy_inputs.pop("text_embeds"), @@ -467,6 +483,55 @@ def is_sdxl(self) -> bool: def is_sdxl(self, is_sdxl: bool): self._is_sdxl = is_sdxl + @property + def with_controlnet(self) -> bool: + return self._with_controlnet + + @with_controlnet.setter + def with_controlnet(self, with_controlnet: bool): + self._with_controlnet = with_controlnet + + +@register_in_tasks_manager("controlnet", *["semantic-segmentation"], library_name="diffusers") +class ControlNetNeuronConfig(VisionNeuronConfig): + ATOL_FOR_VALIDATION = 1e-3 + INPUT_ARGS = ( + "batch_size", + "sequence_length", + "num_channels", + "height", + "width", + "vae_scale_factor", + "encoder_hidden_size", + ) + MODEL_TYPE = "controlnet" + CUSTOM_MODEL_WRAPPER = ControlNetNeuronWrapper + NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args( + height="height", + width="width", + num_channels="in_channels", + hidden_size="cross_attention_dim", + vocab_size="norm_num_groups", + allow_new=True, + ) + + DUMMY_INPUT_GENERATOR_CLASSES = ( + DummyVisionInputGenerator, + DummyControNetInputGenerator, + ) + + @property + def inputs(self) -> List[str]: + common_inputs = ["sample", "timestep", "encoder_hidden_states", "controlnet_cond", "conditioning_scale"] + return common_inputs + + @property + def outputs(self) -> List[str]: + return ["down_block_res_samples", "mid_block_res_sample"] + + def patch_model_for_export(self, model, dummy_inputs): + return self.CUSTOM_MODEL_WRAPPER(model, list(dummy_inputs.keys())) + @register_in_tasks_manager("vae-encoder", *["semantic-segmentation"], library_name="diffusers") class VaeEncoderNeuronConfig(VisionNeuronConfig): diff --git a/optimum/exporters/neuron/model_wrappers.py b/optimum/exporters/neuron/model_wrappers.py index 261701c76..8a47c779e 100644 --- a/optimum/exporters/neuron/model_wrappers.py +++ b/optimum/exporters/neuron/model_wrappers.py @@ -45,12 +45,63 @@ def forward(self, *inputs): } sample = ordered_inputs.pop("sample", None) timestep = ordered_inputs.pop("timestep").float().expand((sample.shape[0],)) + encoder_hidden_states = ordered_inputs.pop("encoder_hidden_states", None) + + # Re-build down_block_additional_residual + down_block_additional_residuals = () + down_block_additional_residuals_names = [ + name for name in ordered_inputs.keys() if "down_block_additional_residuals" in name + ] + for name in down_block_additional_residuals_names: + value = ordered_inputs.pop(name) + down_block_additional_residuals += (value,) + + mid_block_additional_residual = ordered_inputs.pop("mid_block_additional_residual", None) out_tuple = self.model( sample=sample, timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + down_block_additional_residuals=( + down_block_additional_residuals if down_block_additional_residuals else None + ), + mid_block_additional_residual=mid_block_additional_residual, added_cond_kwargs=added_cond_kwargs, return_dict=False, + ) + + return out_tuple + + +class ControlNetNeuronWrapper(torch.nn.Module): + def __init__(self, model, input_names: List[str]): + super().__init__() + self.model = model + self.input_names = input_names + + def forward(self, *inputs): + if len(inputs) != len(self.input_names): + raise ValueError( + f"The model needs {len(self.input_names)} inputs: {self.input_names}." + f" But only {len(input)} inputs are passed." + ) + + ordered_inputs = dict(zip(self.input_names, inputs)) + + sample = ordered_inputs.pop("sample", None) + timestep = ordered_inputs.pop("timestep", None) + encoder_hidden_states = ordered_inputs.pop("encoder_hidden_states", None) + controlnet_cond = ordered_inputs.pop("controlnet_cond", None) + conditioning_scale = ordered_inputs.pop("conditioning_scale", None) + + out_tuple = self.model( + sample=sample, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=controlnet_cond, + conditioning_scale=conditioning_scale, + guess_mode=False, # TODO: support guess mode of ControlNet + return_dict=False, **ordered_inputs, ) diff --git a/optimum/exporters/neuron/utils.py b/optimum/exporters/neuron/utils.py index e3628935f..2a8d3df22 100644 --- a/optimum/exporters/neuron/utils.py +++ b/optimum/exporters/neuron/utils.py @@ -23,6 +23,7 @@ from ...neuron.utils import ( DECODER_NAME, + DIFFUSION_MODEL_CONTROLNET_NAME, DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, DIFFUSION_MODEL_TEXT_ENCODER_NAME, DIFFUSION_MODEL_UNET_NAME, @@ -52,7 +53,7 @@ f"We found an older version of diffusers {_diffusers_version} but we require diffusers to be >= {DIFFUSERS_MINIMUM_VERSION}. " "Please update diffusers by running `pip install --upgrade diffusers`" ) - from diffusers import UNet2DConditionModel + from diffusers import ControlNetModel, UNet2DConditionModel from diffusers.models.attention_processor import ( Attention, AttnProcessor, @@ -122,6 +123,8 @@ def get_stable_diffusion_models_for_export( lora_weight_names: Optional[List[str]] = None, lora_adapter_names: Optional[List[str]] = None, lora_scales: Optional[List[float]] = None, + controlnets: Optional[List["ControlNetModel"]] = None, + controlnet_input_shapes: Optional[Dict[str, int]] = None, ) -> Dict[str, Tuple[Union["PreTrainedModel", "ModelMixin"], "NeuronDefaultConfig"]]: """ Returns the components of a Stable Diffusion model and their subsequent neuron configs. @@ -154,6 +157,11 @@ def get_stable_diffusion_models_for_export( List of adapter names to be used for referencing the loaded adapter models. lora_scales (`Optional[List[float]]`, defaults to `None`): List of scaling factors for lora adapters. + controlnets (`Optional[List["ControlNetModel"]]]`, defaults to `None`): + One or multiple ControlNets providing additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. + controlnet_input_shapes (`Optional[Dict[str, int]]`, defaults to `None`): + Static shapes used for compiling ControlNets. Returns: `Dict[str, Tuple[Union[`PreTrainedModel`, `ModelMixin`], `NeuronDefaultConfig`]`: A Dict containing the model and @@ -222,6 +230,9 @@ def get_stable_diffusion_models_for_export( ) if task == "stable-diffusion-xl": unet_neuron_config.is_sdxl = True + + unet_neuron_config.with_controlnet = True if controlnets else False + models_for_export[DIFFUSION_MODEL_UNET_NAME] = (unet, unet_neuron_config) # VAE Encoder @@ -258,6 +269,27 @@ def get_stable_diffusion_models_for_export( ) models_for_export[DIFFUSION_MODEL_VAE_DECODER_NAME] = (vae_decoder, vae_decoder_neuron_config) + # ControlNet + if controlnets: + for idx, controlnet in enumerate(controlnets): + controlnet_config_constructor = TasksManager.get_exporter_config_constructor( + model=controlnet, + exporter="neuron", + task="semantic-segmentation", + model_type="controlnet", + library_name=library_name, + ) + controlnet_neuron_config = controlnet_config_constructor( + controlnet.config, + task="semantic-segmentation", + dynamic_batch_size=dynamic_batch_size, + **controlnet_input_shapes, + ) + models_for_export[DIFFUSION_MODEL_CONTROLNET_NAME + "_" + str(idx)] = ( + controlnet, + controlnet_neuron_config, + ) + return models_for_export @@ -304,6 +336,17 @@ def _load_lora_weights_to_pipeline( return pipeline +def load_controlnets(controlnet_ids: Optional[Union[str, List[str]]] = None): + contronets = [] + if controlnet_ids: + if isinstance(controlnet_ids, str): + controlnet_ids = [controlnet_ids] + for model_id in controlnet_ids: + model = ControlNetModel.from_pretrained(model_id) + contronets.append(model) + return contronets + + def get_submodels_for_export_stable_diffusion( pipeline: Union["StableDiffusionPipeline", "StableDiffusionXLImg2ImgPipeline"], task: str, @@ -318,6 +361,7 @@ def get_submodels_for_export_stable_diffusion( """ is_sdxl = "xl" in task + # Lora pipeline = _load_lora_weights_to_pipeline( pipeline=pipeline, lora_model_ids=lora_model_ids, diff --git a/optimum/neuron/__init__.py b/optimum/neuron/__init__.py index cfcb76d51..b48426e43 100644 --- a/optimum/neuron/__init__.py +++ b/optimum/neuron/__init__.py @@ -44,6 +44,7 @@ "NeuronModelForObjectDetection", ], "modeling_diffusion": [ + "NeuronStableDiffusionPipelineBase", "NeuronStableDiffusionPipeline", "NeuronStableDiffusionImg2ImgPipeline", "NeuronStableDiffusionInpaintPipeline", @@ -51,6 +52,7 @@ "NeuronStableDiffusionXLPipeline", "NeuronStableDiffusionXLImg2ImgPipeline", "NeuronStableDiffusionXLInpaintPipeline", + "NeuronStableDiffusionControlNetPipeline", ], "modeling_decoder": ["NeuronDecoderModel"], "modeling_seq2seq": ["NeuronModelForSeq2SeqLM"], @@ -83,9 +85,11 @@ from .modeling_decoder import NeuronDecoderModel from .modeling_diffusion import ( NeuronLatentConsistencyModelPipeline, + NeuronStableDiffusionControlNetPipeline, NeuronStableDiffusionImg2ImgPipeline, NeuronStableDiffusionInpaintPipeline, NeuronStableDiffusionPipeline, + NeuronStableDiffusionPipelineBase, NeuronStableDiffusionXLImg2ImgPipeline, NeuronStableDiffusionXLInpaintPipeline, NeuronStableDiffusionXLPipeline, diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py index f59f8f9a8..c13ded908 100644 --- a/optimum/neuron/modeling_diffusion.py +++ b/optimum/neuron/modeling_diffusion.py @@ -23,7 +23,7 @@ from collections import OrderedDict from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union import torch from huggingface_hub import snapshot_download @@ -42,6 +42,7 @@ from ..utils import is_diffusers_available from .modeling_traced import NeuronTracedModel from .utils import ( + DIFFUSION_MODEL_CONTROLNET_NAME, DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, DIFFUSION_MODEL_TEXT_ENCODER_NAME, DIFFUSION_MODEL_UNET_NAME, @@ -73,6 +74,7 @@ if is_diffusers_available(): from diffusers import ( + ControlNetModel, DDIMScheduler, LCMScheduler, LMSDiscreteScheduler, @@ -82,14 +84,18 @@ ) from diffusers.configuration_utils import FrozenDict from diffusers.image_processor import VaeImageProcessor + from diffusers.models.controlnet import ControlNetOutput + from diffusers.pipelines.controlnet import MultiControlNetModel from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import CONFIG_NAME, is_invisible_watermark_available from .pipelines import ( NeuronLatentConsistencyPipelineMixin, + NeuronStableDiffusionControlNetPipelineMixin, NeuronStableDiffusionImg2ImgPipelineMixin, NeuronStableDiffusionInpaintPipelineMixin, NeuronStableDiffusionPipelineMixin, + NeuronStableDiffusionXLControlNetPipelineMixin, NeuronStableDiffusionXLImg2ImgPipelineMixin, NeuronStableDiffusionXLInpaintPipelineMixin, NeuronStableDiffusionXLPipelineMixin, @@ -120,11 +126,19 @@ def __init__( neuron_configs: Dict[str, "NeuronDefaultConfig"], tokenizer: CLIPTokenizer, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, LCMScheduler], - data_parallel_mode: str, + data_parallel_mode: Literal["none", "unet", "all"], vae_encoder: Optional[Union[torch.jit._script.ScriptModule, "NeuronModelVaeEncoder"]] = None, text_encoder_2: Optional[Union[torch.jit._script.ScriptModule, "NeuronModelTextEncoder"]] = None, tokenizer_2: Optional[CLIPTokenizer] = None, feature_extractor: Optional[CLIPFeatureExtractor] = None, + controlnet: Optional[ + Union[ + torch.jit._script.ScriptModule, + List[torch.jit._script.ScriptModule], + "NeuronControlNetModel", + "NeuronMultiControlNetModel", + ] + ] = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, model_and_config_save_paths: Optional[Dict[str, Tuple[str, Path]]] = None, ): @@ -148,13 +162,15 @@ def __init__( [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). scheduler (`Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]`): A scheduler to be used in combination with the U-NET component to denoise the encoded image latents. - data_parallel_mode (`str`): + data_parallel_mode (`Literal["none", "unet", "all"]`): Mode to decide what components to load into both NeuronCores of a Neuron device. Can be "none"(no data parallel), "unet"(only load unet into both cores of each device), "all"(load the whole pipeline into both cores). vae_encoder (`Optional[torch.jit._script.ScriptModule]`, defaults to `None`): The Neuron TorchScript module associated to the VAE encoder. text_encoder_2 (`Optional[torch.jit._script.ScriptModule]`, defaults to `None`): The Neuron TorchScript module associated to the second frozen text encoder. Stable Diffusion XL uses the text and pool portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), specifically the [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) variant. + controlnet (`Optional[Union[torch.jit._script.ScriptModule, List[torch.jit._script.ScriptModule], "NeuronControlNetModel", "NeuronMultiControlNetModel"]]`, defaults to `None`): + The Neuron TorchScript module(s) associated to the ControlNet(s). tokenizer_2 (`Optional[CLIPTokenizer]`, defaults to `None`): Second tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). @@ -217,6 +233,25 @@ def __init__( else: self.vae_decoder = vae_decoder + if ( + controlnet + and not isinstance(controlnet, NeuronControlNetModel) + and not isinstance(controlnet, NeuronMultiControlNetModel) + ): + controlnet_cls = ( + NeuronMultiControlNetModel + if isinstance(controlnet, list) and len(controlnet) > 1 + else NeuronControlNetModel + ) + self.controlnet = controlnet_cls( + controlnet, + self, + self.configs[DIFFUSION_MODEL_CONTROLNET_NAME], + self.neuron_configs[DIFFUSION_MODEL_CONTROLNET_NAME], + ) + else: + self.controlnet = controlnet + self.tokenizer = tokenizer self.tokenizer_2 = tokenizer_2 self.scheduler = scheduler @@ -259,6 +294,9 @@ def __init__( self.num_images_per_prompt = 1 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) @staticmethod def is_lcm(unet_config): @@ -269,12 +307,13 @@ def is_lcm(unet_config): @staticmethod @requires_torch_neuronx def load_model( - data_parallel_mode: Optional[str], + data_parallel_mode: Optional[Literal["none", "unet", "all"]], text_encoder_path: Union[str, Path], unet_path: Union[str, Path], vae_decoder_path: Optional[Union[str, Path]] = None, vae_encoder_path: Optional[Union[str, Path]] = None, text_encoder_2_path: Optional[Union[str, Path]] = None, + controlnet_paths: Optional[List[Path]] = None, dynamic_batch_size: bool = False, to_neuron: bool = False, ): @@ -283,7 +322,7 @@ def load_model( one or multiple [NeuronCore](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-hardware/neuroncores-arch.html). Args: - data_parallel_mode (`Optional[str]`): + data_parallel_mode (`Optional[Literal["none", "unet", "all"]]`): Mode to decide what components to load into both NeuronCores of a Neuron device. Can be "none"(no data parallel), "unet"(only load unet into both cores of each device), "all"(load the whole pipeline into both cores). text_encoder_path (`Union[str, Path]`): @@ -296,6 +335,8 @@ def load_model( Path of the compiled VAE encoder. It is optional, only used for tasks taking images as input. text_encoder_2_path (`Optional[Union[str, Path]]`, defaults to `None`): Path of the compiled second frozen text encoder. SDXL only. + controlnet_paths (`Optional[List[Path]]`, defaults to `None`): + Path of the compiled controlnets. dynamic_batch_size (`bool`, defaults to `False`): Whether enable dynamic batch size for neuron compiled model. If `True`, the input batch size can be a multiple of the batch size during the compilation. to_neuron (`bool`, defaults to `False`): @@ -307,6 +348,7 @@ def load_model( "vae_decoder": vae_decoder_path, "vae_encoder": vae_encoder_path, "text_encoder_2": text_encoder_2_path, + "controlnet": controlnet_paths, } # DataParallel class to use (to remove after neuron sdk 2.20) if to_neuron: @@ -320,26 +362,35 @@ def load_model( if data_parallel_mode == "all": logger.info("Loading the whole pipeline into both Neuron Cores...") - for submodel_name, submodel_path in submodels.items(): - if submodel_path is not None and submodel_path.is_file(): - submodel = NeuronTracedModel.load_model( - submodel_path, to_neuron=False - ) # No need to load to neuron manually when dp - submodels[submodel_name] = dp_cls( - submodel, - [0, 1], - set_dynamic_batching=dynamic_batch_size, - ) + for submodel_name, submodel_paths in submodels.items(): + if not isinstance(submodel_paths, list): + submodel_paths = [submodel_paths] + submodels_list = [] + for submodel_path in submodel_paths: + if submodel_path is not None and submodel_path.is_file(): + submodel = NeuronTracedModel.load_model( + submodel_path, to_neuron=False + ) # No need to load to neuron manually when dp + submodel = dp_cls( + submodel, + [0, 1], + set_dynamic_batching=dynamic_batch_size, + ) + submodels_list.append(submodel) + if submodels_list: + submodels[submodel_name] = submodels_list if len(submodels_list) > 1 else submodels_list[0] else: submodels[submodel_name] = None elif data_parallel_mode == "unet": logger.info("Loading only U-Net into both Neuron Cores...") submodels.pop("unet") + submodels.pop("controlnet") # controlnet takes inputs with the same batch_size as the unet for submodel_name, submodel_path in submodels.items(): if submodel_path is not None and submodel_path.is_file(): submodels[submodel_name] = NeuronTracedModel.load_model(submodel_path, to_neuron=to_neuron) else: submodels[submodel_name] = None + # load unet unet = NeuronTracedModel.load_model( unet_path, to_neuron=False ) # No need to load to neuron manually when dp @@ -348,11 +399,31 @@ def load_model( [0, 1], set_dynamic_batching=dynamic_batch_size, ) + # load controlnets + if controlnet_paths: + controlnets = [] + for controlnet_path in controlnet_paths: + if controlnet_path.is_file(): + controlnet = NeuronTracedModel.load_model( + controlnet_path, to_neuron=False + ) # No need to load to neuron manually when dp + controlnets.append(dp_cls(controlnet, [0, 1], set_dynamic_batching=dynamic_batch_size)) + if controlnets: + submodels["controlnet"] = controlnets if len(controlnets) > 1 else controlnets[0] + else: + submodels["controlnet"] = None elif data_parallel_mode == "none": logger.info("Loading the pipeline without any data parallelism...") - for submodel_name, submodel_path in submodels.items(): - if submodel_path is not None and submodel_path.is_file(): - submodels[submodel_name] = NeuronTracedModel.load_model(submodel_path, to_neuron=to_neuron) + for submodel_name, submodel_paths in submodels.items(): + if not isinstance(submodel_paths, list): + submodel_paths = [submodel_paths] + submodels_list = [] + for submodel_path in submodel_paths: + if submodel_path is not None and submodel_path.is_file(): + submodel = NeuronTracedModel.load_model(submodel_path, to_neuron=to_neuron) + submodels_list.append(submodel) + if submodels_list: + submodels[submodel_name] = submodels_list if len(submodels_list) > 1 else submodels_list[0] else: submodels[submodel_name] = None else: @@ -386,6 +457,7 @@ def _save_pretrained( unet_file_name: str = NEURON_FILE_NAME, vae_encoder_file_name: str = NEURON_FILE_NAME, vae_decoder_file_name: str = NEURON_FILE_NAME, + controlnet_file_name: str = NEURON_FILE_NAME, ): """ Saves the model to the serialized format optimized for Neuron devices. @@ -406,6 +478,12 @@ def _save_pretrained( if not self.model_and_config_save_paths.get(DIFFUSION_MODEL_TEXT_ENCODER_2_NAME)[0].is_file(): self.model_and_config_save_paths.pop(DIFFUSION_MODEL_TEXT_ENCODER_2_NAME) + if not self.model_and_config_save_paths.get(DIFFUSION_MODEL_CONTROLNET_NAME)[0]: + self.model_and_config_save_paths.pop(DIFFUSION_MODEL_CONTROLNET_NAME) + num_controlnet = 0 + else: + num_controlnet = len(self.model_and_config_save_paths.get(DIFFUSION_MODEL_CONTROLNET_NAME)[0]) + logger.info(f"Saving the {tuple(self.model_and_config_save_paths.keys())}...") dst_paths = { @@ -423,20 +501,34 @@ def _save_pretrained( / DIFFUSION_MODEL_VAE_DECODER_NAME / vae_decoder_file_name, } - model_src_to_dst_path = { - self.model_and_config_save_paths[model_name][0]: dst_paths[model_name] - for model_name in set(self.model_and_config_save_paths.keys()).intersection(dst_paths.keys()) - } - # save - config_src_to_dst_path = { - self.model_and_config_save_paths[model_name][1]: dst_paths[model_name].parent / CONFIG_NAME - for model_name in set(self.model_and_config_save_paths.keys()).intersection(dst_paths.keys()) - } - - src_paths = list(model_src_to_dst_path.keys()) + list(config_src_to_dst_path.keys()) - dst_paths = list(model_src_to_dst_path.values()) + list(config_src_to_dst_path.values()) - - for src_path, dst_path in zip(src_paths, dst_paths): + dst_paths[DIFFUSION_MODEL_CONTROLNET_NAME] = [ + save_directory / (DIFFUSION_MODEL_CONTROLNET_NAME + f"_{str(idx)}") / controlnet_file_name + for idx in range(num_controlnet) + ] + + src_paths_list = [] + dst_paths_list = [] + for model_name in set(self.model_and_config_save_paths.keys()).intersection(dst_paths.keys()): + model_src_path = self.model_and_config_save_paths[model_name][0] + if isinstance(model_src_path, list): + # neuron model + src_paths_list += model_src_path + dst_paths_list += dst_paths[model_name] + + # config + src_paths_list += self.model_and_config_save_paths[model_name][1] + dst_paths_list += [model_path.parent / CONFIG_NAME for model_path in dst_paths[model_name]] + + else: + # neuron model + src_paths_list.append(model_src_path) + dst_paths_list.append(dst_paths[model_name]) + + # config + src_paths_list.append(self.model_and_config_save_paths[model_name][1]) + dst_paths_list.append(dst_paths[model_name].parent / CONFIG_NAME) + + for src_path, dst_path in zip(src_paths_list, dst_paths_list): dst_path.parent.mkdir(parents=True, exist_ok=True) if src_path.is_file(): shutil.copyfile(src_path, dst_path) @@ -464,17 +556,19 @@ def _from_pretrained( unet_file_name: Optional[str] = NEURON_FILE_NAME, vae_encoder_file_name: Optional[str] = NEURON_FILE_NAME, vae_decoder_file_name: Optional[str] = NEURON_FILE_NAME, + controlnet_file_name: Optional[str] = NEURON_FILE_NAME, text_encoder_2: Optional["NeuronModelTextEncoder"] = None, vae_encoder: Optional["NeuronModelVaeEncoder"] = None, vae_decoder: Optional["NeuronModelVaeDecoder"] = None, + controlnet: Optional[Union["NeuronControlNetModel", "NeuronMultiControlNetModel"]] = None, local_files_only: bool = False, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, - data_parallel_mode: Optional[str] = None, + data_parallel_mode: Optional[Literal["none", "unet", "all"]] = None, **kwargs, # To share kwargs only available for `_from_transformers` ): model_id = str(model_id) patterns = set(config.keys()) - sub_models_to_load = patterns.intersection({"feature_extractor", "tokenizer", "tokenizer_2", "scheduler"}) + processors_to_load = patterns.intersection({"feature_extractor", "tokenizer", "tokenizer_2", "scheduler"}) if not os.path.isdir(model_id): patterns.update({DIFFUSION_MODEL_VAE_ENCODER_NAME, DIFFUSION_MODEL_VAE_DECODER_NAME}) @@ -486,6 +580,7 @@ def _from_pretrained( unet_file_name, vae_encoder_file_name, vae_decoder_file_name, + controlnet_file_name, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name, @@ -505,7 +600,7 @@ def _from_pretrained( new_model_save_dir = Path(model_id) sub_models = {} - for name in sub_models_to_load: + for name in processors_to_load: library_name, library_classes = config[name] if library_classes is not None: library = importlib.import_module(library_name) @@ -540,17 +635,35 @@ def _from_pretrained( ), } + # Add ControlNet paths + controlnet_model_paths = [] + controlnet_config_paths = [] + for path in new_model_save_dir.iterdir(): + if path.is_dir() and path.name.startswith("controlnet"): + controlnet_model_paths.append(path / controlnet_file_name) + controlnet_config_paths.append(path / cls.sub_component_config_name) + model_and_config_save_paths["controlnet"] = (controlnet_model_paths, controlnet_config_paths) + # Re-build pretrained configs and neuron configs configs, neuron_configs = {}, {} - for name, file_paths in model_and_config_save_paths.items(): - if file_paths[1].is_file(): - model_config = DiffusersPretrainedConfig.from_json_file(file_paths[1]) - configs[name] = model_config - neuron_configs[name] = cls._neuron_config_init(model_config) - inline_weights_to_neff = all( - neuron_config._config.neuron.get("inline_weights_to_neff", False) - for _, neuron_config in neuron_configs.items() - ) + inline_weights_to_neff = True + for name, (_, config_paths) in model_and_config_save_paths.items(): + if not isinstance(config_paths, list): + config_paths = [config_paths] + sub_model_configs = [] + sub_neuron_configs = [] + for config_path in config_paths: + if config_path.is_file(): + model_config = DiffusersPretrainedConfig.from_json_file(config_path) + neuron_config = cls._neuron_config_init(model_config) + inline_weights_to_neff = inline_weights_to_neff and neuron_config._config.neuron.get( + "inline_weights_to_neff", True + ) + sub_model_configs.append(model_config) + sub_neuron_configs.append(neuron_config) + if sub_model_configs and sub_neuron_configs: + configs[name] = sub_model_configs if len(sub_model_configs) > 1 else sub_model_configs[0] + neuron_configs[name] = sub_neuron_configs if len(sub_neuron_configs) > 1 else sub_neuron_configs[0] if data_parallel_mode is None: data_parallel_mode = cls.set_default_dp_mode(configs["unet"]) @@ -562,6 +675,11 @@ def _from_pretrained( vae_decoder_path=model_and_config_save_paths["vae_decoder"][0] if vae_decoder is None else None, vae_encoder_path=model_and_config_save_paths["vae_encoder"][0] if vae_encoder is None else None, text_encoder_2_path=model_and_config_save_paths["text_encoder_2"][0] if text_encoder_2 is None else None, + controlnet_paths=( + model_and_config_save_paths["controlnet"][0] + if controlnet is None and model_and_config_save_paths["controlnet"][0] + else None + ), dynamic_batch_size=neuron_configs[DIFFUSION_MODEL_UNET_NAME].dynamic_batch_size, to_neuron=not inline_weights_to_neff, ) @@ -578,6 +696,7 @@ def _from_pretrained( scheduler=sub_models.get("scheduler"), vae_encoder=vae_encoder or pipe.get("vae_encoder"), text_encoder_2=text_encoder_2 or pipe.get("text_encoder_2"), + controlnet=controlnet or pipe.get("controlnet"), tokenizer_2=sub_models.get("tokenizer_2", None), feature_extractor=sub_models.get("feature_extractor", None), data_parallel_mode=data_parallel_mode, @@ -616,11 +735,12 @@ def _export( auto_cast_type: Optional[str] = "bf16", dynamic_batch_size: bool = False, output_hidden_states: bool = False, - data_parallel_mode: Optional[str] = None, + data_parallel_mode: Optional[Literal["none", "unet", "all"]] = None, lora_model_ids: Optional[Union[str, List[str]]] = None, lora_weight_names: Optional[Union[str, List[str]]] = None, lora_adapter_names: Optional[Union[str, List[str]]] = None, lora_scales: Optional[Union[float, List[float]]] = None, + controlnet_ids: Optional[Union[str, List[str]]] = None, **kwargs_shapes, ) -> "NeuronStableDiffusionPipelineBase": """ @@ -679,7 +799,7 @@ def _export( batch size during the compilation, but it comes with a potential tradeoff in terms of latency. output_hidden_states (`bool`, defaults to `False`): Whether or not for the traced text encoders to return the hidden states of all layers. - data_parallel_mode (`Optional[str]`, defaults to `None`): + data_parallel_mode (`Optional[Literal["none", "unet", "all"]]`, defaults to `None`): Mode to decide what components to load into both NeuronCores of a Neuron device. Can be "none"(no data parallel), "unet"(only load unet into both cores of each device), "all"(load the whole pipeline into both cores). lora_model_ids (`Optional[Union[str, List[str]]]`, defaults to `None`): @@ -690,6 +810,8 @@ def _export( Adapter names to be used for referencing the loaded adapter models. lora_scales (`Optional[List[float]]`, defaults to `None`): Lora adapters scaling factors. + controlnet_ids (`Optional[Union[str, List[str]]]`, defaults to `None`): + List of ControlNet model ids (eg. `thibaud/controlnet-openpose-sdxl-1.0`)." kwargs_shapes (`Dict[str, int]`): Shapes to use during inference. This argument allows to override the default shapes used during the export. """ @@ -747,6 +869,7 @@ def _export( lora_weight_names=lora_weight_names, lora_adapter_names=lora_adapter_names, lora_scales=lora_scales, + controlnet_ids=controlnet_ids, **input_shapes_copy, ) @@ -822,6 +945,7 @@ def _export( lora_weight_names=lora_weight_names, lora_adapter_names=lora_adapter_names, lora_scales=lora_scales, + controlnet_ids=controlnet_ids, library_name=cls.library_name, **input_shapes, ) @@ -926,20 +1050,24 @@ def forward( encoder_hidden_states: torch.Tensor, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, timestep_cond: Optional[torch.Tensor] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, ): timestep = timestep.float().expand((sample.shape[0],)) - inputs = { - "sample": sample, - "timestep": timestep, - "encoder_hidden_states": encoder_hidden_states, - } + inputs = (sample, timestep, encoder_hidden_states) if timestep_cond is not None: - inputs["timestep_cond"] = timestep_cond + inputs = inputs + (timestep_cond,) if added_cond_kwargs is not None: - inputs["text_embeds"] = added_cond_kwargs.pop("text_embeds", None) - inputs["time_ids"] = added_cond_kwargs.pop("time_ids", None) + text_embeds = added_cond_kwargs.pop("text_embeds", None) + time_ids = added_cond_kwargs.pop("time_ids", None) + inputs = inputs + (text_embeds, time_ids) + if mid_block_additional_residual is not None: + inputs = inputs + (mid_block_additional_residual,) + if down_block_additional_residuals is not None: + for idx in range(len(down_block_additional_residuals)): + inputs = inputs + (down_block_additional_residuals[idx],) - outputs = self.model(*tuple(inputs.values())) + outputs = self.model(*inputs) return outputs @@ -986,6 +1114,98 @@ def forward( return tuple(output for output in outputs.values()) +class NeuronControlNetModel(_NeuronDiffusionModelPart): + auto_model_class = ControlNetModel + library_name = "diffusers" + base_model_prefix = "neuron_model" + config_name = "model_index.json" + sub_component_config_name = "config.json" + + def __init__( + self, + model: torch.jit._script.ScriptModule, + parent_model: NeuronTracedModel, + config: Optional[DiffusersPretrainedConfig] = None, + neuron_config: Optional[Dict[str, str]] = None, + ): + super().__init__(model, parent_model, config, neuron_config, DIFFUSION_MODEL_CONTROLNET_NAME) + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union["ControlNetOutput", Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: + inputs = (sample, timestep, encoder_hidden_states, controlnet_cond, conditioning_scale) + outputs = self.model(*inputs) + + if guess_mode: + logger.info( + "Guess mode is not yet supported. File us an issue on: https://github.com/huggingface/optimum-neuron/issues." + ) + + if return_dict: + outputs = ControlNetOutput(dict(zip(self.neuron_config.outputs, outputs))) + + return outputs + + +class NeuronMultiControlNetModel(_NeuronDiffusionModelPart): + auto_model_class = MultiControlNetModel + library_name = "diffusers" + base_model_prefix = "neuron_model" + config_name = "model_index.json" + sub_component_config_name = "config.json" + + def __init__( + self, + models: List[torch.jit._script.ScriptModule], + parent_model: NeuronTracedModel, + config: Optional[DiffusersPretrainedConfig] = None, + neuron_config: Optional[Dict[str, str]] = None, + ): + self.nets = models + self.parent_model = parent_model + self.config = config + self.neuron_config = neuron_config + self.model_type = DIFFUSION_MODEL_CONTROLNET_NAME + self.device = None + + def forward( + self, + sample: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + return_dict: bool = True, + ) -> Union["ControlNetOutput", Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: + for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.model)): + inputs = (sample, timestep, encoder_hidden_states, image, scale) + down_samples, mid_sample = controlnet(*inputs) + + # merge samples + if i == 0: + down_block_res_samples, mid_block_res_sample = down_samples, mid_sample + else: + down_block_res_samples = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + ] + mid_block_res_sample += mid_sample + + if return_dict: + return ControlNetOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) + + return down_block_res_samples, mid_block_res_sample + + class NeuronStableDiffusionPipeline(NeuronStableDiffusionPipelineBase, NeuronStableDiffusionPipelineMixin): __call__ = NeuronStableDiffusionPipelineMixin.__call__ @@ -1006,6 +1226,12 @@ class NeuronLatentConsistencyModelPipeline(NeuronStableDiffusionPipelineBase, Ne __call__ = NeuronLatentConsistencyPipelineMixin.__call__ +class NeuronStableDiffusionControlNetPipeline( + NeuronStableDiffusionPipelineBase, NeuronStableDiffusionControlNetPipelineMixin +): + __call__ = NeuronStableDiffusionControlNetPipelineMixin.__call__ + + class NeuronStableDiffusionXLPipelineBase(NeuronStableDiffusionPipelineBase): # `TasksManager` registered img2ime pipeline for `stable-diffusion-xl`: https://github.com/huggingface/optimum/blob/v1.12.0/optimum/exporters/tasks.py#L174 auto_model_class = StableDiffusionXLImg2ImgPipeline @@ -1018,11 +1244,19 @@ def __init__( config: Dict[str, Any], tokenizer: CLIPTokenizer, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - data_parallel_mode: str, + data_parallel_mode: Literal["none", "unet", "all"], vae_encoder: Optional[torch.jit._script.ScriptModule] = None, text_encoder_2: Optional[torch.jit._script.ScriptModule] = None, tokenizer_2: Optional[CLIPTokenizer] = None, feature_extractor: Optional[CLIPFeatureExtractor] = None, + controlnet: Optional[ + Union[ + torch.jit._script.ScriptModule, + List[torch.jit._script.ScriptModule], + "NeuronControlNetModel", + "NeuronMultiControlNetModel", + ] + ] = None, configs: Optional[Dict[str, "PretrainedConfig"]] = None, neuron_configs: Optional[Dict[str, "NeuronDefaultConfig"]] = None, model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, @@ -1041,6 +1275,7 @@ def __init__( text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, feature_extractor=feature_extractor, + controlnet=controlnet, configs=configs, neuron_configs=neuron_configs, model_save_dir=model_save_dir, @@ -1077,6 +1312,12 @@ class NeuronStableDiffusionXLInpaintPipeline( __call__ = NeuronStableDiffusionXLInpaintPipelineMixin.__call__ +class NeuronStableDiffusionXLControlNetPipeline( + NeuronStableDiffusionPipelineBase, NeuronStableDiffusionXLControlNetPipelineMixin +): + __call__ = NeuronStableDiffusionXLControlNetPipelineMixin.__call__ + + if is_neuronx_available(): # TO REMOVE: This class will be included directly in the DDP API of Neuron SDK 2.20 class WeightSeparatedDataParallel(torch_neuronx.DataParallel): diff --git a/optimum/neuron/pipelines/__init__.py b/optimum/neuron/pipelines/__init__.py index 41312ce82..aa5366fc8 100644 --- a/optimum/neuron/pipelines/__init__.py +++ b/optimum/neuron/pipelines/__init__.py @@ -25,18 +25,22 @@ "NeuronStableDiffusionImg2ImgPipelineMixin", "NeuronStableDiffusionInpaintPipelineMixin", "NeuronLatentConsistencyPipelineMixin", + "NeuronStableDiffusionControlNetPipelineMixin", "NeuronStableDiffusionXLPipelineMixin", "NeuronStableDiffusionXLImg2ImgPipelineMixin", "NeuronStableDiffusionXLInpaintPipelineMixin", + "NeuronStableDiffusionXLControlNetPipelineMixin", ], } if TYPE_CHECKING: from .diffusers import ( NeuronLatentConsistencyPipelineMixin, + NeuronStableDiffusionControlNetPipelineMixin, NeuronStableDiffusionImg2ImgPipelineMixin, NeuronStableDiffusionInpaintPipelineMixin, NeuronStableDiffusionPipelineMixin, + NeuronStableDiffusionXLControlNetPipelineMixin, NeuronStableDiffusionXLImg2ImgPipelineMixin, NeuronStableDiffusionXLInpaintPipelineMixin, NeuronStableDiffusionXLPipelineMixin, diff --git a/optimum/neuron/pipelines/diffusers/__init__.py b/optimum/neuron/pipelines/diffusers/__init__.py index b30664695..b39843657 100644 --- a/optimum/neuron/pipelines/diffusers/__init__.py +++ b/optimum/neuron/pipelines/diffusers/__init__.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .pipeline_controlnet import NeuronStableDiffusionControlNetPipelineMixin +from .pipeline_controlnet_sd_xl import NeuronStableDiffusionXLControlNetPipelineMixin from .pipeline_latent_consistency_text2img import NeuronLatentConsistencyPipelineMixin from .pipeline_stable_diffusion import NeuronStableDiffusionPipelineMixin from .pipeline_stable_diffusion_img2img import NeuronStableDiffusionImg2ImgPipelineMixin diff --git a/optimum/neuron/pipelines/diffusers/pipeline_controlnet.py b/optimum/neuron/pipelines/diffusers/pipeline_controlnet.py new file mode 100644 index 000000000..641123635 --- /dev/null +++ b/optimum/neuron/pipelines/diffusers/pipeline_controlnet.py @@ -0,0 +1,580 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Override some diffusers API for NeuronStableDiffusionControlNetPipeline""" + +import logging +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from diffusers import StableDiffusionControlNetPipeline +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import PipelineImageInput +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps + +from .pipeline_utils import StableDiffusionPipelineMixin + + +logger = logging.getLogger(__name__) + + +class NeuronStableDiffusionControlNetPipelineMixin(StableDiffusionPipelineMixin, StableDiffusionControlNetPipeline): + # Adapted from https://github.com/huggingface/diffusers/blob/de9528ebc7725012cf097e43f565aeff24940eda/src/diffusers/pipelines/controlnet/pipeline_controlnet.py#L594 + # Replace class types with Neuron ones + def check_inputs( + self, + prompt, + image, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Check `image` + if self.controlnet.__class__.__name__ == "NeuronControlNetModel": + self.check_image(image, prompt, prompt_embeds) + elif self.controlnet.__class__.__name__ == "NeuronMultiControlNetModel": + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + transposed_image = [list(t) for t in zip(*image)] + if len(transposed_image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: if you pass`image` as a list of list, each sublist must have the same length as the number of controlnets, but the sublists in `image` got {len(transposed_image)} images and {len(self.controlnet.nets)} ControlNets." + ) + for image_ in transposed_image: + self.check_image(image_, prompt, prompt_embeds) + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + else: + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if self.controlnet.__class__.__name__ == "NeuronControlNetModel": + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif self.controlnet.__class__.__name__ == "NeuronMultiControlNetModel": + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError( + "A single batch of varying conditioning scale settings (e.g. [[1.0, 0.5], [0.2, 0.8]]) is not supported at the moment. " + "The conditioning scale must be fixed across the batch." + ) + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + if self.controlnet.__class__.__name__ == "NeuronMultiControlNetModel": + if len(control_guidance_start) != len(self.controlnet.nets): + raise ValueError( + f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + output_type: str = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`Optional["PipelineImageInput"]`, defaults to `None`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single + ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple + ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet. + num_inference_steps (`int`, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`Optional[List[int]]`, defaults to `None`): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`Optional[List[int]]`, defaults to `None`): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`Optional[Union[str, List[str]]]`, defaults to `None`): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, defaults to 1): + The number of images to generate per prompt. If it is different from the batch size used for the compiltaion, + it will be overriden by the static batch size of neuron (except for dynamic batching). + eta (`float`, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`Optional[Union[torch.Generator, List[torch.Generator]]]`, defaults to `None`): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`Optional[torch.Tensor]`, defaults to `None`): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`Optional[torch.Tensor]`, defaults to `None`): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`Optional[torch.Tensor]`, defaults to `None`): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`Optional[PipelineImageInput]`, defaults to `None`): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`Optional[List[torch.Tensor]]`, defaults to `None`): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`Union[float, List[float]]`, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, defaults to `False`): + The ControlNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`Union[float, List[float]]`, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`Union[float, List[float]]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + clip_skip (`Optional[int]`, defaults to `None`): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Optional[Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]]`, defaults to `None`): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List[str]`, defaults to `["latents"]`): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Returns: + [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + controlnet = self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if controlnet.__class__.__name__ == "NeuronMultiControlNetModel" else 1 + control_guidance_start, control_guidance_end = ( + mult * [control_guidance_start], + mult * [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + image=image, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ip_adapter_image=ip_adapter_image, + ip_adapter_image_embeds=ip_adapter_image_embeds, + controlnet_conditioning_scale=controlnet_conditioning_scale, + control_guidance_start=control_guidance_start, + control_guidance_end=control_guidance_end, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if controlnet.__class__.__name__ == "NeuronMultiControlNetModel" and isinstance( + controlnet_conditioning_scale, float + ): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if controlnet.__class__.__name__ == "NeuronControlNetModel" + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + # TODO: support guess mode of ControlNet + if guess_mode: + logger.info("Disabling the guess mode as this is not supported yet.") + guess_mode = False + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + do_classifier_free_guidance = guidance_scale > 1.0 and ( + self.dynamic_batch_size or self.data_parallel_mode == "unet" + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + # TODO: support ip adapter + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + logger.info( + "IP adapter is not supported yet, `ip_adapter_image` and `ip_adapter_image_embeds` will be ignored." + ) + + # 4. Prepare image + height = self.vae_encoder.config.neuron["static_height"] + width = self.vae_encoder.config.neuron["static_width"] + if controlnet.__class__.__name__ == "NeuronControlNetModel": + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=None, + dtype=None, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif controlnet.__class__.__name__ == "NeuronMultiControlNetModel": + images = [] + + # Nested lists as ControlNet condition + if isinstance(image[0], list): + # Transpose the nested image list + image = [list(t) for t in zip(*image)] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=None, + dtype=None, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + scheduler=self.scheduler, + num_inference_steps=num_inference_steps, + device=None, + timesteps=timesteps, + sigmas=sigmas, + ) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + generator, + latents, + ) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=None, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # TODO: 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = None + + # 7.2 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if controlnet.__class__.__name__ == "NeuronControlNetModel" else keeps) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + controlnet_cond_scale = controlnet_conditioning_scale + if isinstance(controlnet_cond_scale, list): + controlnet_cond_scale = controlnet_cond_scale[0] + cond_scale = controlnet_cond_scale * controlnet_keep[i] + + # Duplicate inputs for ddp + t = torch.tensor([t] * 2) if self.data_parallel_mode == "unet" else t + cond_scale = ( + torch.tensor([cond_scale]).repeat(2) + if self.data_parallel_mode == "unet" + else torch.tensor(cond_scale) + ) + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # De-Duplicate inputs for ddp + t = t[0] if self.data_parallel_mode == "unet" else t + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + image = self.vae_decoder(latents / getattr(self.vae_decoder.config, "scaling_factor", 0.18215))[0] + image, has_nsfw_concept = self.run_safety_checker(image, dtype=prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/optimum/neuron/pipelines/diffusers/pipeline_controlnet_sd_xl.py b/optimum/neuron/pipelines/diffusers/pipeline_controlnet_sd_xl.py new file mode 100644 index 000000000..69e80292f --- /dev/null +++ b/optimum/neuron/pipelines/diffusers/pipeline_controlnet_sd_xl.py @@ -0,0 +1,22 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Override some diffusers API for NeuronStableDiffusionXLControlNetPipelineMixin""" + + +class NeuronStableDiffusionXLControlNetPipelineMixin: + def __call__(self): + raise NotImplementedError( + "`NeuronStableDiffusionXLControlNetPipelineMixin` is not yet supported but will come soon." + ) diff --git a/optimum/neuron/utils/__init__.py b/optimum/neuron/utils/__init__.py index 5b8bae48e..e2253306d 100644 --- a/optimum/neuron/utils/__init__.py +++ b/optimum/neuron/utils/__init__.py @@ -27,6 +27,7 @@ "DIFFUSION_MODEL_UNET_NAME", "DIFFUSION_MODEL_VAE_DECODER_NAME", "DIFFUSION_MODEL_VAE_ENCODER_NAME", + "DIFFUSION_MODEL_CONTROLNET_NAME", "ENCODER_NAME", "NEURON_FILE_NAME", ], @@ -40,7 +41,7 @@ "is_torch_xla_available", "is_transformers_neuronx_available", ], - "input_generators": ["DummyBeamValuesGenerator", "DummyMaskedPosGenerator"], + "input_generators": ["DummyBeamValuesGenerator", "DummyMaskedPosGenerator", "DummyControNetInputGenerator"], "misc": [ "DiffusersPretrainedConfig", "check_if_weights_replacable", @@ -74,6 +75,7 @@ from .argument_utils import convert_neuronx_compiler_args_to_neuron, store_compilation_config from .constant import ( DECODER_NAME, + DIFFUSION_MODEL_CONTROLNET_NAME, DIFFUSION_MODEL_TEXT_ENCODER_2_NAME, DIFFUSION_MODEL_TEXT_ENCODER_NAME, DIFFUSION_MODEL_UNET_NAME, @@ -92,7 +94,7 @@ is_torch_xla_available, is_transformers_neuronx_available, ) - from .input_generators import DummyBeamValuesGenerator, DummyMaskedPosGenerator + from .input_generators import DummyBeamValuesGenerator, DummyControNetInputGenerator, DummyMaskedPosGenerator from .misc import ( DiffusersPretrainedConfig, check_if_weights_replacable, diff --git a/optimum/neuron/utils/constant.py b/optimum/neuron/utils/constant.py index edc6eebb8..82f8f134f 100644 --- a/optimum/neuron/utils/constant.py +++ b/optimum/neuron/utils/constant.py @@ -22,5 +22,6 @@ DIFFUSION_MODEL_UNET_NAME = "unet" DIFFUSION_MODEL_VAE_ENCODER_NAME = "vae_encoder" DIFFUSION_MODEL_VAE_DECODER_NAME = "vae_decoder" +DIFFUSION_MODEL_CONTROLNET_NAME = "controlnet" NEURON_BINARIES_PATH = "/opt/aws/neuron/bin" diff --git a/optimum/neuron/utils/input_generators.py b/optimum/neuron/utils/input_generators.py index 9c1b52518..7fbdd38d1 100644 --- a/optimum/neuron/utils/input_generators.py +++ b/optimum/neuron/utils/input_generators.py @@ -14,6 +14,8 @@ # limitations under the License. """Dummy input generation classes.""" +from typing import Optional + import torch from ...utils import ( @@ -73,3 +75,91 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int return masked_pos elif input_name == "bool_masked_pos": return masked_pos.bool() + + +class DummyControNetInputGenerator(DummyInputGenerator): + SUPPORTED_INPUT_NAMES = ( + # ControlNet inputs + "timestep", + "encoder_hidden_states", # depending on the hidden_size of text encoder + "controlnet_cond", + "conditioning_scale", + # ControlNet outputs -> UNet inputs + "down_block_additional_residuals", + "mid_block_additional_residual", + ) + + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int, + sequence_length: Optional[int] = None, + num_channels: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + vae_scale_factor: Optional[int] = None, + encoder_hidden_size: Optional[int] = None, + **kwargs, + ): + self.task = task + self.normalized_config = normalized_config + self.batch_size = batch_size + self.sequence_length = sequence_length + self.num_channels = num_channels + self.height = height + self.width = width + self.vae_scale_factor = vae_scale_factor + self.text_encoder_hidden_size = encoder_hidden_size + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "timestep": + shape = [self.batch_size] + return self.random_int_tensor(shape, max_value=999, framework=framework, dtype=int_dtype) + elif input_name == "encoder_hidden_states": + shape = (self.batch_size, self.sequence_length, self.text_encoder_hidden_size) + return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) + elif input_name == "controlnet_cond": + num_channels = getattr( + self.normalized_config, "conditioning_channels", 3 + ) # num_channels = 3 since `do_convert_rgb=True` + shape = ( + self.batch_size, + num_channels, + self.height * self.vae_scale_factor, + self.width * self.vae_scale_factor, + ) + return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) + elif input_name == "conditioning_scale": + return torch.tensor([1.0]) + elif input_name == "down_block_additional_residuals": + sample_shape = (self.batch_size, self.normalized_config.block_out_channels[0], self.height, self.width) + sample = self.random_float_tensor(sample_shape, framework=framework, dtype=float_dtype) + down_block_res_samples = (sample,) + num_past_cross_attn_blocks = 0 + height = self.height + width = self.width + for idx, down_block_type in enumerate(self.normalized_config.down_block_types): + res_samples = () + shape = (self.batch_size, self.normalized_config.block_out_channels[idx], height, width) + for _ in range(self.normalized_config.layers_per_block): + res_samples += (self.random_float_tensor(shape, framework=framework, dtype=float_dtype),) + if idx != len(self.normalized_config.down_block_types) - 1: + # add output of downsampler + num_past_cross_attn_blocks += 1 + height = height // 2 + width = width // 2 + shape = (self.batch_size, self.normalized_config.block_out_channels[idx], height, width) + res_samples += (self.random_float_tensor(shape, framework=framework, dtype=float_dtype),) + down_block_res_samples += res_samples + return down_block_res_samples + elif input_name == "mid_block_additional_residual": + num_cross_attn_blocks = self.normalized_config.down_block_types.count("CrossAttnDownBlock2D") + out_channels = self.normalized_config.block_out_channels[-1] + shape = ( + self.batch_size, + out_channels, + self.height // 2**num_cross_attn_blocks, + self.width // 2**num_cross_attn_blocks, + ) + return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) diff --git a/optimum/neuron/utils/misc.py b/optimum/neuron/utils/misc.py index 5f399380c..c2c137898 100644 --- a/optimum/neuron/utils/misc.py +++ b/optimum/neuron/utils/misc.py @@ -669,7 +669,6 @@ def to_dict(self): def get_stable_diffusion_configs( models_for_export: Dict[str, Union["PreTrainedModel", "ModelMixin"]], - # submodels: Optional[Dict[str, Union[Path, str]]] = None, ): subfolders = ["text_encoder", "text_encoder_2", "unet", "vae"] configs = {} diff --git a/setup.py b/setup.py index 49fdcbf4a..0d4759b21 100644 --- a/setup.py +++ b/setup.py @@ -29,12 +29,13 @@ "sentencepiece", "datasets", "sacremoses", - "diffusers >= 0.26.1", + "diffusers>=0.28.0, <0.29.0", "safetensors", "sentence-transformers >= 2.2.0", "peft", "compel", "rjieba", + "opencv-python-headless", ] QUALITY_REQUIRES = [ @@ -64,7 +65,7 @@ "torchvision==0.16.*", "neuronx_distributed==0.7.0", ], - "diffusers": ["diffusers ~= 0.26.1", "peft"], + "diffusers": ["diffusers>=0.28.0, <0.29.0", "peft"], "sentence-transformers": ["sentence-transformers >= 2.2.0"], } diff --git a/tests/cli/test_export_cli.py b/tests/cli/test_export_cli.py index c82306ab9..72f84a50c 100644 --- a/tests/cli/test_export_cli.py +++ b/tests/cli/test_export_cli.py @@ -214,6 +214,40 @@ def test_stable_diffusion_multi_lora(self): check=True, ) + @requires_neuronx + def test_stable_diffusion_single_controlnet(self): + model_id = "hf-internal-testing/tiny-stable-diffusion-torch" + controlnet_id = "hf-internal-testing/tiny-controlnet" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + [ + "optimum-cli", + "export", + "neuron", + "--model", + model_id, + "--task", + "stable-diffusion", + "--batch_size", + "1", + "--height", + "64", + "--width", + "64", + "--controlnet_ids", + controlnet_id, + "--num_images_per_prompt", + "1", + "--auto_cast", + "matmul", + "--auto_cast_type", + "bf16", + tempdir, + ], + shell=False, + check=True, + ) + @requires_neuronx def test_stable_diffusion_xl(self): model_id = "echarlaix/tiny-random-stable-diffusion-xl" diff --git a/tests/inference/test_stable_diffusion_pipeline.py b/tests/inference/test_stable_diffusion_pipeline.py index 174eac1f1..3006be23b 100644 --- a/tests/inference/test_stable_diffusion_pipeline.py +++ b/tests/inference/test_stable_diffusion_pipeline.py @@ -16,12 +16,17 @@ import copy import unittest +import cv2 +import numpy as np import PIL from compel import Compel, ReturnedEmbeddingsType +from diffusers import UniPCMultistepScheduler +from diffusers.utils import load_image from parameterized import parameterized from optimum.neuron import ( NeuronLatentConsistencyModelPipeline, + NeuronStableDiffusionControlNetPipeline, NeuronStableDiffusionImg2ImgPipeline, NeuronStableDiffusionInpaintPipeline, NeuronStableDiffusionPipeline, @@ -30,10 +35,11 @@ NeuronStableDiffusionXLPipeline, ) from optimum.neuron.modeling_diffusion import ( + NeuronControlNetModel, NeuronModelTextEncoder, NeuronModelUnet, NeuronModelVaeDecoder, - NeuronModelVaeEncoder, # noqa + NeuronModelVaeEncoder, ) from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx from optimum.utils import logging @@ -188,6 +194,43 @@ def test_compatibility_with_compel(self, model_arch): image = pipe(prompt_embeds=prompt_embeds, num_inference_steps=2).images[0] self.assertIsInstance(image, PIL.Image.Image) + @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) + def test_export_and_inference_with_single_controlnet(self, model_arch): + input_shapes = copy.deepcopy(self.STATIC_INPUTS_SHAPES) + input_shapes.update({"num_images_per_prompt": 1}) + controlnet_id = "hf-internal-testing/tiny-controlnet" + neuron_pipeline = NeuronStableDiffusionControlNetPipeline.from_pretrained( + MODEL_NAMES[model_arch], + controlnet_ids=controlnet_id, + export=True, + **input_shapes, + **self.COMPILER_ARGS, + ) + self.assertIsInstance(neuron_pipeline.text_encoder, NeuronModelTextEncoder) + self.assertIsInstance(neuron_pipeline.unet, NeuronModelUnet) + self.assertIsInstance(neuron_pipeline.vae_encoder, NeuronModelVaeEncoder) + self.assertIsInstance(neuron_pipeline.vae_decoder, NeuronModelVaeDecoder) + self.assertIsInstance(neuron_pipeline.controlnet, NeuronControlNetModel) + + prompt = "the mona lisa" + # prepare canny image + original_image = load_image( + "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" + ) + + image = np.array(original_image) + + low_threshold = 100 + high_threshold = 200 + + image = cv2.Canny(image, low_threshold, high_threshold) + image = image[:, :, None] + image = np.concatenate([image, image, image], axis=2) + canny_image = PIL.Image.fromarray(image) + image = neuron_pipeline(prompt, image=canny_image).images[0] + neuron_pipeline.scheduler = UniPCMultistepScheduler.from_config(neuron_pipeline.scheduler.config) + self.assertIsInstance(image, PIL.Image.Image) + @is_inferentia_test @requires_neuronx