Skip to content

Commit

Permalink
faceid
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei committed Jan 8, 2024
1 parent bb9483d commit 0aa316f
Show file tree
Hide file tree
Showing 5 changed files with 351 additions and 5 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mediapipe
svglib
fvcore
scikit-image
opencv-python>=4.8.0
143 changes: 138 additions & 5 deletions scripts/controlmodel_ipadapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,105 @@ def forward(self, image_embeds):
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens


class MLPProjModelFaceId(torch.nn.Module):
""" MLPProjModel used for FaceId.
Source: https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter_faceid.py
"""
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
super().__init__()

self.cross_attention_dim = cross_attention_dim
self.num_tokens = num_tokens

self.proj = torch.nn.Sequential(
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
torch.nn.GELU(),
torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
)
self.norm = torch.nn.LayerNorm(cross_attention_dim)

def forward(self, id_embeds):
clip_extra_context_tokens = self.proj(id_embeds)
clip_extra_context_tokens = clip_extra_context_tokens.reshape(-1, self.num_tokens, self.cross_attention_dim)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens



class FacePerceiverResampler(torch.nn.Module):
""" Source: https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter_faceid.py """
def __init__(
self,
*,
dim=768,
depth=4,
dim_head=64,
heads=16,
embedding_dim=1280,
output_dim=768,
ff_mult=4,
):
super().__init__()

self.proj_in = torch.nn.Linear(embedding_dim, dim)
self.proj_out = torch.nn.Linear(dim, output_dim)
self.norm_out = torch.nn.LayerNorm(output_dim)
self.layers = torch.nn.ModuleList([])
for _ in range(depth):
self.layers.append(
torch.nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)

def forward(self, latents, x):
x = self.proj_in(x)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)


class ProjModelFaceIdPlus(torch.nn.Module):
""" Source: https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter_faceid.py """
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4):
super().__init__()

self.cross_attention_dim = cross_attention_dim
self.num_tokens = num_tokens

self.proj = torch.nn.Sequential(
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
torch.nn.GELU(),
torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
)
self.norm = torch.nn.LayerNorm(cross_attention_dim)

self.perceiver_resampler = FacePerceiverResampler(
dim=cross_attention_dim,
depth=4,
dim_head=64,
heads=cross_attention_dim // 64,
embedding_dim=clip_embeddings_dim,
output_dim=cross_attention_dim,
ff_mult=4,
)

def forward(self, id_embeds, clip_embeds, scale=1.0, shortcut=False):
x = self.proj(id_embeds)
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
x = self.norm(x)
out = self.perceiver_resampler(x, clip_embeds)
if shortcut:
out = x + scale * out
return out


class ImageProjModel(torch.nn.Module):
"""Projection Model"""

Expand Down Expand Up @@ -172,7 +271,8 @@ def forward(self, x):


class IPAdapterModel(torch.nn.Module):
def __init__(self, state_dict, clip_embeddings_dim, cross_attention_dim, is_plus, sdxl_plus, is_full):
def __init__(self, state_dict, clip_embeddings_dim, cross_attention_dim,
is_plus, sdxl_plus, is_full, is_faceid: bool):
super().__init__()
self.device = "cpu"

Expand All @@ -181,7 +281,9 @@ def __init__(self, state_dict, clip_embeddings_dim, cross_attention_dim, is_plus
self.sdxl_plus = sdxl_plus
self.is_full = is_full

if self.is_plus:
if is_faceid:
self.image_proj_model = self.init_proj_faceid()
elif self.is_plus:
if self.is_full:
self.image_proj_model = MLPProjModel(
cross_attention_dim=cross_attention_dim,
Expand Down Expand Up @@ -211,6 +313,22 @@ def __init__(self, state_dict, clip_embeddings_dim, cross_attention_dim, is_plus

self.load_ip_adapter(state_dict)

def init_proj_faceid(self):
if self.is_plus:
image_proj_model = ProjModelFaceIdPlus(
cross_attention_dim=self.cross_attention_dim,
id_embeddings_dim=512,
clip_embeddings_dim=1280,
num_tokens=4,
)
else:
image_proj_model = MLPProjModelFaceId(
cross_attention_dim=self.cross_attention_dim,
id_embeddings_dim=512,
num_tokens=self.clip_extra_context_tokens,
)
return image_proj_model

def load_ip_adapter(self, state_dict):
self.image_proj_model.load_state_dict(state_dict["image_proj"])
self.ip_layers = To_KV(self.cross_attention_dim)
Expand All @@ -232,6 +350,13 @@ def get_image_embeds(self, clip_vision_output):
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds

@torch.inference_mode()
def get_image_embeds_faceid_plus(self, face_embed, clip_embed):
return (
self.image_proj_model(face_embed, clip_embed),
self.image_proj_model(torch.zero_like(face_embed), torch.zero_like(clip_embed)),
)


def get_block(model, flag):
return {
Expand Down Expand Up @@ -316,6 +441,7 @@ class PlugableIPAdapter(torch.nn.Module):
def __init__(self, state_dict):
super().__init__()
self.is_full = "proj.0.weight" in state_dict['image_proj']
self.is_faceid = "0.to_q_lora.down.weight" in state_dict["ip_adapter"]
self.is_plus = self.is_full or "latents" in state_dict["image_proj"]
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[1]
self.sdxl = cross_attention_dim == 2048
Expand All @@ -336,7 +462,8 @@ def __init__(self, state_dict):
cross_attention_dim=cross_attention_dim,
is_plus=self.is_plus,
sdxl_plus=self.sdxl_plus,
is_full=self.is_full)
is_full=self.is_full,
is_faceid=self.is_faceid)
self.disable_memory_management = True
self.dtype = None
self.weight = 1.0
Expand Down Expand Up @@ -364,7 +491,13 @@ def hook(self, model, clip_vision_output, weight, start, end, dtype=torch.float3
self.dtype = dtype

self.ipadapter.to(device, dtype=self.dtype)
self.image_emb, self.uncond_image_emb = self.ipadapter.get_image_embeds(clip_vision_output)
if self.is_faceid and self.is_plus:
# Note: FaceID plus uses both face_embed and clip_embed.
# This should be the return value from preprocessor.
assert isinstance(clip_vision_output, (list, tuple))
self.image_emb, self.uncond_image_emb = self.ipadapter.get_image_embeds_faceid_plus(*clip_vision_output)
else:
self.image_emb, self.uncond_image_emb = self.ipadapter.get_image_embeds(clip_vision_output)

self.image_emb = self.image_emb.to(device, dtype=self.dtype)
self.uncond_image_emb = self.uncond_image_emb.to(device, dtype=self.dtype)
Expand Down Expand Up @@ -433,7 +566,7 @@ def forward(attn_blk, x, q):
if q.dtype != ip_k.dtype:
ip_k = ip_k.to(dtype=q.dtype)
ip_v = ip_v.to(dtype=q.dtype)

ip_out = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v, attn_mask=None, dropout_p=0.0, is_causal=False)
ip_out = ip_out.transpose(1, 2).reshape(batch_size, -1, h * head_dim)

Expand Down
187 changes: 187 additions & 0 deletions scripts/controlmodel_ipadapter_face_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import os
from typing import List

import torch
from safetensors import safe_open


class MLPProjModel(torch.nn.Module):
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
super().__init__()

self.cross_attention_dim = cross_attention_dim
self.num_tokens = num_tokens

self.proj = torch.nn.Sequential(
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim * 2),
torch.nn.GELU(),
torch.nn.Linear(id_embeddings_dim * 2, cross_attention_dim * num_tokens),
)
self.norm = torch.nn.LayerNorm(cross_attention_dim)

def forward(self, id_embeds):
x = self.proj(id_embeds)
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
x = self.norm(x)
return x


class IPAdapterFaceID:
def __init__(self, sd_pipe, ip_ckpt, lora_rank=128, num_tokens=4):
self.device = "cpu"
self.ip_ckpt = ip_ckpt
self.lora_rank = lora_rank
self.num_tokens = num_tokens

self.pipe = sd_pipe.to(self.device)
self.set_ip_adapter()

# image proj model
self.image_proj_model = self.init_proj()

self.load_ip_adapter()

def init_proj(self):
image_proj_model = MLPProjModel(
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
id_embeddings_dim=512,
num_tokens=self.num_tokens,
).to(self.device, dtype=torch.float16)
return image_proj_model

def set_ip_adapter(self):
unet = self.pipe.unet
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = (
None
if name.endswith("attn1.processor")
else unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=self.lora_rank,
).to(self.device, dtype=torch.float16)
else:
attn_procs[name] = LoRAIPAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
rank=self.lora_rank,
num_tokens=self.num_tokens,
).to(self.device, dtype=torch.float16)
unet.set_attn_processor(attn_procs)

def load_ip_adapter(self):
if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][
key.replace("image_proj.", "")
] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][
key.replace("ip_adapter.", "")
] = f.get_tensor(key)
else:
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
self.image_proj_model.load_state_dict(state_dict["image_proj"])
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"])

@torch.inference_mode()
def get_image_embeds(self, faceid_embeds):
faceid_embeds = faceid_embeds.to(self.device, dtype=torch.float16)
image_prompt_embeds = self.image_proj_model(faceid_embeds)
uncond_image_prompt_embeds = self.image_proj_model(
torch.zeros_like(faceid_embeds)
)
return image_prompt_embeds, uncond_image_prompt_embeds

def set_scale(self, scale):
for attn_processor in self.pipe.unet.attn_processors.values():
if isinstance(attn_processor, LoRAIPAttnProcessor):
attn_processor.scale = scale

def generate(
self,
faceid_embeds=None,
prompt=None,
negative_prompt=None,
scale=1.0,
num_samples=4,
seed=None,
guidance_scale=7.5,
num_inference_steps=30,
**kwargs,
):
self.set_scale(scale)

num_prompts = faceid_embeds.size(0)

if prompt is None:
prompt = "best quality, high quality"
if negative_prompt is None:
negative_prompt = (
"monochrome, lowres, bad anatomy, worst quality, low quality"
)

if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts

image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
faceid_embeds
)

bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(
bs_embed * num_samples, seq_len, -1
)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(
1, num_samples, 1
)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(
bs_embed * num_samples, seq_len, -1
)

with torch.inference_mode():
prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
prompt,
device=self.device,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat(
[negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1
)

generator = (
torch.Generator(self.device).manual_seed(seed) if seed is not None else None
)
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
**kwargs,
).images

return images
Loading

0 comments on commit 0aa316f

Please sign in to comment.