From a84637651635004d613863dc91c4892846d8c62b Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Wed, 10 Jan 2024 14:35:03 +0000 Subject: [PATCH] IP-Adapter face id (#2434) * faceid * remove unused file * Update To_KV * faceid_plus * Remove unncessary dep * Change download path * Add install script * nit * nits * Fixes controlmode visibility --- install.py | 73 ++++++++++- requirements.txt | 1 + scripts/controlmodel_ipadapter.py | 202 +++++++++++++++++++++++++----- scripts/global_state.py | 3 + scripts/processor.py | 38 ++++++ 5 files changed, 284 insertions(+), 33 deletions(-) diff --git a/install.py b/install.py index a1cb0362f..4febb9764 100644 --- a/install.py +++ b/install.py @@ -2,6 +2,11 @@ import git # git is part of A1111 dependency. import pkg_resources import os +import sys +import platform +import requests +import tempfile +import subprocess from pathlib import Path from typing import Tuple, Optional @@ -19,8 +24,10 @@ def sync_submodules(): repo.submodule_update() except Exception as e: print(e) - print("Warning: ControlNet failed to sync submodules. Please try run " - "`git submodule init` and `git submodule update` manually.") + print( + "Warning: ControlNet failed to sync submodules. Please try run " + "`git submodule init` and `git submodule update` manually." + ) def comparable_version(version: str) -> Tuple: @@ -35,9 +42,9 @@ def get_installed_version(package: str) -> Optional[str]: def extract_base_package(package_string: str) -> str: - """ trimesh[easy] -> trimesh """ + """trimesh[easy] -> trimesh""" # Split the string on '[' and take the first part - base_package = package_string.split('[')[0] + base_package = package_string.split("[")[0] return base_package @@ -76,7 +83,65 @@ def install_requirements(req_file): ) +def try_install_insight_face(): + """Attempt to install insightface library. The library is necessary to use ip-adapter faceid. + Note: Building insightface library from source requires compiling C++ code, which should be avoided + in principle. Here the solution is to download a precompiled wheel. """ + if get_installed_version("insightface") is not None: + return + + def download_file(url, temp_dir): + """ Download a file from a given URL to a temporary directory """ + local_filename = url.split('/')[-1] + response = requests.get(url, stream=True) + response.raise_for_status() + + filepath = f"{temp_dir}/{local_filename}" + with open(filepath, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + return filepath + + def install_wheel(wheel_path): + """Install the wheel using pip""" + subprocess.run(["pip", "install", wheel_path], check=True) + + wheel_url = "https://github.com/Gourieff/Assets/raw/main/Insightface/insightface-0.7.3-cp310-cp310-win_amd64.whl" + + system = platform.system().lower() + architecture = platform.machine().lower() + python_version = sys.version_info + if ( + system == "windows" + and "amd64" in architecture + and python_version.major == 3 + and python_version.minor == 10 + ): + try: + with tempfile.TemporaryDirectory() as temp_dir: + print( + "Downloading the prebuilt wheel for Windows amd64 to a temporary directory..." + ) + wheel_path = download_file(wheel_url, temp_dir) + print(f"Download complete. File saved to {wheel_path}") + + print("Installing the wheel...") + install_wheel(wheel_path) + print("Installation complete.") + except Exception as e: + print( + "ControlNet init warning: Unable to install insightface automatically. " + e + ) + else: + print( + "ControlNet init warning: Unable to install insightface automatically. " + "Please try run `pip install insightface` manually." + ) + + sync_submodules() install_requirements(main_req_file) if os.path.exists(hand_refiner_req_file): install_requirements(hand_refiner_req_file) +try_install_insight_face() diff --git a/requirements.txt b/requirements.txt index f9f438aba..4fc55f8bd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ mediapipe svglib fvcore +onnxruntime opencv-python>=4.8.0 diff --git a/scripts/controlmodel_ipadapter.py b/scripts/controlmodel_ipadapter.py index 5503f23e9..0ca4e4f6f 100644 --- a/scripts/controlmodel_ipadapter.py +++ b/scripts/controlmodel_ipadapter.py @@ -22,6 +22,105 @@ def forward(self, image_embeds): clip_extra_context_tokens = self.proj(image_embeds) return clip_extra_context_tokens + +class MLPProjModelFaceId(torch.nn.Module): + """ MLPProjModel used for FaceId. + Source: https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter_faceid.py + """ + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + + self.proj = torch.nn.Sequential( + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), + torch.nn.GELU(), + torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), + ) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, id_embeds): + clip_extra_context_tokens = self.proj(id_embeds) + clip_extra_context_tokens = clip_extra_context_tokens.reshape(-1, self.num_tokens, self.cross_attention_dim) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + + + +class FacePerceiverResampler(torch.nn.Module): + """ Source: https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter_faceid.py """ + def __init__( + self, + *, + dim=768, + depth=4, + dim_head=64, + heads=16, + embedding_dim=1280, + output_dim=768, + ff_mult=4, + ): + super().__init__() + + self.proj_in = torch.nn.Linear(embedding_dim, dim) + self.proj_out = torch.nn.Linear(dim, output_dim) + self.norm_out = torch.nn.LayerNorm(output_dim) + self.layers = torch.nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + torch.nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, latents, x): + x = self.proj_in(x) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + latents = self.proj_out(latents) + return self.norm_out(latents) + + +class ProjModelFaceIdPlus(torch.nn.Module): + """ Source: https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter_faceid.py """ + def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.num_tokens = num_tokens + + self.proj = torch.nn.Sequential( + torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), + torch.nn.GELU(), + torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), + ) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + self.perceiver_resampler = FacePerceiverResampler( + dim=cross_attention_dim, + depth=4, + dim_head=64, + heads=cross_attention_dim // 64, + embedding_dim=clip_embeddings_dim, + output_dim=cross_attention_dim, + ff_mult=4, + ) + + def forward(self, id_embeds, clip_embeds, scale=1.0, shortcut=False): + x = self.proj(id_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.norm(x) + out = self.perceiver_resampler(x, clip_embeds) + if shortcut: + out = x + scale * out + return out + + class ImageProjModel(torch.nn.Module): """Projection Model""" @@ -43,17 +142,14 @@ def forward(self, image_embeds): # Cross Attention to_k, to_v for IPAdapter class To_KV(torch.nn.Module): - def __init__(self, cross_attention_dim): + def __init__(self, state_dict): super().__init__() - channels = SD_XL_CHANNELS if cross_attention_dim == 2048 else SD_V12_CHANNELS - self.to_kvs = torch.nn.ModuleList( - [torch.nn.Linear(cross_attention_dim, channel, bias=False) for channel in channels]) - - def load_state_dict(self, state_dict): - # input -> output -> middle - for i, key in enumerate(state_dict.keys()): - self.to_kvs[i].weight.data = state_dict[key] + self.to_kvs = nn.ModuleDict() + for key, value in state_dict.items(): + k = key.replace(".weight", "").replace(".", "_") + self.to_kvs[k] = nn.Linear(value.shape[1], value.shape[0], bias=False) + self.to_kvs[k].weight.data = value def FeedForward(dim, mult=4): @@ -172,24 +268,27 @@ def forward(self, x): class IPAdapterModel(torch.nn.Module): - def __init__(self, state_dict, clip_embeddings_dim, cross_attention_dim, is_plus, sdxl_plus, is_full): + def __init__(self, state_dict, clip_embeddings_dim, cross_attention_dim, + is_plus, sdxl_plus, is_full, is_faceid: bool): super().__init__() self.device = "cpu" + self.clip_embeddings_dim = clip_embeddings_dim self.cross_attention_dim = cross_attention_dim self.is_plus = is_plus self.sdxl_plus = sdxl_plus self.is_full = is_full + self.clip_extra_context_tokens = 16 if self.is_plus else 4 - if self.is_plus: + if is_faceid: + self.image_proj_model = self.init_proj_faceid() + elif self.is_plus: if self.is_full: self.image_proj_model = MLPProjModel( cross_attention_dim=cross_attention_dim, clip_embeddings_dim=clip_embeddings_dim ) else: - self.clip_extra_context_tokens = 16 - self.image_proj_model = Resampler( dim=1280 if sdxl_plus else cross_attention_dim, depth=4, @@ -211,10 +310,25 @@ def __init__(self, state_dict, clip_embeddings_dim, cross_attention_dim, is_plus self.load_ip_adapter(state_dict) + def init_proj_faceid(self): + if self.is_plus: + image_proj_model = ProjModelFaceIdPlus( + cross_attention_dim=self.cross_attention_dim, + id_embeddings_dim=512, + clip_embeddings_dim=self.clip_embeddings_dim, + num_tokens=4, + ) + else: + image_proj_model = MLPProjModelFaceId( + cross_attention_dim=self.cross_attention_dim, + id_embeddings_dim=512, + num_tokens=self.clip_extra_context_tokens, + ) + return image_proj_model + def load_ip_adapter(self, state_dict): self.image_proj_model.load_state_dict(state_dict["image_proj"]) - self.ip_layers = To_KV(self.cross_attention_dim) - self.ip_layers.load_state_dict(state_dict["ip_adapter"]) + self.ip_layers = To_KV(state_dict["ip_adapter"]) @torch.inference_mode() def get_image_embeds(self, clip_vision_output): @@ -232,6 +346,16 @@ def get_image_embeds(self, clip_vision_output): uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) return image_prompt_embeds, uncond_image_prompt_embeds + @torch.inference_mode() + def get_image_embeds_faceid_plus(self, face_embed, clip_vision_output): + face_embed = face_embed['image_embeds'].to(self.device) + from annotator.clipvision import clip_vision_h_uc + clip_embed = clip_vision_output['hidden_states'][-2].to(device='cpu', dtype=torch.float32) + return ( + self.image_proj_model(face_embed, clip_embed), + self.image_proj_model(torch.zeros_like(face_embed), clip_vision_h_uc.to(clip_embed)), + ) + def get_block(model, flag): return { @@ -315,13 +439,24 @@ def clear_all_ip_adapter(): class PlugableIPAdapter(torch.nn.Module): def __init__(self, state_dict): super().__init__() - self.is_full = "proj.0.weight" in state_dict['image_proj'] - self.is_plus = self.is_full or "latents" in state_dict["image_proj"] + self.is_full = "proj.3.weight" in state_dict['image_proj'] + self.is_faceid = "0.to_q_lora.down.weight" in state_dict["ip_adapter"] + self.is_plus = ( + self.is_full or + "latents" in state_dict["image_proj"] or + "perceiver_resampler.proj_in.weight" in state_dict["image_proj"] + ) cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[1] self.sdxl = cross_attention_dim == 2048 self.sdxl_plus = self.sdxl and self.is_plus - if self.is_plus: + if self.is_faceid: + if self.is_plus: + clip_embeddings_dim = 1280 + else: + # Plain faceid does not use clip_embeddings_dim. + clip_embeddings_dim = None + elif self.is_plus: if self.sdxl_plus: clip_embeddings_dim = int(state_dict["image_proj"]["latents"].shape[2]) elif self.is_full: @@ -336,7 +471,8 @@ def __init__(self, state_dict): cross_attention_dim=cross_attention_dim, is_plus=self.is_plus, sdxl_plus=self.sdxl_plus, - is_full=self.is_full) + is_full=self.is_full, + is_faceid=self.is_faceid) self.disable_memory_management = True self.dtype = None self.weight = 1.0 @@ -364,7 +500,13 @@ def hook(self, model, clip_vision_output, weight, start, end, dtype=torch.float3 self.dtype = dtype self.ipadapter.to(device, dtype=self.dtype) - self.image_emb, self.uncond_image_emb = self.ipadapter.get_image_embeds(clip_vision_output) + if self.is_faceid and self.is_plus: + # Note: FaceID plus uses both face_embed and clip_embed. + # This should be the return value from preprocessor. + assert isinstance(clip_vision_output, (list, tuple)) + self.image_emb, self.uncond_image_emb = self.ipadapter.get_image_embeds_faceid_plus(*clip_vision_output) + else: + self.image_emb, self.uncond_image_emb = self.ipadapter.get_image_embeds(clip_vision_output) self.image_emb = self.image_emb.to(device, dtype=self.dtype) self.uncond_image_emb = self.uncond_image_emb.to(device, dtype=self.dtype) @@ -397,16 +539,16 @@ def hook(self, model, clip_vision_output, weight, start, end, dtype=torch.float3 return - def call_ip(self, number, feat, device): - if number in self.cache: - return self.cache[number] + def call_ip(self, key: str, feat, device): + if key in self.cache: + return self.cache[key] else: - ip = self.ipadapter.ip_layers.to_kvs[number](feat).to(device) - self.cache[number] = ip + ip = self.ipadapter.ip_layers.to_kvs[key](feat).to(device) + self.cache[key] = ip return ip @torch.no_grad() - def patch_forward(self, number): + def patch_forward(self, number: int): @torch.no_grad() def forward(attn_blk, x, q): batch_size, sequence_length, inner_dim = x.shape @@ -419,8 +561,10 @@ def forward(attn_blk, x, q): cond_mark = current_model.cond_mark[:, :, :, 0].to(self.image_emb) cond_uncond_image_emb = self.image_emb * cond_mark + self.uncond_image_emb * (1 - cond_mark) - ip_k = self.call_ip(number * 2, cond_uncond_image_emb, device=q.device) - ip_v = self.call_ip(number * 2 + 1, cond_uncond_image_emb, device=q.device) + k_key = f"{number * 2 + 1}_to_k_ip" + v_key = f"{number * 2 + 1}_to_v_ip" + ip_k = self.call_ip(k_key, cond_uncond_image_emb, device=q.device) + ip_v = self.call_ip(v_key, cond_uncond_image_emb, device=q.device) ip_k, ip_v = map( lambda t: t.view(batch_size, -1, h, head_dim).transpose(1, 2), @@ -433,7 +577,7 @@ def forward(attn_blk, x, q): if q.dtype != ip_k.dtype: ip_k = ip_k.to(dtype=q.dtype) ip_v = ip_v.to(dtype=q.dtype) - + ip_out = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False) ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, h * head_dim) diff --git a/scripts/global_state.py b/scripts/global_state.py index 9abb02f66..e7330db0e 100644 --- a/scripts/global_state.py +++ b/scripts/global_state.py @@ -73,6 +73,8 @@ def unified_preprocessor(preprocessor_name: str, *args, **kwargs): "ip-adapter_clip_sd15": functools.partial(clip, config='clip_h'), "ip-adapter_clip_sdxl_plus_vith": functools.partial(clip, config='clip_h'), "ip-adapter_clip_sdxl": functools.partial(clip, config='clip_g'), + "ip-adapter_face_id": g_insight_face_model.run_model, + "ip-adapter_face_id_plus": face_id_plus, "color": color, "pidinet": pidinet, "pidinet_safe": pidinet_safe, @@ -117,6 +119,7 @@ def unified_preprocessor(preprocessor_name: str, *args, **kwargs): "revision_ignore_prompt": functools.partial(unload_clip, config='clip_g'), "ip-adapter_clip_sd15": functools.partial(unload_clip, config='clip_h'), "ip-adapter_clip_sdxl_plus_vith": functools.partial(unload_clip, config='clip_h'), + "ip-adapter_face_id_plus": functools.partial(unload_clip, config='clip_h'), "ip-adapter_clip_sdxl": functools.partial(unload_clip, config='clip_g'), "depth": unload_midas, "depth_leres": unload_leres, diff --git a/scripts/processor.py b/scripts/processor.py index 54110ccf8..b6046e4a9 100644 --- a/scripts/processor.py +++ b/scripts/processor.py @@ -1,6 +1,7 @@ import os import cv2 import numpy as np +import torch from annotator.util import HWC3 from typing import Callable, Tuple @@ -646,6 +647,41 @@ def unload_anime_face_segment(): model_anime_face_segment.unload_model() +class InsightFaceModel: + def __init__(self): + self.model = None + + def load_model(self): + if self.model is None: + from insightface.app import FaceAnalysis + from annotator.annotator_path import models_path + self.model = FaceAnalysis( + name="buffalo_l", + providers=['CUDAExecutionProvider', 'CPUExecutionProvider'], + root=os.path.join(models_path, "insightface"), + ) + self.model.prepare(ctx_id=0, det_size=(640, 640)) + + def run_model(self, img, **kwargs): + self.load_model() + img = HWC3(img) + faces = self.model.get(img) + faceid_embeds = { + "image_embeds": torch.from_numpy(faces[0].normed_embedding).unsqueeze(0) + } + return faceid_embeds, False + + +g_insight_face_model = InsightFaceModel() + + +def face_id_plus(img, **kwargs): + """ FaceID plus uses both face_embeding from insightface and clip_embeding from clip. """ + face_embed, _ = g_insight_face_model.run_model(img) + clip_embed, _ = clip(img, config='clip_h') + return (face_embed, clip_embed), False + + class HandRefinerModel: def __init__(self): self.model = None @@ -710,6 +746,8 @@ def run_model(self, img, res=512, **kwargs): "ip-adapter_clip_sd15", "ip-adapter_clip_sdxl", "t2ia_style_clipvision" + "ip-adapter_face_id", + "ip-adapter_face_id_plus", ] flag_preprocessor_resolution = "Preprocessor Resolution"