@@ -343,10 +343,10 @@ def model_lora_keys_unet(model, key_map={}):
343
343
return key_map
344
344
345
345
346
- def weight_decompose (dora_scale , weight , lora_diff , alpha , strength , intermediate_dtype ):
346
+ def weight_decompose (dora_scale , weight , lora_diff , alpha , strength , intermediate_dtype , function ):
347
347
dora_scale = comfy .model_management .cast_to_device (dora_scale , weight .device , intermediate_dtype )
348
348
lora_diff *= alpha
349
- weight_calc = weight + lora_diff .type (weight .dtype )
349
+ weight_calc = weight + function ( lora_diff ) .type (weight .dtype )
350
350
weight_norm = (
351
351
weight_calc .transpose (0 , 1 )
352
352
.reshape (weight_calc .shape [1 ], - 1 )
@@ -453,7 +453,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
453
453
try :
454
454
lora_diff = torch .mm (mat1 .flatten (start_dim = 1 ), mat2 .flatten (start_dim = 1 )).reshape (weight .shape )
455
455
if dora_scale is not None :
456
- weight = function ( weight_decompose (dora_scale , weight , lora_diff , alpha , strength , intermediate_dtype ) )
456
+ weight = weight_decompose (dora_scale , weight , lora_diff , alpha , strength , intermediate_dtype , function )
457
457
else :
458
458
weight += function (((strength * alpha ) * lora_diff ).type (weight .dtype ))
459
459
except Exception as e :
@@ -499,7 +499,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
499
499
try :
500
500
lora_diff = torch .kron (w1 , w2 ).reshape (weight .shape )
501
501
if dora_scale is not None :
502
- weight = function ( weight_decompose (dora_scale , weight , lora_diff , alpha , strength , intermediate_dtype ) )
502
+ weight = weight_decompose (dora_scale , weight , lora_diff , alpha , strength , intermediate_dtype , function )
503
503
else :
504
504
weight += function (((strength * alpha ) * lora_diff ).type (weight .dtype ))
505
505
except Exception as e :
@@ -536,7 +536,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
536
536
try :
537
537
lora_diff = (m1 * m2 ).reshape (weight .shape )
538
538
if dora_scale is not None :
539
- weight = function ( weight_decompose (dora_scale , weight , lora_diff , alpha , strength , intermediate_dtype ) )
539
+ weight = weight_decompose (dora_scale , weight , lora_diff , alpha , strength , intermediate_dtype , function )
540
540
else :
541
541
weight += function (((strength * alpha ) * lora_diff ).type (weight .dtype ))
542
542
except Exception as e :
@@ -577,7 +577,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
577
577
lora_diff += torch .mm (b1 , b2 ).reshape (weight .shape )
578
578
579
579
if dora_scale is not None :
580
- weight = function ( weight_decompose (dora_scale , weight , lora_diff , alpha , strength , intermediate_dtype ) )
580
+ weight = weight_decompose (dora_scale , weight , lora_diff , alpha , strength , intermediate_dtype , function )
581
581
else :
582
582
weight += function (((strength * alpha ) * lora_diff ).type (weight .dtype ))
583
583
except Exception as e :
0 commit comments