Skip to content

Commit

Permalink
fix preview in build
Browse files Browse the repository at this point in the history
  • Loading branch information
glucauze committed Aug 16, 2023
1 parent afcfc7d commit 0499581
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 42 deletions.
3 changes: 2 additions & 1 deletion scripts/faceswaplab_api/faceswaplab_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from scripts.faceswaplab_swapping.face_checkpoints import (
build_face_checkpoint_and_save,
)
from scripts.faceswaplab_utils.typing import PILImage


def encode_to_base64(image: Union[str, Image.Image, np.ndarray]) -> str: # type: ignore
Expand Down Expand Up @@ -99,7 +100,7 @@ async def swap_face(
pp_options = None
units = get_faceswap_units_settings(request.units)

swapped_images = swapper.batch_process(
swapped_images: Optional[List[PILImage]] = swapper.batch_process(
[src_image], None, units=units, postprocess_options=pp_options
)

Expand Down
76 changes: 47 additions & 29 deletions scripts/faceswaplab_swapping/face_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@ def sanitize_name(name: str) -> str:


def build_face_checkpoint_and_save(
images: List[PILImage], name: str, overwrite: bool = False, path: str = None
) -> PILImage:
images: List[PILImage],
name: str,
overwrite: bool = False,
path: Optional[str] = None,
) -> Optional[PILImage]:
"""
Builds a face checkpoint using the provided image files, performs face swapping,
and saves the result to a file. If a blended face is successfully obtained and the face swapping
Expand All @@ -57,8 +60,12 @@ def build_face_checkpoint_and_save(
name = sanitize_name(name)
images = images or []
logger.info("Build %s with %s images", name, len(images))
faces = swapper.get_faces_from_img_files(images)
blended_face = swapper.blend_faces(faces)
faces: List[Face] = swapper.get_faces_from_img_files(images=images)
if faces is None or len(faces) == 0:
logger.error("No source faces found")
return None

blended_face: Optional[Face] = swapper.blend_faces(faces)
preview_path = os.path.join(
scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references"
)
Expand All @@ -85,40 +92,51 @@ def build_face_checkpoint_and_save(
"Failed to open reference image, cannot create preview : That should not happen unless you deleted the references folder or change the detection threshold."
)
else:
result = swapper.swap_face(
result: swapper.ImageResult = swapper.swap_face(
target_faces=[target_face],
source_face=blended_face,
target_img=reference_preview_img,
model=get_swap_models()[0],
swapping_options=InswappperOptions(face_restorer_name="Codeformer"),
swapping_options=InswappperOptions(
face_restorer_name="CodeFormer",
restorer_visibility=1,
upscaler_name="Lanczos",
codeformer_weight=1,
improved_mask=True,
color_corrections=False,
sharpen=True,
),
)
preview_image = result.image

if path:
file_path = path
else:
file_path = os.path.join(get_checkpoint_path(), f"{name}.safetensors")
if not overwrite:
file_number = 1
while os.path.exists(file_path):
file_path = os.path.join(
get_checkpoint_path(), f"{name}_{file_number}.safetensors"
)
file_number += 1
save_face(filename=file_path, face=blended_face)
preview_image.save(file_path + ".png")
try:
data = load_face(file_path)
logger.debug(data)
except Exception as e:
logger.error("Error loading checkpoint, after creation %s", e)
traceback.print_exc()

return preview_image
if path:
file_path = path
else:
file_path = os.path.join(
get_checkpoint_path(), f"{name}.safetensors"
)
if not overwrite:
file_number = 1
while os.path.exists(file_path):
file_path = os.path.join(
get_checkpoint_path(),
f"{name}_{file_number}.safetensors",
)
file_number += 1
save_face(filename=file_path, face=blended_face)
preview_image.save(file_path + ".png")
try:
data = load_face(file_path)
logger.debug(data)
except Exception as e:
logger.error("Error loading checkpoint, after creation %s", e)
traceback.print_exc()

return preview_image

else:
logger.error("No face found")
return None
return None # type: ignore
except Exception as e:
logger.error("Failed to build checkpoint %s", e)
traceback.print_exc()
Expand All @@ -139,7 +157,7 @@ def save_face(face: Face, filename: str) -> None:
raise e


def load_face(name: str) -> Face:
def load_face(name: str) -> Optional[Face]:
if name.startswith("data:application/face;base64,"):
with tempfile.NamedTemporaryFile(delete=True) as temp_file:
api_utils.base64_to_safetensors(name, temp_file.name)
Expand Down
16 changes: 8 additions & 8 deletions scripts/faceswaplab_swapping/swapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def batch_process(
src_images: List[Union[PILImage, str]], # image or filename
save_path: Optional[str],
units: List[FaceSwapUnitSettings],
postprocess_options: PostProcessingOptions,
postprocess_options: Optional[PostProcessingOptions],
) -> Optional[List[PILImage]]:
"""
Process a batch of images, apply face swapping according to the given settings, and optionally save the resulting images to a specified path.
Expand Down Expand Up @@ -527,7 +527,7 @@ def get_or_default(l: List[Any], index: int, default: Any) -> Any:
return l[index] if index < len(l) else default


def get_faces_from_img_files(images: List[PILImage]) -> List[Optional[CV2ImgU8]]:
def get_faces_from_img_files(images: List[PILImage]) -> List[Face]:
"""
Extracts faces from a list of image files.
Expand All @@ -539,7 +539,7 @@ def get_faces_from_img_files(images: List[PILImage]) -> List[Optional[CV2ImgU8]]
"""

faces = []
faces: List[Face] = []

if len(images) > 0:
for img in images:
Expand Down Expand Up @@ -598,7 +598,6 @@ def swap_face(
target_faces: List[Face],
model: str,
swapping_options: Optional[InswappperOptions],
compute_similarity: bool = True,
) -> ImageResult:
"""
Swaps faces in the target image with the source face.
Expand Down Expand Up @@ -680,9 +679,9 @@ def process_image_unit(
model: str,
unit: FaceSwapUnitSettings,
image: PILImage,
info: str = None,
info: Optional[str] = None,
force_blend: bool = False,
) -> List[Tuple[PILImage, str]]:
) -> List[Tuple[PILImage, Optional[str]]]:
"""Process one image and return a List of (image, info) (one if blended, many if not).
Args:
Expand Down Expand Up @@ -723,7 +722,9 @@ def process_image_unit(
sort_by_face_size=unit.sort_by_size,
)

target_faces = filter_faces(faces, filtering_options=face_filtering_options)
target_faces: List[Face] = filter_faces(
all_faces=faces, filtering_options=face_filtering_options
)

# Apply pre-inpainting to image
if unit.pre_inpainting.inpainting_denoising_strengh > 0:
Expand All @@ -738,7 +739,6 @@ def process_image_unit(
target_faces=target_faces,
model=model,
swapping_options=unit.swapping_options,
compute_similarity=unit.compute_similarity,
)
# Apply post-inpainting to image
if unit.post_inpainting.inpainting_denoising_strengh > 0:
Expand Down
10 changes: 6 additions & 4 deletions scripts/faceswaplab_ui/faceswaplab_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,18 @@ def build_face_checkpoint_and_save(
if not batch_files:
logger.error("No face found")
return None # type: ignore (Optional not really supported by old gradio)
images = [Image.open(file.name) for file in batch_files] # type: ignore
preview_image = face_checkpoints.build_face_checkpoint_and_save(
images, name, overwrite=overwrite
images: list[PILImage] = [Image.open(file.name) for file in batch_files] # type: ignore
preview_image: PILImage | None = (
face_checkpoints.build_face_checkpoint_and_save(
images=images, name=name, overwrite=overwrite
)
)
except Exception as e:
logger.error("Failed to build checkpoint %s", e)

traceback.print_exc()
return None # type: ignore
return preview_image
return preview_image # type: ignore


def explore_onnx_faceswap_model(model_path: str) -> pd.DataFrame:
Expand Down

0 comments on commit 0499581

Please sign in to comment.