From f8b52594a5783b65c32e7ffb07f8272b32f948b2 Mon Sep 17 00:00:00 2001 From: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com> Date: Wed, 18 Sep 2024 15:04:45 +0200 Subject: [PATCH] Add support for multiple controlnet (#691) * add sd multi controlnet support * add tests * add doc * make style * add cv2 dependencies * fix ci error introduced by BatchFeature in transformers * fix * fix changes introduced with transformers/#31980 * pin libneuronxla * Update docs/source/inference_tutorials/stable_diffusion.mdx * Update optimum/neuron/modeling_diffusion.py --- .github/workflows/test_inf2_inference.yml | 3 + .../inference_tutorials/stable_diffusion.mdx | 62 +++++++++++++++++++ optimum/exporters/neuron/__main__.py | 8 +-- optimum/neuron/generation/utils.py | 2 +- optimum/neuron/modeling.py | 21 +++++++ optimum/neuron/modeling_decoder.py | 38 ++++-------- optimum/neuron/modeling_diffusion.py | 23 ++++--- optimum/neuron/modeling_seq2seq.py | 10 +-- optimum/neuron/modeling_traced.py | 32 +++------- .../diffusers/pipeline_controlnet.py | 21 +++++-- .../transformers/sentence_transformers.py | 2 +- setup.py | 3 + tests/decoder/test_decoder_hub.py | 2 +- tests/generation/test_hub.py | 4 +- tests/inference/inference_utils.py | 2 +- .../test_stable_diffusion_pipeline.py | 52 ++++++++++++---- 16 files changed, 196 insertions(+), 89 deletions(-) diff --git a/.github/workflows/test_inf2_inference.yml b/.github/workflows/test_inf2_inference.yml index ae9c3b8c0..7e2ce8d30 100644 --- a/.github/workflows/test_inf2_inference.yml +++ b/.github/workflows/test_inf2_inference.yml @@ -34,6 +34,9 @@ jobs: sudo apt-get update -y sudo apt-get install aws-neuronx-tools=2.17.1.0 aws-neuronx-runtime-lib=2.20.22.0-1b3ca6425 aws-neuronx-collectives=2.20.22.0-c101c322e -y export PATH=/opt/aws/neuron/bin:$PATH + - name: Install cv2 dependencies + run: | + sudo apt-get install ffmpeg libsm6 libxext6 -y - name: Checkout uses: actions/checkout@v2 - name: Install python dependencies diff --git a/docs/source/inference_tutorials/stable_diffusion.mdx b/docs/source/inference_tutorials/stable_diffusion.mdx index df9530875..b4c64a457 100644 --- a/docs/source/inference_tutorials/stable_diffusion.mdx +++ b/docs/source/inference_tutorials/stable_diffusion.mdx @@ -635,6 +635,68 @@ compare.save("compare.png") /> +### MultiControlNet + +With Optimum Neuron, you can also compose multiple ControlNet conditionings from different image inputs: + +* Compile multiple ControlNet for SD1.5 + +```bash +optimum-cli export neuron --inline-weights-neff --model jyoung105/stable-diffusion-v1-5 --task stable-diffusion --auto_cast matmul --auto_cast_type bf16 --batch_size 1 --num_images_per_prompt 1 --controlnet_ids lllyasviel/control_v11p_sd15_openpose lllyasviel/control_v11f1p_sd15_depth --height 512 --width 512 sd15-512x512-bf16-openpose-depth +``` + +* Run SD1.5 with OpenPose and Depth conditionings: + +```python +import numpy as np +import torch +from PIL import Image + +from controlnet_aux import OpenposeDetector +from transformers import pipeline +from diffusers import UniPCMultistepScheduler +from diffusers.utils import load_image +from optimum.neuron import NeuronStableDiffusionControlNetPipeline + + +# OpenPose+Depth ControlNet +model_id = "sd15-512x512-bf16-openpose-depth" + +# Load ControlNet images + +# 1. openpose +image = load_image("https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/main/images/input.png") +processor = OpenposeDetector.from_pretrained('lllyasviel/ControlNet') +openpose_image = processor(image) + +# 2. depth +image = load_image("https://huggingface.co/lllyasviel/control_v11p_sd15_depth/resolve/main/images/input.png") +depth_estimator = pipeline('depth-estimation') +image = depth_estimator(image)['depth'] +image = np.array(image) +image = image[:, :, None] +image = np.concatenate([image, image, image], axis=2) +depth_image = Image.fromarray(image) + +images = [openpose_image.resize((512, 512)), depth_image.resize((512, 512))] + +# 3. inference +pipe = NeuronStableDiffusionControlNetPipeline.from_pretrained(model_id) +pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) +prompt = "a giant in a fantasy landscape, best quality" +negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" + +image = pipe(prompt=prompt, image=images).images[0] +image.save('out.png') +``` + +stable diffusion 1.5 generated image with OpenPose and Depth controlnet. + ## ControlNet with Stable Diffusion XL diff --git a/optimum/exporters/neuron/__main__.py b/optimum/exporters/neuron/__main__.py index 2fa28e68a..18f557214 100644 --- a/optimum/exporters/neuron/__main__.py +++ b/optimum/exporters/neuron/__main__.py @@ -472,7 +472,7 @@ def load_models_and_neuron_configs( revision: str, force_download: bool, local_files_only: bool, - use_auth_token: Optional[Union[bool, str]], + token: Optional[Union[bool, str]], submodels: Optional[Dict[str, Union[Path, str]]], lora_model_ids: Optional[Union[str, List[str]]], lora_weight_names: Optional[Union[str, List[str]]], @@ -494,7 +494,7 @@ def load_models_and_neuron_configs( "subfolder": subfolder, "revision": revision, "cache_dir": cache_dir, - "use_auth_token": use_auth_token, + "token": token, "local_files_only": local_files_only, "force_download": force_download, "trust_remote_code": trust_remote_code, @@ -544,7 +544,7 @@ def main_export( revision: str = "main", force_download: bool = False, local_files_only: bool = False, - use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, do_validation: bool = True, submodels: Optional[Dict[str, Union[Path, str]]] = None, output_attentions: bool = False, @@ -575,7 +575,7 @@ def main_export( revision=revision, force_download=force_download, local_files_only=local_files_only, - use_auth_token=use_auth_token, + token=token, submodels=submodels, output_attentions=output_attentions, output_hidden_states=output_hidden_states, diff --git a/optimum/neuron/generation/utils.py b/optimum/neuron/generation/utils.py index 4d9c4f802..77fad2a21 100644 --- a/optimum/neuron/generation/utils.py +++ b/optimum/neuron/generation/utils.py @@ -719,7 +719,7 @@ def generate( batch_size=batch_size, model_input_name=model_input_name, model_kwargs=model_kwargs, - decoder_start_token_id=generation_config.decoder_start_token_id, + decoder_start_token_id=generation_config._decoder_start_token_tensor, device=inputs_tensor.device, ) else: diff --git a/optimum/neuron/modeling.py b/optimum/neuron/modeling.py index 864bfdb5a..c426a09ac 100644 --- a/optimum/neuron/modeling.py +++ b/optimum/neuron/modeling.py @@ -684,6 +684,13 @@ class NeuronModelForImageClassification(NeuronTracedModel): auto_model_class = AutoModelForImageClassification + @property + def dtype(self) -> Optional["torch.dtype"]: + """ + Torch dtype of the inputs to avoid error in transformers on casting a BatchFeature to type None. + """ + return getattr(self.config.neuron, "input_dtype", torch.float32) + @add_start_docstrings_to_model_forward( NEURON_IMAGE_INPUTS_DOCSTRING.format("batch_size, num_channels, height, width") + IMAGE_CLASSIFICATION_EXAMPLE.format( @@ -763,6 +770,13 @@ class NeuronModelForSemanticSegmentation(NeuronTracedModel): auto_model_class = AutoModelForSemanticSegmentation + @property + def dtype(self) -> Optional["torch.dtype"]: + """ + Torch dtype of the inputs to avoid error in transformers on casting a BatchFeature to type None. + """ + return getattr(self.config.neuron, "input_dtype", torch.float32) + @add_start_docstrings_to_model_forward( NEURON_IMAGE_INPUTS_DOCSTRING.format("batch_size, num_channels, height, width") + SEMANTIC_SEGMENTATION_EXAMPLE.format( @@ -843,6 +857,13 @@ class NeuronModelForObjectDetection(NeuronTracedModel): auto_model_class = AutoModelForObjectDetection + @property + def dtype(self) -> Optional["torch.dtype"]: + """ + Torch dtype of the inputs to avoid error in transformers on casting a BatchFeature to type None. + """ + return getattr(self.config.neuron, "input_dtype", torch.float32) + @add_start_docstrings_to_model_forward( NEURON_IMAGE_INPUTS_DOCSTRING.format("batch_size, num_channels, height, width") + OBJECT_DETECTION_EXAMPLE.format( diff --git a/optimum/neuron/modeling_decoder.py b/optimum/neuron/modeling_decoder.py index a2efea9bf..b2e7cfcbf 100644 --- a/optimum/neuron/modeling_decoder.py +++ b/optimum/neuron/modeling_decoder.py @@ -25,8 +25,7 @@ from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Optional, Tuple, Union -from huggingface_hub import HfApi, get_token, snapshot_download -from huggingface_hub.utils import is_google_colab +from huggingface_hub import HfApi, snapshot_download from transformers import AutoConfig, AutoModel, GenerationConfig from ..exporters.neuron.model_configs import * # noqa: F403 @@ -225,7 +224,7 @@ def __init__( def _create_checkpoint( cls, model_id: str, - use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, cache_dir: Optional[str] = None, @@ -243,7 +242,7 @@ def _create_checkpoint( revision=revision, framework="pt", cache_dir=cache_dir, - use_auth_token=use_auth_token, + token=token, local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, @@ -269,7 +268,7 @@ def get_export_config( cls, model_id: str, config: "PretrainedConfig", - use_auth_token: Optional[str] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, task: Optional[str] = None, batch_size: Optional[int] = None, @@ -286,7 +285,7 @@ def get_export_config( else: checkpoint_id = model_id # Get the exact checkpoint revision (SHA1) - api = HfApi(token=use_auth_token) + api = HfApi(token=token) model_info = api.repo_info(model_id, revision=revision) checkpoint_revision = model_info.sha @@ -337,7 +336,7 @@ def _export( cls, model_id: str, config: "PretrainedConfig", - use_auth_token: Optional[str] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, task: Optional[str] = None, batch_size: Optional[int] = None, @@ -353,7 +352,7 @@ def _export( new_config = cls.get_export_config( model_id, config, - use_auth_token=use_auth_token, + token=token, revision=revision, task=task, batch_size=batch_size, @@ -396,7 +395,7 @@ def _from_pretrained( cls, model_id: Union[str, Path], config: "PretrainedConfig", - use_auth_token: Optional[str] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, **kwargs, ) -> "NeuronDecoderModel": @@ -411,7 +410,7 @@ def _from_pretrained( model_path = model_id if not os.path.isdir(model_id): - model_path = snapshot_download(model_id, token=use_auth_token, revision=revision) + model_path = snapshot_download(model_id, token=token, revision=revision) checkpoint_dir, compiled_dir = cls._get_neuron_dirs(model_path) if not os.path.isdir(checkpoint_dir): @@ -425,7 +424,7 @@ def _from_pretrained( checkpoint_id, task=task, revision=checkpoint_revision, - use_auth_token=use_auth_token, + token=token, **kwargs, ) assert os.path.isdir(compiled_dir) @@ -467,24 +466,13 @@ def push_to_hub( repository_id: str, private: Optional[bool] = None, revision: Optional[str] = None, - use_auth_token: Union[bool, str] = True, + token: Union[bool, str] = True, endpoint: Optional[str] = None, ) -> str: - if isinstance(use_auth_token, str): - huggingface_token = use_auth_token - elif use_auth_token: - huggingface_token = get_token() - else: - raise ValueError("You need to provide `use_auth_token` to be able to push to the hub") api = HfApi(endpoint=endpoint) - user = api.whoami(huggingface_token) - if is_google_colab(): - # Only in Google Colab to avoid the warning message - self.git_config_username_and_email(git_email=user["email"], git_user=user["fullname"]) - api.create_repo( - token=huggingface_token, + token=token, repo_id=repository_id, exist_ok=True, private=private, @@ -498,7 +486,7 @@ def push_to_hub( api.upload_folder( repo_id=repository_id, folder_path=save_directory, - token=huggingface_token, + token=token, revision=revision, ignore_patterns=ignore_patterns, ) diff --git a/optimum/neuron/modeling_diffusion.py b/optimum/neuron/modeling_diffusion.py index 30a310bf8..77e99d43c 100644 --- a/optimum/neuron/modeling_diffusion.py +++ b/optimum/neuron/modeling_diffusion.py @@ -548,7 +548,7 @@ def _from_pretrained( cls, model_id: Union[str, Path], config: Dict[str, Any], - use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, cache_dir: Optional[str] = None, @@ -592,7 +592,7 @@ def _from_pretrained( model_id, cache_dir=cache_dir, local_files_only=local_files_only, - use_auth_token=use_auth_token, + token=token, revision=revision, force_download=force_download, allow_patterns=allow_patterns, @@ -720,7 +720,7 @@ def _export( model_id: Union[str, Path], config: Dict[str, Any], unet_id: Optional[Union[str, Path]] = None, - use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: str = "main", force_download: bool = True, cache_dir: Optional[str] = None, @@ -758,9 +758,9 @@ def _export( configuration files of compatible classes. unet_id (`Optional[Union[str, Path]]`, defaults to `None`): A string or a path point to the U-NET model to replace the one in the original pipeline. - use_auth_token (`Optional[Union[bool, str]]`, defaults to `None`): + token (`Optional[Union[bool, str]]`, defaults to `None`): The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `transformers-cli login` (stored in `~/.huggingface`). + when running `huggingface-cli login` (stored in `huggingface_hub.constants.HF_TOKEN_PATH`). revision (`str`, defaults to `"main"`): The specific model version to use (can be a branch name, tag name or commit id). force_download (`bool`, defaults to `True`): @@ -837,7 +837,7 @@ def _export( framework="pt", library_name=cls.library_name, cache_dir=cache_dir, - use_auth_token=use_auth_token, + token=token, local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, @@ -863,7 +863,7 @@ def _export( revision=revision, force_download=force_download, local_files_only=local_files_only, - use_auth_token=use_auth_token, + token=token, submodels=submodels, output_hidden_states=output_hidden_states, lora_model_ids=lora_model_ids, @@ -938,7 +938,7 @@ def _export( revision=revision, force_download=force_download, local_files_only=local_files_only, - use_auth_token=use_auth_token, + token=token, do_validation=False, submodels={"unet": unet_id}, output_hidden_states=output_hidden_states, @@ -1189,9 +1189,14 @@ def forward( encoder_hidden_states: torch.Tensor, controlnet_cond: torch.Tensor, conditioning_scale: float = 1.0, + guess_mode: bool = False, return_dict: bool = True, ) -> Union["ControlNetOutput", Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]: - for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.model)): + if guess_mode: + logger.info( + "Guess mode is not yet supported. File us an issue on: https://github.com/huggingface/optimum-neuron/issues." + ) + for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): inputs = (sample, timestep, encoder_hidden_states, image, scale) down_samples, mid_sample = controlnet(*inputs) diff --git a/optimum/neuron/modeling_seq2seq.py b/optimum/neuron/modeling_seq2seq.py index 1332b44c8..2eca15285 100644 --- a/optimum/neuron/modeling_seq2seq.py +++ b/optimum/neuron/modeling_seq2seq.py @@ -157,7 +157,7 @@ def _from_pretrained( cls, model_id: Union[str, Path], config: "PretrainedConfig", - use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, cache_dir: Optional[str] = None, @@ -176,7 +176,7 @@ def _from_pretrained( model_id, cache_dir=cache_dir, local_files_only=local_files_only, - use_auth_token=use_auth_token, + token=token, revision=revision, force_download=force_download, ignore_patterns=["*.msgpack", "*.safetensors", "*.bin"], # only download *.neuron artifacts @@ -224,7 +224,7 @@ def _from_pretrained( cache_dir=cache_dir, force_download=force_download, local_files_only=local_files_only, - use_auth_token=use_auth_token, + token=token, revision=revision, subfolder=os.path.join(subfolder, DECODER_NAME), ) @@ -255,7 +255,7 @@ def _export( cls, model_id: str, config: "PretrainedConfig", - use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: str = "main", force_download: bool = True, cache_dir: Optional[str] = None, @@ -310,7 +310,7 @@ def _export( revision=revision, force_download=force_download, local_files_only=local_files_only, - use_auth_token=use_auth_token, + token=token, do_validation=False, output_attentions=output_attentions, output_hidden_states=output_hidden_states, diff --git a/optimum/neuron/modeling_traced.py b/optimum/neuron/modeling_traced.py index 5e05181f6..a22cc18c7 100644 --- a/optimum/neuron/modeling_traced.py +++ b/optimum/neuron/modeling_traced.py @@ -24,7 +24,6 @@ import torch from huggingface_hub import HfApi, HfFolder, hf_hub_download -from huggingface_hub.utils import is_google_colab from transformers import AutoConfig, AutoModel, GenerationMixin from ..exporters.neuron import main_export @@ -156,7 +155,7 @@ def _from_pretrained( cls, model_id: Union[str, Path], config: "PretrainedConfig", - use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, cache_dir: Optional[str] = None, @@ -173,10 +172,10 @@ def _from_pretrained( if model_path.is_dir(): neuron_files = list(model_path.glob("*.neuron")) else: - if isinstance(use_auth_token, bool): + if isinstance(token, bool): token = HfFolder().get_token() else: - token = use_auth_token + token = token repo_files = map(Path, HfApi().list_repo_files(model_id, revision=revision, token=token)) pattern = "*.neuron" if subfolder == "" else f"{subfolder}/*.neuron" neuron_files = [p for p in repo_files if p.match(pattern)] @@ -210,7 +209,7 @@ def _from_pretrained( repo_id=model_id, filename=file_name, subfolder=subfolder, - use_auth_token=use_auth_token, + token=token, revision=revision, cache_dir=cache_dir, force_download=force_download, @@ -246,7 +245,7 @@ def _export( cls, model_id: str, config: "PretrainedConfig", - use_auth_token: Optional[Union[bool, str]] = None, + token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, library_name: Optional[str] = None, force_download: bool = False, @@ -324,7 +323,7 @@ def _export( framework="pt", library_name=library_name, cache_dir=cache_dir, - use_auth_token=use_auth_token, + token=token, local_files_only=local_files_only, force_download=force_download, trust_remote_code=trust_remote_code, @@ -360,7 +359,7 @@ def _export( revision=revision, force_download=force_download, local_files_only=local_files_only, - use_auth_token=use_auth_token, + token=token, do_validation=False, library_name=library_name, **kwargs_shapes, @@ -375,24 +374,13 @@ def push_to_hub( repository_id: str, private: Optional[bool] = None, revision: Optional[str] = None, - use_auth_token: Union[bool, str] = True, + token: Optional[Union[bool, str]] = None, endpoint: Optional[str] = None, ) -> str: - if isinstance(use_auth_token, str): - huggingface_token = use_auth_token - elif use_auth_token: - huggingface_token = HfFolder.get_token() - else: - raise ValueError("You need to provide `use_auth_token` to be able to push to the hub") api = HfApi(endpoint=endpoint) - user = api.whoami(huggingface_token) - if is_google_colab(): - # Only in Google Colab to avoid the warning message - self.git_config_username_and_email(git_email=user["email"], git_user=user["fullname"]) - api.create_repo( - token=huggingface_token, + token=token, repo_id=repository_id, exist_ok=True, private=private, @@ -402,7 +390,7 @@ def push_to_hub( local_file_path = os.path.join(path, name) hub_file_path = os.path.relpath(local_file_path, save_directory) api.upload_file( - token=huggingface_token, + token=token, repo_id=repository_id, path_or_fileobj=os.path.join(os.getcwd(), local_file_path), path_in_repo=hub_file_path, diff --git a/optimum/neuron/pipelines/diffusers/pipeline_controlnet.py b/optimum/neuron/pipelines/diffusers/pipeline_controlnet.py index 690872c83..daaa4f9a4 100644 --- a/optimum/neuron/pipelines/diffusers/pipeline_controlnet.py +++ b/optimum/neuron/pipelines/diffusers/pipeline_controlnet.py @@ -349,7 +349,7 @@ def __call__( global_pool_conditions = ( controlnet.config.global_pool_conditions if controlnet.__class__.__name__ == "NeuronControlNetModel" - else controlnet.nets[0].config.global_pool_conditions + else controlnet.config[0].global_pool_conditions ) guess_mode = guess_mode or global_pool_conditions # TODO: support guess mode of ControlNet @@ -502,11 +502,20 @@ def __call__( # Duplicate inputs for ddp t = torch.tensor([t] * 2) if self.data_parallel_mode == "unet" else t - cond_scale = ( - torch.tensor([cond_scale]).repeat(2) - if self.data_parallel_mode == "unet" - else torch.tensor(cond_scale) - ) + if controlnet.__class__.__name__ == "NeuronControlNetModel": + cond_scale = ( + torch.tensor([cond_scale]).repeat(2) + if self.data_parallel_mode == "unet" + else torch.tensor(cond_scale) + ) + else: + for i, scale in enumerate(cond_scale): + new_scale = ( + torch.tensor([scale]).repeat(2) + if self.data_parallel_mode == "unet" + else torch.tensor(scale) + ) + cond_scale[i] = new_scale down_block_res_samples, mid_block_res_sample = self.controlnet( control_model_input, diff --git a/optimum/neuron/pipelines/transformers/sentence_transformers.py b/optimum/neuron/pipelines/transformers/sentence_transformers.py index bfcd16695..e7324932f 100644 --- a/optimum/neuron/pipelines/transformers/sentence_transformers.py +++ b/optimum/neuron/pipelines/transformers/sentence_transformers.py @@ -12,7 +12,7 @@ def is_sentence_transformer_model(model: str, token: str = None, revision: str = None): """Checks if the model is a sentence transformer model based on provided model id""" try: - _library_name = TasksManager.infer_library_from_model(model, use_auth_token=token, revision=revision) + _library_name = TasksManager.infer_library_from_model(model, token=token, revision=revision) return _library_name == "sentence_transformers" except ValueError: return False diff --git a/setup.py b/setup.py index a389de8dc..24ab00e8a 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,8 @@ "soundfile", "librosa", "opencv-python-headless", + "controlnet-aux", + "mediapipe", ] QUALITY_REQUIRES = [ @@ -68,6 +70,7 @@ "torch==2.1.2.*", "torchvision==0.16.*", "neuronx_distributed==0.8.0", + "libneuronxla==2.0.2335", ], "diffusers": ["diffusers>=0.28.0, <0.29.0", "peft"], "sentence-transformers": ["sentence-transformers >= 2.2.0"], diff --git a/tests/decoder/test_decoder_hub.py b/tests/decoder/test_decoder_hub.py index 566d9659a..b665ba39f 100644 --- a/tests/decoder/test_decoder_hub.py +++ b/tests/decoder/test_decoder_hub.py @@ -47,7 +47,7 @@ def test_decoder_push_to_hub(from_local): model_name = f"neuron-testing-{hostname}-decoder-push" model_name += "-from-local" if from_local else "-from-hub" repo_id = f"optimum-internal-testing/{model_name}" - model.push_to_hub(model_path, repo_id, use_auth_token=get_token()) + model.push_to_hub(model_path, repo_id, token=get_token()) api = HfApi() try: hub_files_path = api.list_repo_files(repo_id) diff --git a/tests/generation/test_hub.py b/tests/generation/test_hub.py index d94e0f0e6..847c25719 100644 --- a/tests/generation/test_hub.py +++ b/tests/generation/test_hub.py @@ -35,9 +35,9 @@ def test_seq2seq_model_from_hub(): def test_push_seq2seq_to_hub(neuron_seq2seq_greedy_path, neuron_push_seq2seq_id, staging): model = NeuronModelForSeq2SeqLM.from_pretrained(neuron_seq2seq_greedy_path) model.push_to_hub( - neuron_seq2seq_greedy_path, neuron_push_seq2seq_id, use_auth_token=staging.token, endpoint=ENDPOINT_STAGING + neuron_seq2seq_greedy_path, neuron_push_seq2seq_id, token=staging["token"], endpoint=ENDPOINT_STAGING ) - api = HfApi(endpoint=ENDPOINT_STAGING, token=staging.token) + api = HfApi(endpoint=ENDPOINT_STAGING, token=staging["token"]) try: hub_files_path = api.list_repo_files(neuron_push_seq2seq_id) for path, _, files in os.walk(neuron_seq2seq_greedy_path): diff --git a/tests/inference/inference_utils.py b/tests/inference/inference_utils.py index f215a3912..c38bb521e 100644 --- a/tests/inference/inference_utils.py +++ b/tests/inference/inference_utils.py @@ -107,7 +107,7 @@ def setUpClass(cls): cls.neuron_model_id = f"{cls.USER}/{cls.NEURON_MODEL_REPO}" if cls._token: - neuron_model.push_to_hub(model_dir, repository_id=cls.neuron_model_id, use_auth_token=cls._token) + neuron_model.push_to_hub(model_dir, repository_id=cls.neuron_model_id, token=cls._token) @classmethod def tearDownClass(cls): diff --git a/tests/inference/test_stable_diffusion_pipeline.py b/tests/inference/test_stable_diffusion_pipeline.py index 406cdb68c..baca33668 100644 --- a/tests/inference/test_stable_diffusion_pipeline.py +++ b/tests/inference/test_stable_diffusion_pipeline.py @@ -41,6 +41,7 @@ NeuronModelUnet, NeuronModelVaeDecoder, NeuronModelVaeEncoder, + NeuronMultiControlNetModel, ) from optimum.neuron.utils.testing_utils import is_inferentia_test, requires_neuronx from optimum.utils import logging @@ -212,6 +213,21 @@ def test_compatibility_with_compel(self, model_arch): image = pipe(prompt_embeds=prompt_embeds, num_inference_steps=2).images[0] self.assertIsInstance(image, PIL.Image.Image) + @staticmethod + def prepare_canny_image(image_url=None): + if image_url is None: + image_url = "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" + original_image = load_image(image_url) + image = np.array(original_image) + low_threshold = 100 + high_threshold = 200 + image = cv2.Canny(image, low_threshold, high_threshold) + image = image[:, :, None] + image = np.concatenate([image, image, image], axis=2) + canny_image = PIL.Image.fromarray(image) + + return canny_image + @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) def test_export_and_inference_with_single_controlnet(self, model_arch): input_shapes = copy.deepcopy(self.STATIC_INPUTS_SHAPES) @@ -231,21 +247,33 @@ def test_export_and_inference_with_single_controlnet(self, model_arch): self.assertIsInstance(neuron_pipeline.controlnet, NeuronControlNetModel) prompt = "the mona lisa" - # prepare canny image - original_image = load_image( - "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" - ) + canny_image = NeuronStableDiffusionPipelineIntegrationTest.prepare_canny_image() + image = neuron_pipeline(prompt, image=canny_image).images[0] + neuron_pipeline.scheduler = UniPCMultistepScheduler.from_config(neuron_pipeline.scheduler.config) + self.assertIsInstance(image, PIL.Image.Image) - image = np.array(original_image) + @parameterized.expand(SUPPORTED_ARCHITECTURES, skip_on_empty=True) + def test_export_and_inference_with_multiple_controlnet(self, model_arch): + input_shapes = copy.deepcopy(self.STATIC_INPUTS_SHAPES) + input_shapes.update({"num_images_per_prompt": 1}) + controlnet_id = "hf-internal-testing/tiny-controlnet" - low_threshold = 100 - high_threshold = 200 + neuron_pipeline = NeuronStableDiffusionControlNetPipeline.from_pretrained( + MODEL_NAMES[model_arch], + controlnet_ids=[controlnet_id, controlnet_id], + export=True, + **input_shapes, + **self.COMPILER_ARGS, + ) + self.assertIsInstance(neuron_pipeline.text_encoder, NeuronModelTextEncoder) + self.assertIsInstance(neuron_pipeline.unet, NeuronModelUnet) + self.assertIsInstance(neuron_pipeline.vae_encoder, NeuronModelVaeEncoder) + self.assertIsInstance(neuron_pipeline.vae_decoder, NeuronModelVaeDecoder) + self.assertIsInstance(neuron_pipeline.controlnet, NeuronMultiControlNetModel) - image = cv2.Canny(image, low_threshold, high_threshold) - image = image[:, :, None] - image = np.concatenate([image, image, image], axis=2) - canny_image = PIL.Image.fromarray(image) - image = neuron_pipeline(prompt, image=canny_image).images[0] + prompt = "the mona lisa" + canny_image = NeuronStableDiffusionPipelineIntegrationTest.prepare_canny_image() + image = neuron_pipeline(prompt, image=[canny_image, canny_image]).images[0] neuron_pipeline.scheduler = UniPCMultistepScheduler.from_config(neuron_pipeline.scheduler.config) self.assertIsInstance(image, PIL.Image.Image)