Skip to content

Commit

Permalink
🐛 Fix IP-Adapter face id plus v2 (#2448)
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei authored Jan 11, 2024
1 parent d6eeff7 commit 31b2b18
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
21 changes: 15 additions & 6 deletions scripts/controlmodel_ipadapter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import torch
import torch.nn as nn

from scripts.logging import logger

# attention_channels of input, output, middle
SD_V12_CHANNELS = [320] * 4 + [640] * 4 + [1280] * 4 + [1280] * 6 + [640] * 6 + [320] * 6 + [1280] * 2
Expand Down Expand Up @@ -347,13 +347,13 @@ def get_image_embeds(self, clip_vision_output):
return image_prompt_embeds, uncond_image_prompt_embeds

@torch.inference_mode()
def get_image_embeds_faceid_plus(self, face_embed, clip_vision_output):
def get_image_embeds_faceid_plus(self, face_embed, clip_vision_output, is_v2: bool):
face_embed = face_embed['image_embeds'].to(self.device)
from annotator.clipvision import clip_vision_h_uc
clip_embed = clip_vision_output['hidden_states'][-2].to(device='cpu', dtype=torch.float32)
return (
self.image_proj_model(face_embed, clip_embed),
self.image_proj_model(torch.zeros_like(face_embed), clip_vision_h_uc.to(clip_embed)),
self.image_proj_model(face_embed, clip_embed, shortcut=is_v2),
self.image_proj_model(torch.zeros_like(face_embed), clip_vision_h_uc.to(clip_embed), shortcut=is_v2),
)


Expand Down Expand Up @@ -437,8 +437,14 @@ def clear_all_ip_adapter():


class PlugableIPAdapter(torch.nn.Module):
def __init__(self, state_dict):
def __init__(self, state_dict, is_v2: bool = False):
"""
Arguments:
- state_dict: model state_dict.
- is_v2: whether "v2" is in model name.
"""
super().__init__()
self.is_v2 = is_v2
self.is_full = "proj.3.weight" in state_dict['image_proj']
self.is_faceid = "0.to_q_lora.down.weight" in state_dict["ip_adapter"]
self.is_plus = (
Expand All @@ -449,6 +455,8 @@ def __init__(self, state_dict):
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[1]
self.sdxl = cross_attention_dim == 2048
self.sdxl_plus = self.sdxl and self.is_plus
if self.is_faceid and self.is_v2 and self.is_plus:
logger.info("IP-Adapter faceid plus v2 detected.")

if self.is_faceid:
if self.is_plus:
Expand Down Expand Up @@ -504,7 +512,8 @@ def hook(self, model, clip_vision_output, weight, start, end, dtype=torch.float3
# Note: FaceID plus uses both face_embed and clip_embed.
# This should be the return value from preprocessor.
assert isinstance(clip_vision_output, (list, tuple))
self.image_emb, self.uncond_image_emb = self.ipadapter.get_image_embeds_faceid_plus(*clip_vision_output)
assert len(clip_vision_output) == 2
self.image_emb, self.uncond_image_emb = self.ipadapter.get_image_embeds_faceid_plus(*clip_vision_output, is_v2=self.is_v2)
else:
self.image_emb, self.uncond_image_emb = self.ipadapter.get_image_embeds(clip_vision_output)

Expand Down
4 changes: 2 additions & 2 deletions scripts/controlnet_model_guess.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def state_dict_prefix_replace(state_dict, replace_prefix):
return state_dict


def build_model_by_guess(state_dict, unet, model_path):
def build_model_by_guess(state_dict, unet, model_path: str):
if "lora_controlnet" in state_dict:
is_sdxl = "input_blocks.11.0.in_layers.0.weight" not in state_dict
logger.info(f"Using ControlNet lora ({'SDXL' if is_sdxl else 'SD15'})")
Expand Down Expand Up @@ -241,7 +241,7 @@ def build_model_by_guess(state_dict, unet, model_path):
return network

if 'ip_adapter' in state_dict:
network = PlugableIPAdapter(state_dict)
network = PlugableIPAdapter(state_dict, is_v2='v2' in model_path)
network.to('cpu')
return network

Expand Down

0 comments on commit 31b2b18

Please sign in to comment.