Skip to content

Commit 203942c

Browse files
Fix flux doras with diffusers keys.
1 parent 3c72c89 commit 203942c

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

comfy/lora.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -343,10 +343,10 @@ def model_lora_keys_unet(model, key_map={}):
343343
return key_map
344344

345345

346-
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
346+
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
347347
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
348348
lora_diff *= alpha
349-
weight_calc = weight + lora_diff.type(weight.dtype)
349+
weight_calc = weight + function(lora_diff).type(weight.dtype)
350350
weight_norm = (
351351
weight_calc.transpose(0, 1)
352352
.reshape(weight_calc.shape[1], -1)
@@ -453,7 +453,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
453453
try:
454454
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
455455
if dora_scale is not None:
456-
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
456+
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
457457
else:
458458
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
459459
except Exception as e:
@@ -499,7 +499,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
499499
try:
500500
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
501501
if dora_scale is not None:
502-
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
502+
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
503503
else:
504504
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
505505
except Exception as e:
@@ -536,7 +536,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
536536
try:
537537
lora_diff = (m1 * m2).reshape(weight.shape)
538538
if dora_scale is not None:
539-
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
539+
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
540540
else:
541541
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
542542
except Exception as e:
@@ -577,7 +577,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
577577
lora_diff += torch.mm(b1, b2).reshape(weight.shape)
578578

579579
if dora_scale is not None:
580-
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
580+
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
581581
else:
582582
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
583583
except Exception as e:

0 commit comments

Comments
 (0)