From 04a0a5c46cfe9eef9718ca7043463fd0e819e0c0 Mon Sep 17 00:00:00 2001 From: Tran Xen <137925069+glucauze@users.noreply.github.com> Date: Thu, 17 Aug 2023 10:41:04 +0200 Subject: [PATCH] fix bug in improved mask --- .../upscaled_inswapper.py | 36 ++++++++++++------- scripts/faceswaplab_utils/imgutils.py | 3 +- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/scripts/faceswaplab_swapping/upscaled_inswapper.py b/scripts/faceswaplab_swapping/upscaled_inswapper.py index 30f4907..e381208 100644 --- a/scripts/faceswaplab_swapping/upscaled_inswapper.py +++ b/scripts/faceswaplab_swapping/upscaled_inswapper.py @@ -15,7 +15,7 @@ from scripts.faceswaplab_utils.imgutils import cv2_to_pil, pil_to_cv2 from scripts.faceswaplab_utils.sd_utils import get_sd_option from scripts.faceswaplab_utils.typing import CV2ImgU8, Face -from scripts.faceswaplab_utils.faceswaplab_logging import logger +from scripts.faceswaplab_utils.faceswaplab_logging import logger, save_img_debug def get_upscaler() -> Optional[UpscalerData]: @@ -216,18 +216,11 @@ def compute_diff(bgr_fake: CV2ImgU8, aimg: CV2ImgU8) -> CV2ImgU8: bgr_fake, inswapper_options=options, k=k ) - if options.improved_mask: - if k == 1: - logger.warning( - "Please note that improved mask does not work well without upscaling. Set upscaling to Lanczos at least if you want speed and want to use improved mask." - ) - - logger.info("improved_mask") - mask = get_face_mask(aimg, bgr_fake) - bgr_fake = merge_images_with_mask(aimg, bgr_fake, mask) + fake_diff: CV2ImgU8 = None # type: ignore - # compute fake_diff before sharpen and color correction (better result) - fake_diff = compute_diff(bgr_fake, aimg) + if not options.improved_mask: + # If improved mask is not used, we should compute before sharpen and color correction (better diff) + fake_diff = compute_diff(bgr_fake, aimg=aimg) if options.sharpen: logger.info("sharpen") @@ -244,6 +237,24 @@ def compute_diff(bgr_fake: CV2ImgU8, aimg: CV2ImgU8) -> CV2ImgU8: ) bgr_fake = pil_to_cv2(bgr_fake_pil) + if options.improved_mask: + if k == 1: + logger.warning( + "Please note that improved mask does not work well without upscaling. Set upscaling to Lanczos at least if you want speed and want to use improved mask." + ) + + logger.info("improved_mask") + mask = get_face_mask(aimg, bgr_fake) + # save_img_debug(cv2_to_pil(bgr_fake), "Before Mask") + bgr_fake = merge_images_with_mask(aimg, bgr_fake, mask) + # save_img_debug(cv2_to_pil(bgr_fake), "After Mask") + + fake_diff = compute_diff(bgr_fake, aimg=aimg) + + assert ( + fake_diff is not None + ), "fake diff is None, this should not happen" + logger.info("*" * 80) else: @@ -266,6 +277,7 @@ def compute_diff(bgr_fake: CV2ImgU8, aimg: CV2ImgU8) -> CV2ImgU8: (target_img.shape[1], target_img.shape[0]), borderValue=0.0, ) + fake_diff = cv2.warpAffine( fake_diff, IM, diff --git a/scripts/faceswaplab_utils/imgutils.py b/scripts/faceswaplab_utils/imgutils.py index 2306253..0112ece 100644 --- a/scripts/faceswaplab_utils/imgutils.py +++ b/scripts/faceswaplab_utils/imgutils.py @@ -5,7 +5,6 @@ import numpy as np from math import isqrt, ceil import torch -from ifnude import detect from modules import processing import base64 from collections import Counter @@ -31,6 +30,8 @@ def check_against_nsfw(img: PILImage) -> bool: if NSFW_SCORE_THRESHOLD >= 1: return False + from ifnude import detect + shapes: List[bool] = [] chunks: List[Dict[str, Union[int, float]]] = detect(img)