Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache IPAdapter instances to avoid expensive KV extraction on every generation #335

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from modules.ui_components import InputAccordion
from modules.api.api import decode_base64_to_image
import gradio as gr
import time

from lib_controlnet import global_state, external_code
from lib_controlnet.external_code import ControlNetUnit
Expand Down Expand Up @@ -494,11 +495,17 @@ def process_unit_before_every_sampling(self,
params.model.positive_advanced_weighting = soft_weighting.copy()
params.model.negative_advanced_weighting = soft_weighting.copy()

model_process_start_time = time.perf_counter()
cond, mask = params.preprocessor.process_before_every_sampling(p, cond, mask, *args, **kwargs)
model_process_end_time = time.perf_counter() - model_process_start_time
logger.debug(f"CN Preprocessor {params.preprocessor.name}: {model_process_end_time:.2f}s.")

params.model.advanced_mask_weighting = mask

model_process_start_time = time.perf_counter()
params.model.process_before_every_sampling(p, cond, mask, *args, **kwargs)
model_process_end_time = time.perf_counter() - model_process_start_time
logger.debug(f"CN Model {type(params.model).__name__}: {model_process_end_time:.2f}s.")

logger.info(f"ControlNet Method {params.preprocessor.name} patched.")
return
Expand Down Expand Up @@ -581,6 +588,8 @@ def on_ui_settings():
{"minimum": 1, "maximum": 10, "step": 1}, section=section))
shared.opts.add_option("control_net_model_cache_size", shared.OptionInfo(
5, "Model cache size (requires restart)", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}, section=section))
shared.opts.add_option("control_net_ipadapter_cache_size", shared.OptionInfo(
5, "IPAdapter cache size (requires restart)", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}, section=section))
shared.opts.add_option("control_net_no_detectmap", shared.OptionInfo(
False, "Do not append detectmap to output", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("control_net_detectmap_autosaving", shared.OptionInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import contextlib
import os
import math
import time
from cachetools import LRUCache

import ldm_patched.modules.utils
import ldm_patched.modules.model_management
Expand All @@ -16,7 +18,11 @@
import torch.nn.functional as F
import torchvision.transforms as TT


from lib_ipadapter.resampler import PerceiverAttention, FeedForward, Resampler
from modules import shared

from lib_controlnet.logging import logger

# set the models directory backward compatible
GLOBAL_MODELS_DIR = os.path.join(folder_paths.models_dir, "ipadapter")
Expand Down Expand Up @@ -259,6 +265,29 @@ def NPToTensor(image):
return out

class IPAdapter(nn.Module):

_cache = LRUCache(maxsize=shared.opts.data.get("control_net_ipadapter_cache_size", 5))

# Factory method that caches off of the model filename
@classmethod
def create(cls, model_filename, ipadapter_model, cross_attention_dim=1024, output_cross_attention_dim=1024,
clip_embeddings_dim=1024, clip_extra_context_tokens=4,
is_sdxl=False, is_plus=False, is_full=False,
is_faceid=False, is_instant_id=False):
if model_filename in cls._cache:
logger.info(f"IPAdapter: Using cached layers for {model_filename}.")
return cls._cache[model_filename]
else:
logger.info(f"IPAdapter: Creating new layer instance for {model_filename}.")
instance = cls(ipadapter_model, cross_attention_dim, output_cross_attention_dim,
clip_embeddings_dim, clip_extra_context_tokens,
is_sdxl, is_plus, is_full, is_faceid, is_instant_id)

if ldm_patched.modules.model_management.enable_ipadapter_layer_cache():
cls._cache[model_filename] = instance

return instance

def __init__(self, ipadapter_model, cross_attention_dim=1024, output_cross_attention_dim=1024,
clip_embeddings_dim=1024, clip_extra_context_tokens=4,
is_sdxl=False, is_plus=False, is_full=False,
Expand Down Expand Up @@ -612,9 +641,10 @@ def INPUT_TYPES(s):
FUNCTION = "apply_ipadapter"
CATEGORY = "ipadapter"

def apply_ipadapter(self, ipadapter, model, weight, clip_vision=None, image=None, weight_type="original",
def apply_ipadapter(self, ipadapter, model_filename, model, weight, clip_vision=None, image=None, weight_type="original",
noise=None, embeds=None, attn_mask=None, start_at=0.0, end_at=1.0, unfold_batch=False,
insightface=None, faceid_v2=False, weight_v2=False, instant_id=False):
apply_ipadapter_start = time.perf_counter()

self.dtype = torch.float16 if ldm_patched.modules.model_management.should_use_fp16() else torch.float32
self.device = ldm_patched.modules.model_management.get_torch_device()
Expand Down Expand Up @@ -720,7 +750,8 @@ def apply_ipadapter(self, ipadapter, model, weight, clip_vision=None, image=None

clip_embeddings_dim = clip_embed.shape[-1]

self.ipadapter = IPAdapter(
self.ipadapter = IPAdapter.create(
model_filename,
ipadapter,
cross_attention_dim=cross_attention_dim,
output_cross_attention_dim=output_cross_attention_dim,
Expand Down Expand Up @@ -799,6 +830,8 @@ def modifier(cnet, x_noisy, t, cond, batched_number):
set_model_patch_replace(work_model, patch_kwargs, ("middle", 0, index))
patch_kwargs["number"] += 1

apply_ipadapter_time = time.perf_counter() - apply_ipadapter_start
logger.debug(f"IPAdapter apply_ipadapter time: {apply_ipadapter_time:.2f}s")
return (work_model, )

class IPAdapterApplyFaceID(IPAdapterApply):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,20 @@ def try_build_from_state_dict(state_dict, ckpt_path):
if "ip_adapter" not in model.keys() or len(model["ip_adapter"]) == 0:
return None

o = IPAdapterPatcher(model)

model_filename = Path(ckpt_path).name.lower()

o = IPAdapterPatcher(model, model_filename)

if 'v2' in model_filename:
o.faceid_v2 = True
o.weight_v2 = True

return o

def __init__(self, state_dict):
def __init__(self, state_dict, model_filename):
super().__init__()
self.ip_adapter = state_dict
self.model_filename = model_filename
self.faceid_v2 = False
self.weight_v2 = False
return
Expand All @@ -146,6 +148,7 @@ def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):

unet = opIPAdapterApply(
ipadapter=self.ip_adapter,
model_filename=self.model_filename,
model=unet,
weight=self.strength,
start_at=self.start_percent,
Expand Down
3 changes: 3 additions & 0 deletions ldm_patched/modules/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,9 @@ def free_memory(memory_required, device, keep_loaded=[]):
if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache()

def enable_ipadapter_layer_cache():
return vram_state == VRAMState.HIGH_VRAM

def load_models_gpu(models, memory_required=0):
global vram_state

Expand Down
1 change: 1 addition & 0 deletions requirements_versions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ transformers==4.30.2
httpx==0.24.1
basicsr==1.4.2
diffusers==0.25.0
cachetools==5.3.2
Loading