From a3864b2ee4028bf55605fec92dbf0500dc4998fa Mon Sep 17 00:00:00 2001 From: ljleb Date: Wed, 18 Oct 2023 17:49:22 -0400 Subject: [PATCH] update presets --- lib_free_u/global_state.py | 10 +++++----- lib_free_u/unet.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/lib_free_u/global_state.py b/lib_free_u/global_state.py index 6bef786..bac0134 100644 --- a/lib_free_u/global_state.py +++ b/lib_free_u/global_state.py @@ -143,22 +143,22 @@ def apply_xyz(): default_presets = { "SD1.4 Recommendations": State( stage_infos=[ - StageInfo(1.2, 0.9), + StageInfo(1.3, 0.9), StageInfo(1.4, 0.2), StageInfo(1, 1), ], ), "SD2.1 Recommendations": State( stage_infos=[ - StageInfo(1.1, 0.9), - StageInfo(1.2, 0.2), + StageInfo(1.4, 0.9), + StageInfo(1.6, 0.2), StageInfo(1, 1), ], ), "SDXL Recommendations": State( stage_infos=[ - StageInfo(1.1, 0.6), - StageInfo(1.2, 0.4), + StageInfo(1.3, 0.9), + StageInfo(1.4, 0.2), StageInfo(1, 1), ], ), diff --git a/lib_free_u/unet.py b/lib_free_u/unet.py index dd5589d..6494429 100644 --- a/lib_free_u/unet.py +++ b/lib_free_u/unet.py @@ -62,7 +62,7 @@ def free_u_cat_hijack(hs, *args, original_function, **kwargs): h[:, mask] *= get_backbone_scale( h, - backbone_factor=lerp(1, stage_info.backbone_factor, schedule_ratio), + base_scale=lerp(1, stage_info.backbone_factor, schedule_ratio), ) h_skip = filter_skip( h_skip, @@ -74,9 +74,9 @@ def free_u_cat_hijack(hs, *args, original_function, **kwargs): return original_function([h, h_skip], *args, **kwargs) -def get_backbone_scale(h, backbone_factor): +def get_backbone_scale(h, base_scale): if global_state.instance.version == "1": - return backbone_factor + return base_scale #if global_state.instance.version == "2": features_mean = h.mean(1, keepdim=True) @@ -84,7 +84,7 @@ def get_backbone_scale(h, backbone_factor): features_max, _ = torch.max(features_mean.view(batch_dims, -1), dim=-1, keepdim=True) features_min, _ = torch.min(features_mean.view(batch_dims, -1), dim=-1, keepdim=True) hidden_mean = (features_mean - features_min.unsqueeze(2).unsqueeze(3)) / (features_max - features_min).unsqueeze(2).unsqueeze(3) - return 1 + (backbone_factor - 1) * hidden_mean + return 1 + (base_scale - 1) * hidden_mean def filter_skip(x, threshold, scale, scale_high):