From dce049d4dd659bf012472737227541d28d8dc965 Mon Sep 17 00:00:00 2001 From: Elvis Claros Castro Date: Fri, 13 Sep 2024 09:16:01 -0300 Subject: [PATCH] feat(core): add reference face path argument - Added `--reference-face-path` argument to `parse_args` function. - Updated `roop.globals` to include `reference_face_path`. - Modified `process_video` to use `reference_face_path` if provided. - Iterates over `temp_frame_paths` to find a valid reference face. --- roop/core.py | 3 ++- roop/globals.py | 1 + roop/processors/frame/face_swapper.py | 15 +++++++++++++-- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/roop/core.py b/roop/core.py index 7e5a46fbd..71c29f752 100755 --- a/roop/core.py +++ b/roop/core.py @@ -38,6 +38,7 @@ def parse_args() -> None: program.add_argument('--skip-audio', help='skip target audio', dest='skip_audio', action='store_true') program.add_argument('--many-faces', help='process every face', dest='many_faces', action='store_true') program.add_argument('--reference-face-position', help='position of the reference face', dest='reference_face_position', type=int, default=0) + program.add_argument('--reference-face-path', help='reference face path', dest='face_reference_path') program.add_argument('--reference-frame-number', help='number of the reference frame', dest='reference_frame_number', type=int, default=0) program.add_argument('--similar-face-distance', help='face distance used for recognition', dest='similar_face_distance', type=float, default=0.85) program.add_argument('--temp-frame-format', help='image format used for frame extraction', dest='temp_frame_format', default='png', choices=['jpg', 'png']) @@ -62,6 +63,7 @@ def parse_args() -> None: roop.globals.many_faces = args.many_faces roop.globals.reference_face_position = args.reference_face_position roop.globals.reference_frame_number = args.reference_frame_number + roop.globals.reference_face_path = args.face_reference_path roop.globals.similar_face_distance = args.similar_face_distance roop.globals.temp_frame_format = args.temp_frame_format roop.globals.temp_frame_quality = args.temp_frame_quality @@ -71,7 +73,6 @@ def parse_args() -> None: roop.globals.execution_providers = decode_execution_providers(args.execution_provider) roop.globals.execution_threads = args.execution_threads - def encode_execution_providers(execution_providers: List[str]) -> List[str]: return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers] diff --git a/roop/globals.py b/roop/globals.py index 3eca8d0d0..2f614e214 100644 --- a/roop/globals.py +++ b/roop/globals.py @@ -11,6 +11,7 @@ many_faces: Optional[bool] = None reference_face_position: Optional[int] = None reference_frame_number: Optional[int] = None +reference_face_path: Optional[str] = None similar_face_distance: Optional[float] = None temp_frame_format: Optional[str] = None temp_frame_quality: Optional[int] = None diff --git a/roop/processors/frame/face_swapper.py b/roop/processors/frame/face_swapper.py index da68956ef..02d5d2036 100644 --- a/roop/processors/frame/face_swapper.py +++ b/roop/processors/frame/face_swapper.py @@ -93,8 +93,19 @@ def process_image(source_path: str, target_path: str, output_path: str) -> None: def process_video(source_path: str, temp_frame_paths: List[str]) -> None: + + if roop.globals.reference_face_path: + reference_face = get_one_face(cv2.imread(roop.globals.reference_face_path)) + set_face_reference(reference_face) + + if not roop.globals.many_faces and not get_face_reference(): - reference_frame = cv2.imread(temp_frame_paths[roop.globals.reference_frame_number]) - reference_face = get_one_face(reference_frame, roop.globals.reference_face_position) + # recorro los temp_frame_paths hasta que reference_face sea distinto de nulo. + for temp_frame_path in temp_frame_paths: + temp_frame = cv2.imread(temp_frame_path) + reference_face = get_one_face(temp_frame, roop.globals.reference_face_position) + if reference_face: + break + set_face_reference(reference_face) roop.processors.frame.core.process_video(source_path, temp_frame_paths, process_frames)