Skip to content

Commit

Permalink
improve test, add extract to api
Browse files Browse the repository at this point in the history
  • Loading branch information
glucauze committed Jul 29, 2023
1 parent b6add28 commit be505f4
Show file tree
Hide file tree
Showing 10 changed files with 228 additions and 56 deletions.
37 changes: 36 additions & 1 deletion client_api/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from io import BytesIO
from typing import List, Tuple, Optional
import numpy as np
import requests


class InpaintingWhen(Enum):
Expand Down Expand Up @@ -151,7 +152,7 @@ class FaceSwapRequest(BaseModel):

class FaceSwapResponse(BaseModel):
images: List[str] = Field(description="base64 swapped image", default=None)
infos: List[str]
infos: Optional[List[str]] # not really used atm

@property
def pil_images(self) -> Image.Image:
Expand All @@ -171,6 +172,23 @@ class FaceSwapCompareRequest(BaseModel):
)


class FaceSwapExtractRequest(BaseModel):
images: List[str] = Field(
description="base64 reference image",
examples=["data:image/jpeg;base64,/9j/4AAQSkZJRgABAQECWAJYAAD...."],
default=None,
)
postprocessing: Optional[PostProcessingOptions]


class FaceSwapExtractResponse(BaseModel):
images: List[str] = Field(description="base64 face images", default=None)

@property
def pil_images(self) -> Image.Image:
return [base64_to_pil(img) for img in self.images]


def pil_to_base64(img: Image.Image) -> np.array: # type:ignore
if isinstance(img, str):
img = Image.open(img)
Expand All @@ -192,3 +210,20 @@ def base64_to_pil(base64str: Optional[str]) -> Optional[Image.Image]:
# if no data URL scheme, just decode
img_bytes = base64.b64decode(base64str)
return Image.open(io.BytesIO(img_bytes))


def compare_faces(
image1: Image.Image, image2: Image.Image, base_url: str = "http://localhost:7860"
) -> float:
request = FaceSwapCompareRequest(
image1=pil_to_base64(image1),
image2=pil_to_base64(image2),
)

result = requests.post(
url=f"{base_url}/faceswaplab/compare",
data=request.json(),
headers={"Content-Type": "application/json; charset=utf-8"},
)

return float(result.text)
28 changes: 27 additions & 1 deletion client_api/faceswaplab_api_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,16 @@
pil_to_base64,
InpaintingWhen,
FaceSwapCompareRequest,
FaceSwapExtractRequest,
FaceSwapExtractResponse,
)

address = "http://127.0.0.1:7860"


#############################
# FaceSwap

# First face unit :
unit1 = FaceSwapUnit(
source_img=pil_to_base64("../references/man.png"), # The face you want to use
Expand Down Expand Up @@ -41,7 +47,7 @@
image=pil_to_base64("test_image.png"), units=[unit1, unit2], postprocessing=pp
)


# Face Swap
result = requests.post(
url=f"{address}/faceswaplab/swap_face",
data=request.json(),
Expand All @@ -52,6 +58,8 @@
for img in response.pil_images:
img.show()

#############################
# Comparison

request = FaceSwapCompareRequest(
image1=pil_to_base64("../references/man.png"),
Expand All @@ -65,3 +73,21 @@
)

print("similarity", result.text)

#############################
# Extraction

# Prepare the request
request = FaceSwapExtractRequest(
images=[pil_to_base64(response.pil_images[0])], postprocessing=pp
)

result = requests.post(
url=f"{address}/faceswaplab/extract",
data=request.json(),
headers={"Content-Type": "application/json; charset=utf-8"},
)
response = FaceSwapExtractResponse.parse_obj(result.json())

for img in response.pil_images:
img.show()
2 changes: 1 addition & 1 deletion scripts/faceswaplab.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def ui(self, is_img2img: bool) -> List[gr.components.Component]:
components = []
for i in range(1, self.units_count + 1):
components += faceswaplab_unit_ui.faceswap_unit_ui(is_img2img, i)
upscaler = faceswaplab_tab.upscaler_ui()
upscaler = faceswaplab_tab.postprocessing_ui()
# If the order is modified, the before_process should be changed accordingly.
return components + upscaler

Expand Down
19 changes: 19 additions & 0 deletions scripts/faceswaplab_api/faceswaplab_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,22 @@ async def compare(
return swapper.compare_faces(
base64_to_pil(request.image1), base64_to_pil(request.image2)
)

@app.post(
"/faceswaplab/extract",
tags=["faceswaplab"],
description="Extract faces of each images",
)
async def extract(
request: api_utils.FaceSwapExtractRequest,
) -> api_utils.FaceSwapExtractResponse:
pp_options = None
if request.postprocessing:
pp_options = get_postprocessing_options(request.postprocessing)
images = [base64_to_pil(img) for img in request.images]
faces = swapper.extract_faces(
images, extract_path=None, postprocess_options=pp_options
)
result_images = [encode_to_base64(img) for img in faces]
response = api_utils.FaceSwapExtractResponse(images=result_images)
return response
28 changes: 27 additions & 1 deletion scripts/faceswaplab_settings/faceswaplab_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,15 @@ def on_ui_settings() -> None:
"faceswaplab_detection_threshold",
shared.OptionInfo(
0.5,
"Detection threshold ",
"Face Detection threshold",
gr.Slider,
{"minimum": 0.1, "maximum": 0.99, "step": 0.001},
section=section,
),
)

# DEFAULT UI SETTINGS

shared.opts.add_option(
"faceswaplab_pp_default_face_restorer",
shared.OptionInfo(
Expand Down Expand Up @@ -105,6 +107,30 @@ def on_ui_settings() -> None:
),
)

shared.opts.add_option(
"faceswaplab_pp_default_inpainting_prompt",
shared.OptionInfo(
"Portrait of a [gender]",
"UI Default inpainting prompt [gender] is replaced by man or woman (requires restart)",
gr.Textbox,
{},
section=section,
),
)

shared.opts.add_option(
"faceswaplab_pp_default_inpainting_negative_prompt",
shared.OptionInfo(
"blurry",
"UI Default inpainting negative prompt [gender] (requires restart)",
gr.Textbox,
{},
section=section,
),
)

# UPSCALED SWAPPER

shared.opts.add_option(
"faceswaplab_upscaled_swapper",
shared.OptionInfo(
Expand Down
65 changes: 65 additions & 0 deletions scripts/faceswaplab_swapping/swapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,71 @@ def batch_process(
return None


def extract_faces(
images: List[Image.Image],
extract_path: Optional[str],
postprocess_options: PostProcessingOptions,
) -> Optional[List[str]]:
"""
Extracts faces from a list of image files.
Given a list of image file paths, this function opens each image, extracts the faces,
and saves them in a specified directory. Post-processing is applied to each extracted face,
and the processed faces are saved as separate PNG files.
Parameters:
files (Optional[List[Image]]): List of file paths to the images to extract faces from.
extract_path (Optional[str]): Path where the extracted faces will be saved.
If no path is provided, a temporary directory will be created.
postprocess_options (PostProcessingOptions): Post-processing settings to be applied to the images.
Returns:
Optional[List[img]]: List of face images
"""

try:
if extract_path:
os.makedirs(extract_path, exist_ok=True)

if images:
result_images = []
for img in images:
faces = get_faces(pil_to_cv2(img))

if faces:
face_images = []
for face in faces:
bbox = face.bbox.astype(int)
x_min, y_min, x_max, y_max = bbox
face_image = img.crop((x_min, y_min, x_max, y_max))

if postprocess_options and (
postprocess_options.face_restorer_name
or postprocess_options.restorer_visibility
):
postprocess_options.scale = (
1 if face_image.width > 512 else 512 // face_image.width
)
face_image = enhance_image(face_image, postprocess_options)

if extract_path:
path = tempfile.NamedTemporaryFile(
delete=False, suffix=".png", dir=extract_path
).name
face_image.save(path)
face_images.append(face_image)

result_images += face_images

return result_images
except Exception as e:
logger.info("Failed to extract : %s", e)
import traceback

traceback.print_exc()
return None


class FaceModelException(Exception):
"""Exception raised when an error is encountered in the face model."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from scripts.faceswaplab_postprocessing.postprocessing_options import InpaintingWhen


def upscaler_ui() -> List[gr.components.Component]:
def postprocessing_ui() -> List[gr.components.Component]:
with gr.Tab(f"Post-Processing"):
gr.Markdown(
"""Upscaling is performed on the whole image. Upscaling happens before face restoration."""
Expand Down Expand Up @@ -87,12 +87,16 @@ def upscaler_ui() -> List[gr.components.Component]:
)

inpainting_denoising_prompt = gr.Textbox(
"Portrait of a [gender]",
opts.data.get(
"faceswaplab_pp_default_inpainting_prompt", "Portrait of a [gender]"
),
elem_id="faceswaplab_pp_inpainting_denoising_prompt",
label="Inpainting prompt use [gender] instead of men or woman",
)
inpainting_denoising_negative_prompt = gr.Textbox(
"",
opts.data.get(
"faceswaplab_pp_default_inpainting_negative_prompt", "blurry"
),
elem_id="faceswaplab_pp_inpainting_denoising_neg_prompt",
label="Inpainting negative prompt use [gender] instead of men or woman",
)
Expand Down
58 changes: 10 additions & 48 deletions scripts/faceswaplab_ui/faceswaplab_tab.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import tempfile
from pprint import pformat, pprint

import dill as pickle
Expand All @@ -8,14 +7,13 @@
import onnx
import pandas as pd
from scripts.faceswaplab_ui.faceswaplab_unit_ui import faceswap_unit_ui
from scripts.faceswaplab_ui.faceswaplab_upscaler_ui import upscaler_ui
from scripts.faceswaplab_ui.faceswaplab_postprocessing_ui import postprocessing_ui
from insightface.app.common import Face
from modules import scripts
from PIL import Image
from modules.shared import opts

from scripts.faceswaplab_utils import imgutils
from scripts.faceswaplab_utils.imgutils import pil_to_cv2
from scripts.faceswaplab_utils.models_utils import get_models
from scripts.faceswaplab_utils.faceswaplab_logging import logger
import scripts.faceswaplab_swapping.swapper as swapper
Expand Down Expand Up @@ -54,7 +52,7 @@ def extract_faces(
files: List[gr.File],
extract_path: Optional[str],
*components: List[gr.components.Component],
) -> Optional[List[str]]:
) -> Optional[List[Image.Image]]:
"""
Extracts faces from a list of image files.
Expand All @@ -73,49 +71,13 @@ def extract_faces(
If no faces are found, None is returned.
"""

try:
postprocess_options = PostProcessingOptions(*components) # type: ignore

if not extract_path:
extract_path = tempfile.mkdtemp()

if files:
images = []
for file in files:
img = Image.open(file.name)
faces = swapper.get_faces(pil_to_cv2(img))

if faces:
face_images = []
for face in faces:
bbox = face.bbox.astype(int)
x_min, y_min, x_max, y_max = bbox
face_image = img.crop((x_min, y_min, x_max, y_max))

if (
postprocess_options.face_restorer_name
or postprocess_options.restorer_visibility
):
postprocess_options.scale = (
1 if face_image.width > 512 else 512 // face_image.width
)
face_image = enhance_image(face_image, postprocess_options)

path = tempfile.NamedTemporaryFile(
delete=False, suffix=".png", dir=extract_path
).name
face_image.save(path)
face_images.append(path)

images += face_images

return images
except Exception as e:
logger.info("Failed to extract : %s", e)
import traceback

traceback.print_exc()
return None
postprocess_options = PostProcessingOptions(*components) # type: ignore
images = [
Image.open(file.name) for file in files
] # potentially greedy but Image.open is supposed to be lazy
return swapper.extract_faces(
images, extract_path=extract_path, postprocess_options=postprocess_options
)


def analyse_faces(image: Image.Image, det_threshold: float = 0.5) -> Optional[str]:
Expand Down Expand Up @@ -459,7 +421,7 @@ def tools_ui() -> None:
for i in range(1, opts.data.get("faceswaplab_units_count", 3) + 1):
unit_components += faceswap_unit_ui(False, i, id_prefix="faceswaplab_tab")

upscale_options = upscaler_ui()
upscale_options = postprocessing_ui()

explore_btn.click(
explore_onnx_faceswap_model, inputs=[model], outputs=[explore_result_text]
Expand Down
Loading

0 comments on commit be505f4

Please sign in to comment.