Skip to content

Commit

Permalink
🔨 Refactor choose input image logic
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei committed Jan 14, 2024
1 parent dad60de commit 26595c0
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 95 deletions.
145 changes: 57 additions & 88 deletions scripts/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,50 +105,26 @@ def image_dict_from_any(image) -> Optional[Dict[str, np.ndarray]]:
elif image['image']:
image['image'] = external_code.to_base64_nparray(image['image'])
else:
image['image'] = None
image['image'] = None

# If there is no image, return image with None image and None mask
if image['image'] is None:
image['mask'] = None
return image

if isinstance(image['mask'], str):
if 'mask' not in image or image['mask'] is None:
image['mask'] = np.zeros_like(image['image'], dtype=np.uint8)
elif isinstance(image['mask'], str):
if os.path.exists(image['mask']):
image['mask'] = np.array(Image.open(image['mask'])).astype('uint8')
elif image['mask']:
image['mask'] = external_code.to_base64_nparray(image['mask'])
else:
image['mask'] = np.zeros_like(image['image'], dtype=np.uint8)
elif image['mask'] is None:
image['mask'] = np.zeros_like(image['image'], dtype=np.uint8)

return image


def image_has_mask(input_image: np.ndarray) -> bool:
"""
Determine if an image has an alpha channel (mask) that is not empty.
The function checks if the input image has three dimensions (height, width, channels),
and if the third dimension (channel dimension) is of size 4 (presumably RGB + alpha).
Then it checks if the maximum value in the alpha channel is greater than 127. This is
presumably to check if there is any non-transparent (or semi-transparent) pixel in the
image. A pixel is considered non-transparent if its alpha value is above 127.
Args:
input_image (np.ndarray): A 3D numpy array representing an image. The dimensions
should represent [height, width, channels].
Returns:
bool: True if the image has a non-empty alpha channel, False otherwise.
"""
return (
input_image.ndim == 3 and
input_image.shape[2] == 4 and
np.max(input_image[:, :, 3]) > 127
)


def prepare_mask(
mask: Image.Image, p: processing.StableDiffusionProcessing
) -> Image.Image:
Expand Down Expand Up @@ -587,30 +563,31 @@ def get_enabled_units(p):

@staticmethod
def choose_input_image(
p: processing.StableDiffusionProcessing,
p: processing.StableDiffusionProcessing,
unit: external_code.ControlNetUnit,
idx: int
) -> Tuple[np.ndarray, bool]:
) -> Tuple[np.ndarray, external_code.ResizeMode]:
""" Choose input image from following sources with descending priority:
- p.image_control: [Deprecated] Lagacy way to pass image to controlnet.
- p.control_net_input_image: [Deprecated] Lagacy way to pass image to controlnet.
- unit.image:
- ControlNet tab input image.
- Input image from API call.
- unit.image: ControlNet tab input image.
- p.init_images: A1111 img2img tab input image.
Returns:
- The input image in ndarray form.
- Whether input image is from A1111.
- The resize mode.
"""
image_from_a1111 = False

# 4 input image sources.
p_image_control = getattr(p, "image_control", None)
p_input_image = Script.get_remote_call(p, "control_net_input_image", None, idx)
image = image_dict_from_any(unit.image)
a1111_image = getattr(p, "init_images", [None])[0]

resize_mode = external_code.resize_mode_from_value(unit.resize_mode)

if batch_hijack.instance.is_batch and getattr(p, "image_control", None) is not None:
if batch_hijack.instance.is_batch and p_image_control is not None:
logger.warning("Warn: Using legacy field 'p.image_control'.")
input_image = HWC3(np.asarray(p.image_control))
input_image = HWC3(np.asarray(p_image_control))
elif p_input_image is not None:
logger.warning("Warn: Using legacy field 'p.controlnet_input_image'")
if isinstance(p_input_image, dict) and "mask" in p_input_image and "image" in p_input_image:
Expand All @@ -620,9 +597,6 @@ def choose_input_image(
else:
input_image = HWC3(np.asarray(p_input_image))
elif image is not None:
while len(image['mask'].shape) < 3:
image['mask'] = image['mask'][..., np.newaxis]

# Need to check the image for API compatibility
if isinstance(image['image'], str):
from modules.api.api import decode_base64_to_image
Expand All @@ -631,44 +605,54 @@ def choose_input_image(
input_image = HWC3(image['image'])

have_mask = 'mask' in image and not (
(image['mask'][:, :, 0] <= 5).all() or
(image['mask'][:, :, 0] <= 5).all() or
(image['mask'][:, :, 0] >= 250).all()
)
if have_mask:
while len(image['mask'].shape) < 3:
image['mask'] = image['mask'][..., np.newaxis]

if 'inpaint' in unit.module:
logger.info("using inpaint as input")
color = HWC3(image['image'])
if have_mask:
if 'inpaint' in unit.module:
logger.info("using inpaint as input")
color = HWC3(image['image'])
alpha = image['mask'][:, :, 0:1]
input_image = np.concatenate([color, alpha], axis=2)
else:
alpha = np.zeros_like(color)[:, :, 0:1]
input_image = np.concatenate([color, alpha], axis=2)
else:
if have_mask and not shared.opts.data.get("controlnet_ignore_noninpaint_mask", False):
logger.info("using mask as input")
input_image = HWC3(image['mask'][:, :, 0])
unit.module = 'none' # Always use black bg and white line
if not shared.opts.data.get("controlnet_ignore_noninpaint_mask", False):
logger.info("using mask as input")
input_image = HWC3(image['mask'][:, :, 0])
unit.module = 'none' # Always use black bg and white line
elif a1111_image is not None:
input_image = HWC3(np.asarray(a1111_image))
a1111_i2i_resize_mode = getattr(p, "resize_mode", None)
assert a1111_i2i_resize_mode is not None
resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode)

a1111_mask_image : Optional[Image.Image] = getattr(p, "image_mask", None)
if 'inpaint' in unit.module and a1111_mask_image is not None:
a1111_mask = np.array(prepare_mask(a1111_mask_image, p))
assert a1111_mask.ndim == 2
assert a1111_mask.shape[0] == input_image.shape[0]
assert a1111_mask.shape[1] == input_image.shape[1]
input_image = np.concatenate([input_image[:, :, 0:3], a1111_mask[:, :, None]], axis=2)
else:
# use img2img init_image as default
input_image = getattr(p, "init_images", [None])[0]
if input_image is None:
if batch_hijack.instance.is_batch:
shared.state.interrupted = True
raise ValueError('controlnet is enabled but no input image is given')

input_image = HWC3(np.asarray(input_image))
image_from_a1111 = True

# No input image detected.
if batch_hijack.instance.is_batch:
shared.state.interrupted = True
raise ValueError("controlnet is enabled but no input image is given")

assert isinstance(input_image, np.ndarray)
return input_image, image_from_a1111

if 'inpaint' in unit.module and input_image.shape[2] != 4:
raise ValueError("No mask detected for ControlNet inpaint")
return input_image, resize_mode

@staticmethod
def bound_check_params(unit: external_code.ControlNetUnit) -> None:
"""
Checks and corrects negative parameters in ControlNetUnit 'unit'.
Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to
Parameters 'processor_res', 'threshold_a', 'threshold_b' are reset to
their default values if negative.
Args:
unit (external_code.ControlNetUnit): The ControlNetUnit instance to check.
"""
Expand Down Expand Up @@ -760,7 +744,6 @@ def controlnet_main_entry(self, p):
Script.bound_check_params(unit)
Script.check_sd_version_compatible(unit)

resize_mode = external_code.resize_mode_from_value(unit.resize_mode)
control_mode = external_code.control_mode_from_value(unit.control_mode)

if unit.module in model_free_preprocessors:
Expand All @@ -774,22 +757,7 @@ def controlnet_main_entry(self, p):
bind_control_lora(unet, control_lora)
p.controlnet_control_loras.append(control_lora)

input_image, image_from_a1111 = Script.choose_input_image(p, unit, idx)
if image_from_a1111:
a1111_i2i_resize_mode = getattr(p, "resize_mode", None)
if a1111_i2i_resize_mode is not None:
resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode)

a1111_mask_image : Optional[Image.Image] = getattr(p, "image_mask", None)
if 'inpaint' in unit.module and not image_has_mask(input_image) and a1111_mask_image is not None:
a1111_mask = np.array(prepare_mask(a1111_mask_image, p))
assert a1111_mask.ndim == 2
assert a1111_mask.shape[0] == input_image.shape[0]
assert a1111_mask.shape[1] == input_image.shape[1]
input_image = np.concatenate([input_image[:, :, 0:3], a1111_mask[:, :, None]], axis=2)
a1111_i2i_resize_mode = getattr(p, "resize_mode", None)
if a1111_i2i_resize_mode is not None:
resize_mode = external_code.resize_mode_from_value(a1111_i2i_resize_mode)
input_image, resize_mode = Script.choose_input_image(p, unit, idx)

# Note: The method determining whether the active script is an upscale script is purely
# based on `extra_generation_params` these scripts attach on `p`, and subject to change
Expand All @@ -799,13 +767,14 @@ def controlnet_main_entry(self, p):
logger.debug(f"is_upscale_script={is_upscale_script}")
# Note: `inpaint_full_res` is "inpaint area" on UI. The flag is `True` when "Only masked"
# option is selected.
a1111_mask_image : Optional[Image.Image] = getattr(p, "image_mask", None)
is_only_masked_inpaint = (
issubclass(type(p), StableDiffusionProcessingImg2Img) and
p.inpaint_full_res and
p.inpaint_full_res and
a1111_mask_image is not None
)
# Crop ControlNet input image based on A1111 inpaint mask given.
# This logic is crutial in upscale scripts, as they use A1111 mask + inpaint_full_res
# This logic is crutial in upscale scripts, as they use A1111 mask + inpaint_full_res
# to crop tiles.
if (
'reference' not in unit.module
Expand All @@ -822,13 +791,13 @@ def controlnet_main_entry(self, p):
crop_region = masking.expand_crop_region(crop_region, p.width, p.height, mask.width, mask.height)

input_image = [
images.resize_image(resize_mode.int_value(), i, mask.width, mask.height)
images.resize_image(resize_mode.int_value(), i, mask.width, mask.height)
for i in input_image
]

input_image = [x.crop(crop_region) for x in input_image]
input_image = [
images.resize_image(external_code.ResizeMode.OUTER_FIT.int_value(), x, p.width, p.height)
images.resize_image(external_code.ResizeMode.OUTER_FIT.int_value(), x, p.width, p.height)
for x in input_image
]

Expand Down
5 changes: 3 additions & 2 deletions tests/web_api/full_coverage/inpaint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def test_inpaint(self):
)

def test_inpaint_no_mask(self):
"""Inpaint should fail if no mask is provided."""
"""Inpaint should fail if no mask is provided. Output should not contain
ControlNet detected map."""
for gen_type in ("img2img", "txt2img"):
if gen_type == "img2img":
payload = {
Expand All @@ -91,7 +92,7 @@ def test_inpaint_no_mask(self):
unit["model"] = "control_v11p_sd15_inpaint [ebff9138]"
unit["module"] = "inpaint_only"
with self.subTest(gen_type=gen_type):
self.assertFalse(
self.assertTrue(
APITestTemplate(
f"{gen_type}_no_mask_fail",
gen_type,
Expand Down
15 changes: 10 additions & 5 deletions tests/web_api/full_coverage/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
test_result_dir = Path(__file__).parent / "results" / f"test_result_{timestamp}"
test_expectation_dir = Path(__file__).parent / "expectations"
os.makedirs(test_result_dir, exist_ok=True)
os.makedirs(test_expectation_dir, exist_ok=True)
resource_dir = Path(__file__).parents[2] / "images"

Expand Down Expand Up @@ -48,11 +47,10 @@ class StableDiffusionVersion(Enum):
)

is_full_coverage = os.environ.get("CONTROLNET_TEST_FULL_COVERAGE", None) is not None
is_full_coverage = True


class APITestTemplate:
is_set_expectation_run = True
is_set_expectation_run = False

def __init__(
self,
Expand Down Expand Up @@ -81,6 +79,9 @@ def __init__(
]

def exec(self) -> bool:
if not APITestTemplate.is_set_expectation_run:
os.makedirs(test_result_dir, exist_ok=True)

failed = False

response = requests.post(url=self.url, json=self.payload).json()
Expand Down Expand Up @@ -108,11 +109,15 @@ def exec(self) -> bool:
failed = True
continue

if img1 is None:
print(f"Warn: No expectation file found {img_file_name}.")
continue

if not expect_same_image(
img1,
img2,
diff_img_path=test_result_dir
/ img_file_name.replace(".png", "_diff.png"),
diff_img_path=str(test_result_dir
/ img_file_name.replace(".png", "_diff.png")),
):
failed = True
return not failed
Expand Down

0 comments on commit 26595c0

Please sign in to comment.