Skip to content

[LoRA] Implement hot-swapping of LoRA #9453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 55 commits into from
Apr 8, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
d3fbd7b
[WIP][LoRA] Implement hot-swapping of LoRA
BenjaminBossan Sep 17, 2024
84bae62
Reviewer feedback
BenjaminBossan Sep 18, 2024
63ece9d
Reviewer feedback, adjust test
BenjaminBossan Oct 16, 2024
94c669c
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Oct 16, 2024
c7378ed
Fix, doc
BenjaminBossan Oct 16, 2024
7c67b38
Make fix
BenjaminBossan Oct 16, 2024
ea12e0d
Fix for possible g++ error
BenjaminBossan Oct 16, 2024
ec4b0d5
Add test for recompilation w/o hotswapping
BenjaminBossan Oct 18, 2024
e07323a
Merge branch 'main' into lora-hot-swapping
sayakpaul Oct 18, 2024
529a523
Merge branch 'main' into lora-hot-swapping
sayakpaul Oct 22, 2024
ac1346d
Merge branch 'main' into lora-hot-swapping
sayakpaul Oct 25, 2024
58b35ba
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Feb 6, 2025
d21a988
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Feb 6, 2025
488f2f0
Make hotswap work
BenjaminBossan Feb 7, 2025
ece3d0f
Merge branch 'main' into lora-hot-swapping
sayakpaul Feb 8, 2025
5ab1460
Address reviewer feedback:
BenjaminBossan Feb 10, 2025
bc157e6
Change order of test decorators
BenjaminBossan Feb 10, 2025
bd1da66
Split model and pipeline tests
BenjaminBossan Feb 11, 2025
119a8ed
Reviewer feedback: Move decorator to test classes
BenjaminBossan Feb 12, 2025
53c2f84
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Feb 12, 2025
a715559
Apply suggestions from code review
BenjaminBossan Feb 13, 2025
e40390d
Reviewer feedback: version check, TODO comment
BenjaminBossan Feb 13, 2025
1b834ec
Add enable_lora_hotswap method
BenjaminBossan Feb 14, 2025
4b01401
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Feb 14, 2025
2cd3665
Reviewer feedback: check _lora_loadable_modules
BenjaminBossan Feb 17, 2025
efbd820
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Feb 18, 2025
e735ac2
Revert changes in unet.py
BenjaminBossan Feb 18, 2025
69b637d
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Feb 21, 2025
3a6677c
Add possibility to ignore enabled at wrong time
BenjaminBossan Feb 21, 2025
a96f3fd
Fix docstrings
BenjaminBossan Feb 21, 2025
deab0eb
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Feb 27, 2025
2c6b435
Log possible PEFT error, test
BenjaminBossan Feb 27, 2025
ccb45f7
Raise helpful error if hotswap not supported
BenjaminBossan Feb 27, 2025
09e2ec7
Formatting
BenjaminBossan Feb 27, 2025
67ab6bf
More linter
BenjaminBossan Feb 27, 2025
f03fe6b
More ruff
BenjaminBossan Feb 27, 2025
2d407ca
Doc-builder complaint
BenjaminBossan Feb 27, 2025
6b59ecf
Update docstring:
BenjaminBossan Mar 3, 2025
f14146f
Merge branch 'main' into lora-hot-swapping
yiyixuxu Mar 3, 2025
a79876d
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Mar 5, 2025
c3c1bdf
Fix error in docstring
BenjaminBossan Mar 5, 2025
387ddf6
Update more methods with hotswap argument
BenjaminBossan Mar 7, 2025
7f72d0b
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Mar 7, 2025
dec4d10
Add hotswap argument to load_lora_into_transformer
BenjaminBossan Mar 11, 2025
204f521
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Mar 11, 2025
716f446
Extend docstrings
BenjaminBossan Mar 12, 2025
4d82111
Add version guards to tests
BenjaminBossan Mar 12, 2025
425cb39
Formatting
BenjaminBossan Mar 12, 2025
115c77d
Fix LoRA loading call to add prefix=None
BenjaminBossan Mar 12, 2025
5d90753
Run make fix-copies
BenjaminBossan Mar 12, 2025
62c1c13
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Mar 12, 2025
d6d23b8
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Mar 17, 2025
366632d
Add hot swap documentation to the docs
BenjaminBossan Mar 17, 2025
b181a47
Apply suggestions from code review
BenjaminBossan Mar 18, 2025
f2a6146
Merge branch 'main' into lora-hot-swapping
BenjaminBossan Apr 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
text_encoder_name = TEXT_ENCODER_NAME

def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name=None,
hotswap: bool = False,
**kwargs,
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
Expand All @@ -88,6 +92,7 @@ def load_lora_weights(
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
hotswap TODO
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
Expand All @@ -109,6 +114,7 @@ def load_lora_weights(
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
adapter_name=adapter_name,
_pipeline=self,
hotswap=hotswap,
)
self.load_lora_into_text_encoder(
state_dict,
Expand Down Expand Up @@ -232,7 +238,7 @@ def lora_state_dict(
return state_dict, network_alphas

@classmethod
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, hotswap: bool = False):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.

Expand All @@ -250,6 +256,7 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
hotswap TODO
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
Expand All @@ -263,7 +270,11 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None
# Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.")
unet.load_attn_procs(
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
hotswap=hotswap,
)

@classmethod
Expand Down
114 changes: 109 additions & 5 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class UNet2DConditionLoadersMixin:
unet_name = UNET_NAME

@validate_hf_hub_args
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], hotswap: bool = False, **kwargs):
r"""
Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
defined in
Expand Down Expand Up @@ -115,6 +115,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
`default_{i}` where i is the total number of adapters being loaded.
weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file.
hotswap TODO

Example:

Expand Down Expand Up @@ -209,6 +210,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
hotswap=hotswap,
)
else:
raise ValueError(
Expand Down Expand Up @@ -268,7 +270,7 @@ def _process_custom_diffusion(self, state_dict):

return attn_processors

def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline):
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, hotswap: bool = False):
# This method does the following things:
# 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
# format. For legacy format no filtering is applied.
Expand Down Expand Up @@ -299,10 +301,12 @@ def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict

if len(state_dict_to_be_used) > 0:
if adapter_name in getattr(self, "peft_config", {}):
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
raise ValueError(
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
)
elif adapter_name not in getattr(self, "peft_config", {}) and hotswap:
raise ValueError(f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name.")

state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used)

Expand Down Expand Up @@ -336,8 +340,108 @@ def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)

inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)

def _check_hotswap_configs_compatible(config0, config1):
# To hot-swap two adapters, their configs must be compatible. Otherwise, the results could be false. E.g. if they
# use different alpha values, after hot-swapping, the alphas from the first adapter would still be used with the
# weights from the 2nd adapter, which would result in incorrect behavior. There is probably a way to swap these
# values as well, but that's not implemented yet, and it would trigger a re-compilation if the model is compiled.

# TODO: This is a very rough check at the moment and there are probably better ways than to error out
config_keys_to_check = ["lora_alpha", "use_rslora", "lora_dropout", "alpha_pattern", "use_dora"]
config0 = config0.to_dict()
config1 = config1.to_dict()
for key in config_keys_to_check:
val0 = config0[key]
val1 = config1[key]
if val0 != val1:
raise ValueError(f"Configs are incompatible: for {key}, {val0} != {val1}")

def _hotswap_adapter_from_state_dict(model, state_dict, adapter_name):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this will go away once huggingface/peft#2120 is merged?

"""
Swap out the LoRA weights from the model with the weights from state_dict.

It is assumed that the existing adapter and the new adapter are compatible.

Args:
model: nn.Module
The model with the loaded adapter.
state_dict: dict[str, torch.Tensor]
The state dict of the new adapter, which needs to be compatible (targeting same modules etc.).
adapter_name: Optional[str]
The name of the adapter that should be hot-swapped.

Raises:
RuntimeError
If the old and the new adapter are not compatible, a RuntimeError is raised.
"""
from operator import attrgetter

#######################
# INSERT ADAPTER NAME #
#######################

remapped_state_dict = {}
expected_str = adapter_name + "."
for key, val in state_dict.items():
if expected_str not in key:
prefix, _, suffix = key.rpartition(".")
key = f"{prefix}.{adapter_name}.{suffix}"
remapped_state_dict[key] = val
state_dict = remapped_state_dict

####################
# CHECK STATE_DICT #
####################

# Ensure that all the keys of the new adapter correspond exactly to the keys of the old adapter, otherwise
# hot-swapping is not possible
parameter_prefix = "lora_" # hard-coded for now
is_compiled = hasattr(model, "_orig_mod")
# TODO: there is probably a more precise way to identify the adapter keys
missing_keys = {k for k in model.state_dict() if (parameter_prefix in k) and (adapter_name in k)}
unexpected_keys = set()

# first: dry run, not swapping anything
for key, new_val in state_dict.items():
try:
old_val = attrgetter(key)(model)
except AttributeError:
unexpected_keys.add(key)
continue

if is_compiled:
missing_keys.remove("_orig_mod." + key)
else:
missing_keys.remove(key)

if missing_keys or unexpected_keys:
msg = "Hot swapping the adapter did not succeed."
if missing_keys:
msg += f" Missing keys: {', '.join(sorted(missing_keys))}."
if unexpected_keys:
msg += f" Unexpected keys: {', '.join(sorted(unexpected_keys))}."
raise RuntimeError(msg)

###################
# ACTUAL SWAPPING #
###################

for key, new_val in state_dict.items():
# no need to account for potential _orig_mod in key here, as torch handles that
old_val = attrgetter(key)(model)
old_val.data = new_val.data.to(device=old_val.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan

Suggested change
old_val.data = new_val.data.to(device=old_val.device)
old_val.data.copy_ (new_val.data.to(device=old_val.device))

# TODO: wanted to use swap_tensors but this somehow does not work on nn.Parameter
# torch.utils.swap_tensors(old_val.data, new_val.data)

if hotswap:
_check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config)
_hotswap_adapter_from_state_dict(self, state_dict, adapter_name)
# the hotswap function raises if there are incompatible keys, so if we reach this point we can set it to None
incompatible_keys = None
else:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)

if incompatible_keys is not None:
# check only for unexpected keys
Expand Down
38 changes: 38 additions & 0 deletions tests/pipelines/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
import random
import shutil
import subprocess
import sys
import tempfile
import traceback
Expand Down Expand Up @@ -2014,3 +2015,40 @@ def test_ddpm_ddim_equality_batched(self):

# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_images - ddim_images).max() < 1e-1


class TestLoraHotSwapping:
def test_hotswapping_peft_config_incompatible_raises(self):
# TODO
pass

def test_hotswapping_no_existing_adapter_raises(self):
# TODO
pass

def test_hotswapping_works(self):
# TODO
pass

def test_hotswapping_compiled_model_does_not_trigger_recompilation(self):
# TODO: kinda slow, should it get a slow marker?
env = {"TORCH_LOGS": "guards,recompiles"}
here = os.path.dirname(__file__)
file_name = os.path.join(here, "run_compiled_model_hotswap.py")

process = subprocess.Popen(
[sys.executable, file_name],
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)

# Communicate will read the output and error streams, preventing deadlock
stdout, stderr = process.communicate()
exit_code = process.returncode

# sanity check:
assert exit_code == 0

# check that the recompilation message is not present
assert "__recompiles" not in stderr.decode()
Loading