diff --git a/scripts/controlnet_ui/controlnet_ui_group.py b/scripts/controlnet_ui/controlnet_ui_group.py index e8351a60a..72f3f3e45 100644 --- a/scripts/controlnet_ui/controlnet_ui_group.py +++ b/scripts/controlnet_ui/controlnet_ui_group.py @@ -2,7 +2,7 @@ import gradio as gr import functools from copy import copy -from typing import List, Optional, Union, Callable +from typing import List, Optional, Union, Callable, Dict import numpy as np from scripts.utils import svg_preprocess @@ -39,11 +39,12 @@ def __init__( loopback: bool = False, use_preview_as_input: bool = False, generated_image: Optional[np.ndarray] = None, + mask_image: Optional[np.ndarray] = None, enabled: bool = True, module: Optional[str] = None, model: Optional[str] = None, weight: float = 1.0, - image: Optional[np.ndarray] = None, + image: Optional[Dict[str, np.ndarray]] = None, *args, **kwargs, ): @@ -53,6 +54,11 @@ def __init__( else: input_image = image + # Prefer uploaded mask_image over hand-drawn mask. + if input_image is not None and mask_image is not None: + assert isinstance(input_image, dict) + input_image["mask"] = mask_image + super().__init__(enabled, module, model, weight, input_image, *args, **kwargs) self.is_ui = True self.input_mode = input_mode @@ -120,6 +126,8 @@ def __init__( self.image = None self.generated_image_group = None self.generated_image = None + self.mask_image_group = None + self.mask_image = None self.batch_tab = None self.batch_image_dir = None self.create_canvas = None @@ -135,6 +143,7 @@ def __init__( self.low_vram = None self.pixel_perfect = None self.preprocessor_preview = None + self.mask_upload = None self.type_filter = None self.module = None self.trigger_preprocessor = None @@ -229,6 +238,15 @@ def render(self, tabname: str, elem_id_tabname: str, is_img2img: bool) -> None: elem_classes=["cnet-close-preview"], ) + with gr.Group(visible=False, elem_classes=["cnet-mask-image-group"]) as self.mask_image_group: + self.mask_image = gr.Image( + value=None, + label="Upload Mask", + elem_id=f"{elem_id_tabname}_{tabname}_mask_image", + elem_classes=["cnet-mask-image"], + interactive=True, + ) + with gr.Tab(label="Batch") as self.batch_tab: self.batch_image_dir = gr.Textbox( label="Input Directory", @@ -318,6 +336,13 @@ def render(self, tabname: str, elem_id_tabname: str, is_img2img: bool) -> None: elem_id=preview_check_elem_id, visible=not is_img2img, ) + self.mask_upload = gr.Checkbox( + label="Mask Upload", + value=False, + elem_classes=["cnet-mask-upload"], + elem_id=f"{elem_id_tabname}_{tabname}_controlnet_mask_upload_checkbox", + visible=not is_img2img, + ) self.use_preview_as_input = gr.Checkbox( label="Preview as Input", value=False, @@ -890,6 +915,33 @@ def register_shift_hr_options(self): show_progress=False ) + def register_shift_upload_mask(self): + """ Controls whether the upload mask input should be visible. """ + self.mask_upload.change( + fn=lambda checked: ( + # Clear mask_image if unchecked. + (gr.update(visible=False), gr.update(value=None)) + if not checked else + (gr.update(visible=True), gr.update()) + ), + inputs=[self.mask_upload], + outputs=[self.mask_image_group, self.mask_image], + show_progress=False, + ) + + if self.upload_independent_img_in_img2img is not None: + self.upload_independent_img_in_img2img.change( + fn=lambda checked: ( + # Uncheck `upload_mask` when not using independent input. + gr.update(visible=False, value=False) + if not checked else + gr.update(visible=True) + ), + inputs=[self.upload_independent_img_in_img2img], + outputs=[self.mask_upload], + show_progress=False, + ) + def register_callbacks(self, is_img2img: bool): """Register callbacks on the UI elements. @@ -906,6 +958,7 @@ def register_callbacks(self, is_img2img: bool): self.register_build_sliders() self.register_run_annotator(is_img2img) self.register_shift_preview() + self.register_shift_upload_mask() self.register_create_canvas() self.openpose_editor.register_callbacks( self.generated_image, self.use_preview_as_input, @@ -950,6 +1003,7 @@ def render_and_register_unit(self, tabname: str, is_img2img: bool): # They are only used during object construction. self.use_preview_as_input, self.generated_image, + self.mask_image, # End of Non-persistent fields. self.enabled, self.module,