Skip to content

Commit

Permalink
✨ Allow mask upload
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei committed Jan 13, 2024
1 parent d56f1b8 commit 7079906
Showing 1 changed file with 56 additions and 2 deletions.
58 changes: 56 additions & 2 deletions scripts/controlnet_ui/controlnet_ui_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import gradio as gr
import functools
from copy import copy
from typing import List, Optional, Union, Callable
from typing import List, Optional, Union, Callable, Dict
import numpy as np

from scripts.utils import svg_preprocess
Expand Down Expand Up @@ -39,11 +39,12 @@ def __init__(
loopback: bool = False,
use_preview_as_input: bool = False,
generated_image: Optional[np.ndarray] = None,
mask_image: Optional[np.ndarray] = None,
enabled: bool = True,
module: Optional[str] = None,
model: Optional[str] = None,
weight: float = 1.0,
image: Optional[np.ndarray] = None,
image: Optional[Dict[str, np.ndarray]] = None,
*args,
**kwargs,
):
Expand All @@ -53,6 +54,11 @@ def __init__(
else:
input_image = image

# Prefer uploaded mask_image over hand-drawn mask.
if input_image is not None and mask_image is not None:
assert isinstance(input_image, dict)
input_image["mask"] = mask_image

super().__init__(enabled, module, model, weight, input_image, *args, **kwargs)
self.is_ui = True
self.input_mode = input_mode
Expand Down Expand Up @@ -120,6 +126,8 @@ def __init__(
self.image = None
self.generated_image_group = None
self.generated_image = None
self.mask_image_group = None
self.mask_image = None
self.batch_tab = None
self.batch_image_dir = None
self.create_canvas = None
Expand All @@ -135,6 +143,7 @@ def __init__(
self.low_vram = None
self.pixel_perfect = None
self.preprocessor_preview = None
self.mask_upload = None
self.type_filter = None
self.module = None
self.trigger_preprocessor = None
Expand Down Expand Up @@ -229,6 +238,15 @@ def render(self, tabname: str, elem_id_tabname: str, is_img2img: bool) -> None:
elem_classes=["cnet-close-preview"],
)

with gr.Group(visible=False, elem_classes=["cnet-mask-image-group"]) as self.mask_image_group:
self.mask_image = gr.Image(
value=None,
label="Upload Mask",
elem_id=f"{elem_id_tabname}_{tabname}_mask_image",
elem_classes=["cnet-mask-image"],
interactive=True,
)

with gr.Tab(label="Batch") as self.batch_tab:
self.batch_image_dir = gr.Textbox(
label="Input Directory",
Expand Down Expand Up @@ -318,6 +336,13 @@ def render(self, tabname: str, elem_id_tabname: str, is_img2img: bool) -> None:
elem_id=preview_check_elem_id,
visible=not is_img2img,
)
self.mask_upload = gr.Checkbox(
label="Mask Upload",
value=False,
elem_classes=["cnet-mask-upload"],
elem_id=f"{elem_id_tabname}_{tabname}_controlnet_mask_upload_checkbox",
visible=not is_img2img,
)
self.use_preview_as_input = gr.Checkbox(
label="Preview as Input",
value=False,
Expand Down Expand Up @@ -890,6 +915,33 @@ def register_shift_hr_options(self):
show_progress=False
)

def register_shift_upload_mask(self):
""" Controls whether the upload mask input should be visible. """
self.mask_upload.change(
fn=lambda checked: (
# Clear mask_image if unchecked.
(gr.update(visible=False), gr.update(value=None))
if not checked else
(gr.update(visible=True), gr.update())
),
inputs=[self.mask_upload],
outputs=[self.mask_image_group, self.mask_image],
show_progress=False,
)

if self.upload_independent_img_in_img2img is not None:
self.upload_independent_img_in_img2img.change(
fn=lambda checked: (
# Uncheck `upload_mask` when not using independent input.
gr.update(visible=False, value=False)
if not checked else
gr.update(visible=True)
),
inputs=[self.upload_independent_img_in_img2img],
outputs=[self.mask_upload],
show_progress=False,
)

def register_callbacks(self, is_img2img: bool):
"""Register callbacks on the UI elements.
Expand All @@ -906,6 +958,7 @@ def register_callbacks(self, is_img2img: bool):
self.register_build_sliders()
self.register_run_annotator(is_img2img)
self.register_shift_preview()
self.register_shift_upload_mask()
self.register_create_canvas()
self.openpose_editor.register_callbacks(
self.generated_image, self.use_preview_as_input,
Expand Down Expand Up @@ -950,6 +1003,7 @@ def render_and_register_unit(self, tabname: str, is_img2img: bool):
# They are only used during object construction.
self.use_preview_as_input,
self.generated_image,
self.mask_image,
# End of Non-persistent fields.
self.enabled,
self.module,
Expand Down

0 comments on commit 7079906

Please sign in to comment.