diff --git a/optimum/commands/export/neuronx.py b/optimum/commands/export/neuronx.py index ab3a5e32c..ecd2ff82e 100644 --- a/optimum/commands/export/neuronx.py +++ b/optimum/commands/export/neuronx.py @@ -132,10 +132,40 @@ def parse_args_neuronx(parser: "ArgumentParser"): action="store_true", help=("Whether or not for the traced model to return the hidden states of all layers."), ) + optional_group.add_argument( + "--lora_model_ids", + default=None, + nargs="*", + type=str, + help=( + "List of model ids (eg. `ostris/super-cereal-sdxl-lora`) of pretrained lora models hosted on the Hub or paths to local directories containing the lora weights." + ), + ) + optional_group.add_argument( + "--lora_weight_names", + default=None, + nargs="*", + type=str, + help="List of lora weights file names.", + ) + optional_group.add_argument( + "--lora_adapter_names", + default=None, + nargs="*", + type=str, + help="List of the adapter names to be used for referencing the loaded adapter models.", + ) + optional_group.add_argument( + "--lora_scales", + default=None, + nargs="*", + type=float, + help="List of scaling factors for the lora adapters.", + ) optional_group.add_argument( "--output_attentions", action="store_true", - help=("Whether or not for the traced model to return the attentions tensors of all attention layers."), + help="Whether or not for the traced model to return the attentions tensors of all attention layers.", ) input_group = parser.add_argument_group("Input shapes") diff --git a/optimum/exporters/neuron/__main__.py b/optimum/exporters/neuron/__main__.py index 53884a289..9d19e7109 100644 --- a/optimum/exporters/neuron/__main__.py +++ b/optimum/exporters/neuron/__main__.py @@ -19,7 +19,7 @@ import os from argparse import ArgumentParser from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from requests.exceptions import ConnectionError as RequestsConnectionError from transformers import AutoConfig, PretrainedConfig @@ -245,6 +245,10 @@ def _get_submodels_and_neuron_configs( submodels: Optional[Dict[str, Union[Path, str]]] = None, output_attentions: bool = False, output_hidden_states: bool = False, + lora_model_ids: Optional[Union[str, List[str]]] = None, + lora_weight_names: Optional[Union[str, List[str]]] = None, + lora_adapter_names: Optional[Union[str, List[str]]] = None, + lora_scales: Optional[Union[float, List[float]]] = None, ): is_stable_diffusion = "stable-diffusion" in task is_encoder_decoder = ( @@ -258,12 +262,16 @@ def _get_submodels_and_neuron_configs( f"`output_attentions` and `output_hidden_states` are not supported by the {task} task yet." ) models_and_neuron_configs, output_model_names = _get_submodels_and_neuron_configs_for_stable_diffusion( - model, - input_shapes, - task, - output, - dynamic_batch_size, - submodels, + model=model, + input_shapes=input_shapes, + task=task, + output=output, + dynamic_batch_size=dynamic_batch_size, + submodels=submodels, + lora_model_ids=lora_model_ids, + lora_weight_names=lora_weight_names, + lora_adapter_names=lora_adapter_names, + lora_scales=lora_scales, ) elif is_encoder_decoder: optional_outputs = {"output_attentions": output_attentions, "output_hidden_states": output_hidden_states} @@ -291,6 +299,26 @@ def _get_submodels_and_neuron_configs( return models_and_neuron_configs, output_model_names +def _normalize_lora_params(lora_model_ids, lora_weight_names, lora_adapter_names, lora_scales): + if isinstance(lora_model_ids, str): + lora_model_ids = [ + lora_model_ids, + ] + if isinstance(lora_weight_names, str): + lora_weight_names = [ + lora_weight_names, + ] + if isinstance(lora_adapter_names, str): + lora_adapter_names = [ + lora_adapter_names, + ] + if isinstance(lora_scales, float): + lora_scales = [ + lora_scales, + ] + return lora_model_ids, lora_weight_names, lora_adapter_names, lora_scales + + def _get_submodels_and_neuron_configs_for_stable_diffusion( model: Union["PreTrainedModel", "DiffusionPipeline"], input_shapes: Dict[str, int], @@ -298,6 +326,10 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion( output: Path, dynamic_batch_size: bool = False, submodels: Optional[Dict[str, Union[Path, str]]] = None, + lora_model_ids: Optional[Union[str, List[str]]] = None, + lora_weight_names: Optional[Union[str, List[str]]] = None, + lora_adapter_names: Optional[Union[str, List[str]]] = None, + lora_scales: Optional[Union[float, List[float]]] = None, ): check_compiler_compatibility_for_stable_diffusion() model = replace_stable_diffusion_submodels(model, submodels) @@ -317,10 +349,17 @@ def _get_submodels_and_neuron_configs_for_stable_diffusion( model.feature_extractor.save_pretrained(output.joinpath("feature_extractor")) model.save_config(output) + lora_model_ids, lora_weight_names, lora_adapter_names, lora_scales = _normalize_lora_params( + lora_model_ids, lora_weight_names, lora_adapter_names, lora_scales + ) models_and_neuron_configs = get_stable_diffusion_models_for_export( pipeline=model, task=task, dynamic_batch_size=dynamic_batch_size, + lora_model_ids=lora_model_ids, + lora_weight_names=lora_weight_names, + lora_adapter_names=lora_adapter_names, + lora_scales=lora_scales, **input_shapes, ) output_model_names = { @@ -395,6 +434,10 @@ def main_export( output_attentions: bool = False, output_hidden_states: bool = False, library_name: Optional[str] = None, + lora_model_ids: Optional[Union[str, List[str]]] = None, + lora_weight_names: Optional[Union[str, List[str]]] = None, + lora_adapter_names: Optional[Union[str, List[str]]] = None, + lora_scales: Optional[Union[float, List[float]]] = None, **input_shapes, ): output = Path(output) @@ -434,6 +477,10 @@ def main_export( submodels=submodels, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + lora_model_ids=lora_model_ids, + lora_weight_names=lora_weight_names, + lora_adapter_names=lora_adapter_names, + lora_scales=lora_scales, ) _, neuron_outputs = export_models( @@ -556,6 +603,10 @@ def main(): do_validation=not args.disable_validation, submodels=submodels, library_name=args.library_name, + lora_model_ids=getattr(args, "lora_model_ids", None), + lora_weight_names=getattr(args, "lora_weight_names", None), + lora_adapter_names=getattr(args, "lora_adapter_names", None), + lora_scales=getattr(args, "lora_scales", None), **optional_outputs, **input_shapes, ) diff --git a/optimum/exporters/neuron/utils.py b/optimum/exporters/neuron/utils.py index eb3d799d3..7e49381df 100644 --- a/optimum/exporters/neuron/utils.py +++ b/optimum/exporters/neuron/utils.py @@ -17,7 +17,7 @@ import copy import os from collections import OrderedDict -from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union import torch from transformers import PretrainedConfig @@ -135,6 +135,10 @@ def get_stable_diffusion_models_for_export( vae_encoder_input_shapes: Dict[str, int], vae_decoder_input_shapes: Dict[str, int], dynamic_batch_size: Optional[bool] = False, + lora_model_ids: Optional[List[str]] = None, + lora_weight_names: Optional[List[str]] = None, + lora_adapter_names: Optional[List[str]] = None, + lora_scales: Optional[List[float]] = None, ) -> Dict[str, Tuple[Union["PreTrainedModel", "ModelMixin"], "NeuronDefaultConfig"]]: """ Returns the components of a Stable Diffusion model and their subsequent neuron configs. @@ -157,12 +161,27 @@ def get_stable_diffusion_models_for_export( Static shapes used for compiling vae decoder. dynamic_batch_size (`bool`, defaults to `False`): Whether the Neuron compiled model supports dynamic batch size. + lora_model_ids (`Optional[List[str]]`, defaults to `None`): + List of model ids (eg. `ostris/super-cereal-sdxl-lora`) of pretrained lora models hosted on the Hub or paths to local directories containing the lora weights. + lora_weight_names (`Optional[List[str]]`, defaults to `None`): + List of lora weights file names. + lora_adapter_names (`Optional[List[str]]`, defaults to `None`): + List of adapter names to be used for referencing the loaded adapter models. + lora_scales (`Optional[List[float]]`, defaults to `None`): + List of scaling factors for lora adapters. Returns: `Dict[str, Tuple[Union[`PreTrainedModel`, `ModelMixin`], `NeuronDefaultConfig`]`: A Dict containing the model and Neuron configs for the different components of the model. """ - models_for_export = _get_submodels_for_export_stable_diffusion(pipeline=pipeline, task=task) + models_for_export = _get_submodels_for_export_stable_diffusion( + pipeline=pipeline, + task=task, + lora_model_ids=lora_model_ids, + lora_weight_names=lora_weight_names, + lora_adapter_names=lora_adapter_names, + lora_scales=lora_scales, + ) library_name = "diffusers" # Text encoders @@ -255,15 +274,52 @@ def get_stable_diffusion_models_for_export( return models_for_export +def _load_lora_weights_to_pipeline( + pipeline: Union["StableDiffusionPipeline", "StableDiffusionXLImg2ImgPipeline"], + lora_model_ids: Optional[List[str]] = None, + weight_names: Optional[List[str]] = None, + adapter_names: Optional[List[str]] = None, + lora_scales: Optional[List[float]] = None, +): + if lora_model_ids and weight_names: + if len(lora_model_ids) == 1: + pipeline.load_lora_weights(lora_model_ids[0], weight_name=weight_names[0]) + # For tracing the lora weights, we need to use PEFT to fuse adapters directly into the model weights. It won't work by passing the lora scale to the Neuron pipeline during the inference. + pipeline.fuse_lora(lora_scale=lora_scales[0]) + elif len(lora_model_ids) > 1: + if not len(lora_model_ids) == len(weight_names) == len(adapter_names): + raise ValueError( + f"weight_name and lora_scale are required to fuse more than one lora. You have {len(lora_model_ids)} lora models to fuse, but you have {len(weight_names)} lora weight names and {len(adapter_names)} adapter names." + ) + for model_id, weight_name, adapter_name in zip(lora_model_ids, weight_names, adapter_names): + pipeline.load_lora_weights(model_id, weight_name=weight_name, adapter_name=adapter_name) + + if lora_scales: + pipeline.set_adapters(adapter_names, adapter_weights=lora_scales) + pipeline.fuse_lora() + + def _get_submodels_for_export_stable_diffusion( pipeline: Union["StableDiffusionPipeline", "StableDiffusionXLImg2ImgPipeline"], task: str, + lora_model_ids: Optional[List[str]] = None, + lora_weight_names: Optional[List[str]] = None, + lora_adapter_names: Optional[List[str]] = None, + lora_scales: Optional[List[float]] = None, ) -> Dict[str, Union["PreTrainedModel", "ModelMixin"]]: """ Returns the components of a Stable Diffusion model. """ is_sdxl = "xl" in task + _load_lora_weights_to_pipeline( + pipeline=pipeline, + lora_model_ids=lora_model_ids, + weight_names=lora_weight_names, + adapter_names=lora_adapter_names, + lora_scales=lora_scales, + ) + models_for_export = [] if hasattr(pipeline, "text_encoder_2"): projection_dim = pipeline.text_encoder_2.config.projection_dim diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py index ef00f377b..ee7408bb2 100644 --- a/optimum/neuron/modeling_diffusion.py +++ b/optimum/neuron/modeling_diffusion.py @@ -21,7 +21,7 @@ from abc import abstractmethod from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch from huggingface_hub import snapshot_download @@ -553,6 +553,10 @@ def _export( disable_fallback: bool = False, dynamic_batch_size: bool = False, data_parallel_mode: Optional[str] = None, + lora_model_ids: Optional[Union[str, List[str]]] = None, + lora_weight_names: Optional[Union[str, List[str]]] = None, + lora_adapter_names: Optional[Union[str, List[str]]] = None, + lora_scales: Optional[Union[float, List[float]]] = None, **kwargs_shapes, ) -> "NeuronStableDiffusionPipelineBase": """ @@ -615,6 +619,14 @@ def _export( data_parallel_mode (`Optional[str]`, defaults to `None`): Mode to decide what components to load into both NeuronCores of a Neuron device. Can be "none"(no data parallel), "unet"(only load unet into both cores of each device), "all"(load the whole pipeline into both cores). + lora_model_ids (`Optional[Union[str, List[str]]]`, defaults to `None`): + Lora model local paths or repo ids (eg. `ostris/super-cereal-sdxl-lora`) on the Hugginface Hub. + lora_weight_names (`Optional[Union[str, List[str]]]`, defaults to `None`): + Lora weights file names. + lora_adapter_names (`Optional[List[str]]`, defaults to `None`): + Adapter names to be used for referencing the loaded adapter models. + lora_scales (`Optional[List[float]]`, defaults to `None`): + Lora adapters scaling factors. kwargs_shapes (`Dict[str, int]`): Shapes to use during inference. This argument allows to override the default shapes used during the export. """ @@ -654,6 +666,10 @@ def _export( use_auth_token=use_auth_token, do_validation=False, submodels={"unet": unet_id}, + lora_model_ids=lora_model_ids, + lora_weight_names=lora_weight_names, + lora_adapter_names=lora_adapter_names, + lora_scales=lora_scales, library_name=cls.library_name, **input_shapes, ) diff --git a/optimum/neuron/pipelines/diffusers/pipeline_latent_consistency_text2img.py b/optimum/neuron/pipelines/diffusers/pipeline_latent_consistency_text2img.py index 28dd245a8..8d3a08253 100644 --- a/optimum/neuron/pipelines/diffusers/pipeline_latent_consistency_text2img.py +++ b/optimum/neuron/pipelines/diffusers/pipeline_latent_consistency_text2img.py @@ -172,9 +172,11 @@ def __call__( self.check_num_images_per_prompt(batch_size, neuron_batch_size, num_images_per_prompt) # 3. Encode input prompt - lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None - ) + if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Lora scale need to be fused with model weights during the compilation. The scale passed through the pipeline during inference will be ignored." + ) + lora_scale = None # NOTE: when a LCM is distilled from an LDM via latent consistency distillation (Algorithm 1) with guided # distillation, the forward pass of the LCM learns to approximate sampling from the LDM using CFG with the diff --git a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion.py b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion.py index ee93d8b70..572cea262 100644 --- a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion.py +++ b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion.py @@ -161,9 +161,11 @@ def __call__( ) # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) + if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Lora scale need to be fused with model weights during the compilation. The scale passed through the pipeline during inference will be ignored." + ) + text_encoder_lora_scale = None prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, num_images_per_prompt, diff --git a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_img2img.py b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_img2img.py index 3aaf704b7..d818d9605 100644 --- a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_img2img.py +++ b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_img2img.py @@ -212,9 +212,11 @@ def __call__( ) # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) + if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Lora scale need to be fused with model weights during the compilation. The scale passed through the pipeline during inference will be ignored." + ) + text_encoder_lora_scale = None prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, num_images_per_prompt, diff --git a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py index a00de9ba8..b757f5936 100644 --- a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py +++ b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_inpaint.py @@ -214,9 +214,11 @@ def __call__( ) # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) + if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Lora scale need to be fused with model weights during the compilation. The scale passed through the pipeline during inference will be ignored." + ) + text_encoder_lora_scale = None prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, num_images_per_prompt, diff --git a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl.py b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl.py index 63c69632b..bd930e994 100644 --- a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl.py +++ b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl.py @@ -252,7 +252,11 @@ def __call__( ) # 3. Encode input prompt - lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Lora scale need to be fused with model weights during the compilation. The scale passed through the pipeline during inference will be ignored." + ) + lora_scale = None ( prompt_embeds, negative_prompt_embeds, diff --git a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py index 7b6550b69..75229be15 100644 --- a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py +++ b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl_img2img.py @@ -345,9 +345,11 @@ def __call__( ) # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) + if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Lora scale need to be fused with model weights during the compilation. The scale passed through the pipeline during inference will be ignored." + ) + text_encoder_lora_scale = None ( prompt_embeds, negative_prompt_embeds, diff --git a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl_inpaint.py b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl_inpaint.py index 03198e43c..c1ab0be5c 100644 --- a/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl_inpaint.py +++ b/optimum/neuron/pipelines/diffusers/pipeline_stable_diffusion_xl_inpaint.py @@ -374,9 +374,11 @@ def __call__( ) # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None - ) + if cross_attention_kwargs is not None and cross_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Lora scale need to be fused with model weights during the compilation. The scale passed through the pipeline during inference will be ignored." + ) + text_encoder_lora_scale = None ( prompt_embeds, diff --git a/setup.py b/setup.py index 33b573b99..5bb56c901 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,7 @@ "torchvision==0.14.*", "neuronx_distributed==0.6.0", ], - "diffusers": ["diffusers ~= 0.26.1"], + "diffusers": ["diffusers ~= 0.26.1", "peft"], "sentence-transformers": ["sentence-transformers >= 2.2.0"], } diff --git a/tests/cli/test_export_cli.py b/tests/cli/test_export_cli.py index 2bc38eaef..c63194a6d 100644 --- a/tests/cli/test_export_cli.py +++ b/tests/cli/test_export_cli.py @@ -172,6 +172,48 @@ def test_stable_diffusion(self): check=True, ) + @requires_neuronx + def test_stable_diffusion_multi_lora(self): + model_id = "hf-internal-testing/tiny-stable-diffusion-torch" + lora_model_id = "Jingya/tiny-stable-diffusion-lora-64" + lora_weight_name = "pytorch_lora_weights.safetensors" + adpater_name = "pokemon" + with tempfile.TemporaryDirectory() as tempdir: + subprocess.run( + [ + "optimum-cli", + "export", + "neuron", + "--model", + model_id, + "--task", + "stable-diffusion", + "--batch_size", + "1", + "--height", + "64", + "--width", + "64", + "--num_images_per_prompt", + "4", + "--lora_model_ids", + lora_model_id, + "--lora_weight_names", + lora_weight_name, + "lora_adapter_names", + adpater_name, + "--lora_scales", + "0.9", + "--auto_cast", + "matmul", + "--auto_cast_type", + "bf16", + tempdir, + ], + shell=False, + check=True, + ) + @requires_neuronx def test_stable_diffusion_xl(self): model_id = "echarlaix/tiny-random-stable-diffusion-xl" diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index eba5d5bc4..e1b5e10f2 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -43,6 +43,10 @@ "latent-consistency": "echarlaix/tiny-random-latent-consistency", } +LORA_WEIGHTS_TINY = { + "stable-diffusion": ("Jingya/tiny-stable-diffusion-lora-64", "pytorch_lora_weights.safetensors", "pokemon"), +} + SENTENCE_TRANSFORMERS_MODELS = { "transformer": "sentence-transformers/all-MiniLM-L6-v2", "clip": "sentence-transformers/clip-ViT-B-32", diff --git a/tests/exporters/test_export.py b/tests/exporters/test_export.py index 7ef9ffd98..12593f56d 100644 --- a/tests/exporters/test_export.py +++ b/tests/exporters/test_export.py @@ -45,6 +45,7 @@ ENCODER_DECODER_MODELS_TINY, EXPORT_MODELS_TINY, EXTREA_DEFAULT_DUMMY_SHAPES, + LORA_WEIGHTS_TINY, SENTENCE_TRANSFORMERS_MODELS, STABLE_DIFFUSION_MODELS_TINY, WEIGHTS_NEFF_SEPARATION_UNSUPPORTED_ARCH, @@ -270,6 +271,41 @@ def test_export_for_stable_diffusion_xl_models(self, model_id): neuron_files_subpaths=output_model_names, ) + def test_export_sd_with_fused_lora_weights(self): + model_id = STABLE_DIFFUSION_MODELS_TINY["stable-diffusion"] + lora_params = LORA_WEIGHTS_TINY["stable-diffusion"] + set_seed(SEED) + + # prepare neuron config / models + model = StableDiffusionPipeline.from_pretrained(model_id) + input_shapes = build_stable_diffusion_components_mandatory_shapes( + **{"batch_size": 1, "height": 64, "width": 64, "num_images_per_prompt": 4} + ) + + with TemporaryDirectory() as tmpdirname: + models_and_neuron_configs, output_model_names = _get_submodels_and_neuron_configs( + model=model, + input_shapes=input_shapes, + task="stable-diffusion", + output=Path(tmpdirname), + model_name_or_path=model_id, + lora_model_ids=lora_params[0], + lora_weight_names=lora_params[1], + lora_adapter_names=lora_params[2], + lora_scales=0.9, + ) + _, neuron_outputs = export_models( + models_and_neuron_configs=models_and_neuron_configs, + output_dir=Path(tmpdirname), + output_file_names=output_model_names, + ) + validate_models_outputs( + models_and_neuron_configs=models_and_neuron_configs, + neuron_named_outputs=neuron_outputs, + output_dir=Path(tmpdirname), + neuron_files_subpaths=output_model_names, + ) + @is_inferentia_test @requires_neuronx diff --git a/tests/inference/inference_utils.py b/tests/inference/inference_utils.py index 33fc20da9..f64406a6a 100644 --- a/tests/inference/inference_utils.py +++ b/tests/inference/inference_utils.py @@ -50,6 +50,10 @@ "xlm-roberta": "hf-internal-testing/tiny-xlm-roberta", } +LORA_WEIGHTS_TINY = { + "stable-diffusion": ("Jingya/tiny-stable-diffusion-lora-64", "pytorch_lora_weights.safetensors", "pokemon"), +} + SENTENCE_TRANSFORMERS_MODEL_NAMES = { "transformer": "sentence-transformers/all-MiniLM-L6-v2", "clip": "sentence-transformers/clip-ViT-B-32", diff --git a/tests/inference/test_stable_diffusion_pipeline.py b/tests/inference/test_stable_diffusion_pipeline.py index bd800da1a..24490a347 100644 --- a/tests/inference/test_stable_diffusion_pipeline.py +++ b/tests/inference/test_stable_diffusion_pipeline.py @@ -38,7 +38,7 @@ from optimum.utils import logging from optimum.utils.testing_utils import require_diffusers -from .inference_utils import MODEL_NAMES, download_image +from .inference_utils import LORA_WEIGHTS_TINY, MODEL_NAMES, download_image logger = logging.get_logger() @@ -139,6 +139,32 @@ def test_lcm_export_and_inference(self, model_arch): image = neuron_pipeline(prompt, num_inference_steps=4, guidance_scale=8.0).images[0] self.assertIsInstance(image, PIL.Image.Image) + @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) + def test_export_and_inference_with_fused_lora(self, model_arch): + num_images_per_prompt = 4 + input_shapes = copy.deepcopy(self.STATIC_INPUTS_SHAPES) + input_shapes.update({"num_images_per_prompt": num_images_per_prompt}) + lora_params = LORA_WEIGHTS_TINY[model_arch] + neuron_pipeline = self.NEURON_MODEL_CLASS.from_pretrained( + MODEL_NAMES[model_arch], + export=True, + dynamic_batch_size=False, + lora_model_ids=lora_params[0], + lora_weight_names=lora_params[1], + lora_adapter_names=lora_params[2], + lora_scales=0.9, + **input_shapes, + **self.COMPILER_ARGS, + ) + self.assertIsInstance(neuron_pipeline.text_encoder, NeuronModelTextEncoder) + self.assertIsInstance(neuron_pipeline.unet, NeuronModelUnet) + self.assertIsInstance(neuron_pipeline.vae_encoder, NeuronModelVaeEncoder) + self.assertIsInstance(neuron_pipeline.vae_decoder, NeuronModelVaeDecoder) + + prompts = ["A cute brown bear eating a slice of pizza"] + image = neuron_pipeline(prompts, num_images_per_prompt=num_images_per_prompt).images[0] + self.assertIsInstance(image, PIL.Image.Image) + @is_inferentia_test @requires_neuronx