Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion invokeai/backend/model_manager/configs/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,16 @@ def _validate_base(cls, mod: ModelOnDisk) -> None:

@classmethod
def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
# First rule out ControlLoRA and Diffusers LoRA
# First rule out ControlLoRA
flux_format = _get_flux_lora_format(mod)
if flux_format in [FluxLoRAFormat.Control]:
raise NotAMatchError("model looks like Control LoRA")

# If it's a recognized Flux LoRA format (Kohya, Diffusers, OneTrainer, AIToolkit, XLabs, etc.),
# it's valid and we skip the heuristic check
if flux_format is not None:
return

# Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA.
# Some main models have these keys, likely due to the creator merging in a LoRA.
has_key_with_lora_prefix = state_dict_has_any_keys_starting_with(
Expand Down
6 changes: 6 additions & 0 deletions invokeai/backend/model_manager/load/model_loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
is_state_dict_likely_in_flux_onetrainer_format,
lora_model_from_flux_onetrainer_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_xlabs_lora_conversion_utils import (
is_state_dict_likely_in_flux_xlabs_format,
lora_model_from_flux_xlabs_state_dict,
)
from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format

Expand Down Expand Up @@ -117,6 +121,8 @@ def _load_model(
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict=state_dict):
model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict)
elif is_state_dict_likely_in_flux_xlabs_format(state_dict=state_dict):
model = lora_model_from_flux_xlabs_state_dict(state_dict=state_dict)
else:
raise ValueError("LoRA model is in unsupported FLUX format")
else:
Expand Down
1 change: 1 addition & 0 deletions invokeai/backend/model_manager/taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class FluxLoRAFormat(str, Enum):
OneTrainer = "flux.onetrainer"
Control = "flux.control"
AIToolkit = "flux.aitoolkit"
XLabs = "flux.xlabs"


AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, FluxVariantType]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import re
from typing import Any, Dict

import torch

from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw

# A regex pattern that matches all of the transformer keys in the xlabs FLUX LoRA format.
# Example keys:
# double_blocks.0.processor.qkv_lora1.down.weight
# double_blocks.0.processor.qkv_lora1.up.weight
# double_blocks.0.processor.proj_lora1.down.weight
# double_blocks.0.processor.proj_lora1.up.weight
# double_blocks.0.processor.qkv_lora2.down.weight
# double_blocks.0.processor.proj_lora2.up.weight
FLUX_XLABS_KEY_REGEX = r"double_blocks\.(\d+)\.processor\.(qkv|proj)_lora([12])\.(down|up)\.weight"


def is_state_dict_likely_in_flux_xlabs_format(state_dict: dict[str | int, Any]) -> bool:
"""Checks if the provided state dict is likely in the xlabs FLUX LoRA format.

The xlabs format is characterized by keys matching the pattern:
double_blocks.{block_idx}.processor.{qkv|proj}_lora{1|2}.{down|up}.weight

Where:
- lora1 corresponds to the image attention stream (img_attn)
- lora2 corresponds to the text attention stream (txt_attn)
"""
if not state_dict:
return False

# Check that all keys match the xlabs pattern
for key in state_dict.keys():
if not isinstance(key, str):
continue
if not re.match(FLUX_XLABS_KEY_REGEX, key):
return False

# Ensure we have at least some valid keys
return any(isinstance(k, str) and re.match(FLUX_XLABS_KEY_REGEX, k) for k in state_dict.keys())


def lora_model_from_flux_xlabs_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
"""Converts an xlabs FLUX LoRA state dict to the InvokeAI ModelPatchRaw format.

The xlabs format uses:
- lora1 for image attention stream (img_attn)
- lora2 for text attention stream (txt_attn)
- qkv for query/key/value projection
- proj for output projection

Key mapping:
- double_blocks.X.processor.qkv_lora1 -> double_blocks.X.img_attn.qkv
- double_blocks.X.processor.proj_lora1 -> double_blocks.X.img_attn.proj
- double_blocks.X.processor.qkv_lora2 -> double_blocks.X.txt_attn.qkv
- double_blocks.X.processor.proj_lora2 -> double_blocks.X.txt_attn.proj
"""
# Group keys by layer (without the .down.weight/.up.weight suffix)
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}

for key, value in state_dict.items():
match = re.match(FLUX_XLABS_KEY_REGEX, key)
if not match:
raise ValueError(f"Key '{key}' does not match the expected pattern for xlabs FLUX LoRA weights.")

block_idx = match.group(1)
component = match.group(2) # qkv or proj
lora_stream = match.group(3) # 1 or 2
direction = match.group(4) # down or up

# Map lora1 -> img_attn, lora2 -> txt_attn
attn_type = "img_attn" if lora_stream == "1" else "txt_attn"

# Create the InvokeAI-style layer key
layer_key = f"double_blocks.{block_idx}.{attn_type}.{component}"

if layer_key not in grouped_state_dict:
grouped_state_dict[layer_key] = {}

# Map down/up to lora_down/lora_up
param_name = f"lora_{direction}.weight"
grouped_state_dict[layer_key][param_name] = value

# Create LoRA layers
layers: dict[str, BaseLayerPatch] = {}
for layer_key, layer_state_dict in grouped_state_dict.items():
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)

return ModelPatchRaw(layers=layers)
5 changes: 5 additions & 0 deletions invokeai/backend/patches/lora_conversions/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
is_state_dict_likely_in_flux_onetrainer_format,
)
from invokeai.backend.patches.lora_conversions.flux_xlabs_lora_conversion_utils import (
is_state_dict_likely_in_flux_xlabs_format,
)


def flux_format_from_state_dict(
Expand All @@ -30,5 +33,7 @@ def flux_format_from_state_dict(
return FluxLoRAFormat.Control
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict, metadata):
return FluxLoRAFormat.AIToolkit
elif is_state_dict_likely_in_flux_xlabs_format(state_dict):
return FluxLoRAFormat.XLabs
else:
return None
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# A sample state dict in the xlabs FLUX LoRA format.
# The xlabs format uses:
# - lora1 for image attention stream (img_attn)
# - lora2 for text attention stream (txt_attn)
# - qkv for query/key/value projection
# - proj for output projection
state_dict_keys = {
"double_blocks.0.processor.proj_lora1.down.weight": [16, 3072],
"double_blocks.0.processor.proj_lora1.up.weight": [3072, 16],
"double_blocks.0.processor.proj_lora2.down.weight": [16, 3072],
"double_blocks.0.processor.proj_lora2.up.weight": [3072, 16],
"double_blocks.0.processor.qkv_lora1.down.weight": [16, 3072],
"double_blocks.0.processor.qkv_lora1.up.weight": [9216, 16],
"double_blocks.0.processor.qkv_lora2.down.weight": [16, 3072],
"double_blocks.0.processor.qkv_lora2.up.weight": [9216, 16],
"double_blocks.1.processor.proj_lora1.down.weight": [16, 3072],
"double_blocks.1.processor.proj_lora1.up.weight": [3072, 16],
"double_blocks.1.processor.proj_lora2.down.weight": [16, 3072],
"double_blocks.1.processor.proj_lora2.up.weight": [3072, 16],
"double_blocks.1.processor.qkv_lora1.down.weight": [16, 3072],
"double_blocks.1.processor.qkv_lora1.up.weight": [9216, 16],
"double_blocks.1.processor.qkv_lora2.down.weight": [16, 3072],
"double_blocks.1.processor.qkv_lora2.up.weight": [9216, 16],
"double_blocks.10.processor.proj_lora1.down.weight": [16, 3072],
"double_blocks.10.processor.proj_lora1.up.weight": [3072, 16],
"double_blocks.10.processor.proj_lora2.down.weight": [16, 3072],
"double_blocks.10.processor.proj_lora2.up.weight": [3072, 16],
"double_blocks.10.processor.qkv_lora1.down.weight": [16, 3072],
"double_blocks.10.processor.qkv_lora1.up.weight": [9216, 16],
"double_blocks.10.processor.qkv_lora2.down.weight": [16, 3072],
"double_blocks.10.processor.qkv_lora2.up.weight": [9216, 16],
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import accelerate
import pytest
import torch

from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import get_flux_transformers_params
from invokeai.backend.model_manager.taxonomy import FluxVariantType
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.lora_conversions.flux_xlabs_lora_conversion_utils import (
is_state_dict_likely_in_flux_xlabs_format,
lora_model_from_flux_xlabs_state_dict,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_kohya_format import (
state_dict_keys as flux_kohya_state_dict_keys,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_xlabs_format import (
state_dict_keys as flux_xlabs_state_dict_keys,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.utils import keys_to_mock_state_dict


def test_is_state_dict_likely_in_flux_xlabs_format_true():
"""Test that is_state_dict_likely_in_flux_xlabs_format() can identify a state dict in the xlabs FLUX LoRA format."""
state_dict = keys_to_mock_state_dict(flux_xlabs_state_dict_keys)
assert is_state_dict_likely_in_flux_xlabs_format(state_dict)


@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys, flux_kohya_state_dict_keys])
def test_is_state_dict_likely_in_flux_xlabs_format_false(sd_keys: dict[str, list[int]]):
"""Test that is_state_dict_likely_in_flux_xlabs_format() returns False for state dicts in other formats."""
state_dict = keys_to_mock_state_dict(sd_keys)
assert not is_state_dict_likely_in_flux_xlabs_format(state_dict)


def test_lora_model_from_flux_xlabs_state_dict():
"""Test that a ModelPatchRaw can be created from a state dict in the xlabs FLUX LoRA format."""
state_dict = keys_to_mock_state_dict(flux_xlabs_state_dict_keys)

lora_model = lora_model_from_flux_xlabs_state_dict(state_dict)

# Verify the expected layer keys are created
expected_layer_keys = {
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.0.img_attn.proj",
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.0.img_attn.qkv",
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.0.txt_attn.proj",
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.0.txt_attn.qkv",
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.1.img_attn.proj",
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.1.img_attn.qkv",
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.1.txt_attn.proj",
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.1.txt_attn.qkv",
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.10.img_attn.proj",
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.10.img_attn.qkv",
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.10.txt_attn.proj",
f"{FLUX_LORA_TRANSFORMER_PREFIX}double_blocks.10.txt_attn.qkv",
}

assert set(lora_model.layers.keys()) == expected_layer_keys


def test_lora_model_from_flux_xlabs_state_dict_matches_model_keys():
"""Test that the converted xlabs LoRA keys match the actual FLUX model keys."""
state_dict = keys_to_mock_state_dict(flux_xlabs_state_dict_keys)

lora_model = lora_model_from_flux_xlabs_state_dict(state_dict)

# Extract the layer prefixes (without the lora_transformer- prefix)
converted_key_prefixes: list[str] = []
for k in lora_model.layers.keys():
# Remove the transformer prefix
k = k.replace(FLUX_LORA_TRANSFORMER_PREFIX, "")
converted_key_prefixes.append(k)

# Initialize a FLUX model on the meta device.
with accelerate.init_empty_weights():
model = Flux(get_flux_transformers_params(FluxVariantType.Schnell))
model_keys = set(model.state_dict().keys())

# Assert that the converted keys match prefixes in the actual model.
for converted_key_prefix in converted_key_prefixes:
found_match = False
for model_key in model_keys:
if model_key.startswith(converted_key_prefix):
found_match = True
break
if not found_match:
raise AssertionError(f"Could not find a match for the converted key prefix: {converted_key_prefix}")


def test_lora_model_from_flux_xlabs_state_dict_error():
"""Test that an error is raised if the input state_dict contains unexpected keys."""
state_dict = {
"unexpected_key.down.weight": torch.empty(1),
}

with pytest.raises(ValueError):
lora_model_from_flux_xlabs_state_dict(state_dict)