From 79dcbb16a3884d0b65ad27781a49d5686d8f44f9 Mon Sep 17 00:00:00 2001 From: bhack Date: Tue, 27 Jun 2023 23:23:24 +0200 Subject: [PATCH 1/6] Add masks to boundaries --- torchvision/ops/boxes.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index a541f8d880a..faaefd460a1 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -382,7 +382,39 @@ def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Te # distance between boxes' centers squared. return iou - (centers_distance_squared / diagonal_distance_squared), iou +def masks_to_boundaries(masks: torch.Tensor, dilation_ratio: float = 0.02) -> torch.Tensor: + """ + Compute the boundaries around the provided masks using morphological operations. + + Returns a tensor of the same shape as the input masks containing the boundaries of each mask. + + Args: + masks (Tensor[N, H, W]): masks to transform where N is the number of masks + and (H, W) are the spatial dimensions. + dilation_ratio (float, optional): ratio used for the dilation operation. Default: 0.02 + + Returns: + Tensor[N, H, W]: boundaries + """ + # If no masks are provided, return an empty tensor + if masks.numel() == 0: + return torch.zeros_like(masks) + + n, h, w = masks.shape + img_diag = math.sqrt(h ** 2 + w ** 2) + dilation = int(round(dilation_ratio * img_diag)) + selem_size = dilation * 2 + 1 + selem = torch.ones((n, 1, selem_size, selem_size), device=masks.device) + # Compute the boundaries for each mask + masks = masks.float().unsqueeze(1) + eroded_masks = F.conv2d(masks, selem, padding=dilation, groups=n) + eroded_masks = (eroded_masks == selem.view(n, -1).sum(1, keepdim=True)).byte() # Make the output binary + + contours = masks.byte() - eroded_masks + + return contours.squeeze(1) + def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: """ Compute the bounding boxes around the provided masks. @@ -415,3 +447,4 @@ def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: bounding_boxes[index, 3] = torch.max(y) return bounding_boxes + From d171ffdb94b744c7eb2a0afadea9770329f71d76 Mon Sep 17 00:00:00 2001 From: bhack Date: Wed, 28 Jun 2023 13:24:01 +0200 Subject: [PATCH 2/6] Doesn't expose directly the def --- docs/source/ops.rst | 1 + torchvision/ops/__init__.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/docs/source/ops.rst b/docs/source/ops.rst index 7124c85bb79..3a1851108a9 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -22,6 +22,7 @@ The below operators perform pre-processing as well as post-processing required i batched_nms masks_to_boxes + masks_to_boudnaries nms roi_align roi_pool diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 827505b842d..80aed924779 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -9,6 +9,7 @@ distance_box_iou, generalized_box_iou, masks_to_boxes, + masks_to_boundaries, nms, remove_small_boxes, ) @@ -32,6 +33,7 @@ __all__ = [ "masks_to_boxes", + "masks_to_boundaries", "deform_conv2d", "DeformConv2d", "nms", From 9d41c0aaae48ef3dcebf5201b8abad652f63fed6 Mon Sep 17 00:00:00 2001 From: bhack Date: Sat, 1 Jul 2023 14:33:48 +0200 Subject: [PATCH 3/6] change erosion --- torchvision/ops/boxes.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index faaefd460a1..5842ff01519 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -408,8 +408,9 @@ def masks_to_boundaries(masks: torch.Tensor, dilation_ratio: float = 0.02) -> to # Compute the boundaries for each mask masks = masks.float().unsqueeze(1) - eroded_masks = F.conv2d(masks, selem, padding=dilation, groups=n) - eroded_masks = (eroded_masks == selem.view(n, -1).sum(1, keepdim=True)).byte() # Make the output binary + eroded_masks = F.conv2d(masks, selem, padding=dilation) + # Make the output binary + eroded_masks = (eroded_masks == selem.view(n, -1).sum(-1).view(n, 1, 1, 1)).byte() contours = masks.byte() - eroded_masks From e277308d6125a98a4b05d32cdbce237d567fa1ec Mon Sep 17 00:00:00 2001 From: bhack Date: Sat, 1 Jul 2023 14:55:40 +0200 Subject: [PATCH 4/6] Add dummy test --- test/test_ops.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index b993bce65a2..e679fc93204 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -11,7 +11,7 @@ import torch.fx import torch.nn.functional as F from common_utils import assert_equal, cpu_and_cuda, needs_cuda -from PIL import Image +from PIL import Image, ImageDraw from torch import nn, Tensor from torch.autograd import gradcheck from torch.nn.modules.utils import _pair @@ -621,6 +621,32 @@ def test_is_leaf_node(self, device): assert len(graph_node_names[0]) == len(graph_node_names[1]) assert len(graph_node_names[0]) == 1 + op_obj.n_inputs +class TestMasksToBoundaries(ABC): + + @pytest.mark.parametrize("device", ['cpu', 'cuda']) + def test_masks_to_boundaries(self, device): + # Create masks + mask = torch.zeros(4, 32, 32, dtype=torch.bool) + mask[0, 1:10, 1:10] = True + mask[0, 12:20, 12:20] = True + mask[0, 15:18, 20:32] = True + mask[1, 15:23, 15:23] = True + mask[1, 22:33, 22:33] = True + mask[2, 1:5, 22:30] = True + mask[2, 5:14, 25:27] = True + pil_img = Image.new("L", (32, 32)) + draw = ImageDraw.Draw(pil_img) + draw.ellipse([2, 7, 26, 26], fill=1, outline=1, width=1) + mask[3, ...] = torch.from_numpy(np.asarray(pil_img)) + mask = mask.to(device) + dilation_ratio = 0.02 + boundaries = ops.masks_to_boundaries(mask, dilation_ratio) + # Generate expected output + # TODO: How we generate handle the expected output? + # replace with actual code to generate expected output + expected_boundaries = torch.zeros_like(mask) + torch.testing.assert_close(expected_boundaries, boundaries) + class TestNMS: def _reference_nms(self, boxes, scores, iou_threshold): From a8bd95c5ddcec8d9aff73cd2e08b173599c470ff Mon Sep 17 00:00:00 2001 From: bhack Date: Tue, 11 Jul 2023 12:43:35 +0200 Subject: [PATCH 5/6] Update ops.rst --- docs/source/ops.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/ops.rst b/docs/source/ops.rst index 3a1851108a9..c73aadf4cd8 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -22,7 +22,7 @@ The below operators perform pre-processing as well as post-processing required i batched_nms masks_to_boxes - masks_to_boudnaries + masks_to_boundaries nms roi_align roi_pool From 59fb72c9d4545ae6715ad5fed2c1ff03a889a0bc Mon Sep 17 00:00:00 2001 From: bhack Date: Sat, 17 Feb 2024 01:05:50 +0000 Subject: [PATCH 6/6] Add debug image option Refactor test and add debug image util Refactor implementation --- test/conftest.py | 4 ++ test/test_ops.py | 127 ++++++++++++++++++++++++++++++++------- torchvision/ops/boxes.py | 55 +++++++++-------- 3 files changed, 138 insertions(+), 48 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index a9768598ded..53f2e8a60b0 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -82,6 +82,10 @@ def pytest_collection_modifyitems(items): items[:] = out_items +def pytest_addoption(parser): + parser.addoption("--debug-images", action="store_true", help="Enable debug mode for saving images.") + + def pytest_sessionfinish(session, exitstatus): # This hook is called after all tests have run, and just before returning an exit status. # We here change exit code 5 into 0. diff --git a/test/test_ops.py b/test/test_ops.py index 1daaadb5db9..eade686773e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,3 +1,4 @@ +import logging import math import os from abc import ABC, abstractmethod @@ -7,12 +8,13 @@ import numpy as np import pytest +import scipy.ndimage import torch import torch.fx import torch.nn.functional as F import torch.testing._internal.optests as optests from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps -from PIL import Image +from PIL import Image, ImageDraw from torch import nn, Tensor from torch.autograd import gradcheck from torch.nn.modules.utils import _pair @@ -734,31 +736,110 @@ def test_is_leaf_node(self, device): assert len(graph_node_names[0]) == len(graph_node_names[1]) assert len(graph_node_names[0]) == 1 + op_obj.n_inputs + +import matplotlib.pyplot as plt + + class TestMasksToBoundaries(ABC): + def save_and_images( + self, original_masks, expected_boundaries, actual_boundaries, diff, filename_prefix, visualize=True + ): + """ + Saves images separately for original masks, expected boundaries, actual boundaries, and their difference. + + Parameters: + - original_masks: The starting binary masks tensor. + - expected_boundaries: The expected boundaries tensor. + - actual_boundaries: The actual boundaries tensor calculated by the function. + - diff: The absolute difference between expected and actual boundaries. + - filename_prefix: Prefix for the saved filename. + - visualize: Flag to enable or disable visualization. + """ + # Ensure directory exists + output_dir = "test_outputs" + os.makedirs(output_dir, exist_ok=True) + filepath_prefix = os.path.join(output_dir, filename_prefix) - @pytest.mark.parametrize("device", ['cpu', 'cuda']) - def test_masks_to_boundaries(self, device): - # Create masks - mask = torch.zeros(4, 32, 32, dtype=torch.bool) - mask[0, 1:10, 1:10] = True - mask[0, 12:20, 12:20] = True - mask[0, 15:18, 20:32] = True - mask[1, 15:23, 15:23] = True - mask[1, 22:33, 22:33] = True - mask[2, 1:5, 22:30] = True - mask[2, 5:14, 25:27] = True - pil_img = Image.new("L", (32, 32)) - draw = ImageDraw.Draw(pil_img) - draw.ellipse([2, 7, 26, 26], fill=1, outline=1, width=1) - mask[3, ...] = torch.from_numpy(np.asarray(pil_img)) + num_images = original_masks.shape[0] + + original_masks = original_masks.cpu().numpy() if original_masks.is_cuda else original_masks.numpy() + expected_boundaries = ( + expected_boundaries.cpu().numpy() if expected_boundaries.is_cuda else expected_boundaries.numpy() + ) + actual_boundaries = actual_boundaries.cpu().numpy() if actual_boundaries.is_cuda else actual_boundaries.numpy() + diff = diff.cpu().numpy() if diff.is_cuda else diff.numpy() + + # Plot and save each image separately + for i in range(num_images): + original = original_masks[i].squeeze() + expected = expected_boundaries[i].squeeze() + actual = actual_boundaries[i].squeeze() + difference = diff[i].squeeze() + + if visualize: + # Plotting + fig, axes = plt.subplots(1, 4, figsize=(20, 5)) + titles = ["Original Mask", "Expected Boundaries", "Actual Boundaries", "Absolute Difference"] + images = [original, expected, actual, difference] + + for ax, img, title in zip(axes, images, titles): + ax.imshow(img, cmap="gray", interpolation="nearest") + ax.axis("off") + ax.set_title(title) + + plt.subplots_adjust(top=0.85) + + # Save the figure + fig.tight_layout() + plt.savefig(f"{filepath_prefix}_image_{i}.png", bbox_inches="tight") + plt.close(fig) + + @pytest.mark.parametrize("device", ["cpu", "cuda"]) + @pytest.mark.parametrize("kernel_size", [3, 5]) # Example kernel sizes + @pytest.mark.parametrize("canvas_size", [32, 64]) # Example canvas sizes + @pytest.mark.parametrize("batch_size", [1, 4]) # Parametrizing over batch sizes, e.g., 1 and 4 + def test_masks_to_boundaries(self, request, tmpdir, device, kernel_size, canvas_size, batch_size): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA is not available on this system.") + debug_mode = request.config.getoption("--debug-images") + # Create masks with the specified canvas size and batch size + mask = torch.zeros(batch_size, canvas_size, canvas_size, dtype=torch.bool) + + for b in range(batch_size): + if b % 4 == 0: + mask[b, 1:10, 1:10] = True + elif b % 4 == 1: + mask[b, 15:23, 15:23] = True + elif b % 4 == 2: + mask[b, 1:5, 22:30] = True + elif b % 4 == 3: + pil_img = Image.new("L", (canvas_size, canvas_size)) + draw = ImageDraw.Draw(pil_img) + draw.ellipse([2, 7, min(26, canvas_size - 6), min(26, canvas_size - 6)], fill=1, outline=1, width=1) + ellipse_mask = torch.from_numpy(np.array(pil_img, dtype=np.uint8)).bool() + mask[b, ...] = ellipse_mask mask = mask.to(device) - dilation_ratio = 0.02 - boundaries = ops.masks_to_boundaries(mask, dilation_ratio) - # Generate expected output - # TODO: How we generate handle the expected output? - # replace with actual code to generate expected output - expected_boundaries = torch.zeros_like(mask) - torch.testing.assert_close(expected_boundaries, boundaries) + actual_boundaries = ops.masks_to_boundaries(mask, kernel_size) + expected_boundaries = torch.zeros_like(mask) + struct = np.ones((kernel_size, kernel_size), dtype=np.uint8) + + # Calculate expected boundaries using scipy's binary_erosion + for i in range(batch_size): + single_mask = mask[i].cpu().numpy() + eroded_mask = scipy.ndimage.binary_erosion(single_mask, structure=struct, border_value=0) + single_expected_boundary = single_mask ^ eroded_mask + expected_boundaries[i] = torch.from_numpy(single_expected_boundary).to(device) + + if debug_mode: + diff = torch.abs(expected_boundaries.float() - actual_boundaries.float()) + filename_prefix = f"kernel_{kernel_size}_canvas_{canvas_size}_batch_{batch_size}" + output_file_path = tmpdir.join(f"{filename_prefix}.png") + # Log the path where the debug image will be saved + logging.info(f"Debug image saved at: {output_file_path}") + + self.save_and_images(mask, expected_boundaries, actual_boundaries, diff, str(output_file_path)) + + torch.testing.assert_close(actual_boundaries, expected_boundaries) class TestNMS: diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 50af34cb2a4..e1f9c630939 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -1,6 +1,7 @@ from typing import Tuple import torch +import torch.nn.functional as F import torchvision from torch import Tensor from torchvision.extension import _assert_has_ops @@ -379,7 +380,6 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Tensor, Tensor]: - iou = box_iou(boxes1, boxes2) lti = torch.min(boxes1[:, None, :2], boxes2[:, :2]) rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) @@ -398,40 +398,46 @@ def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Te # distance between boxes' centers squared. return iou - (centers_distance_squared / diagonal_distance_squared), iou -def masks_to_boundaries(masks: torch.Tensor, dilation_ratio: float = 0.02) -> torch.Tensor: + +def masks_to_boundaries(masks: torch.Tensor, kernel_size: int) -> torch.Tensor: """ - Compute the boundaries around the provided masks using morphological operations. + Compute the boundaries around the provided binary masks using morphological operations with a custom structuring element. + Enforces the use of an odd-sized kernel for the structuring element. - Returns a tensor of the same shape as the input masks containing the boundaries of each mask. - - Args: - masks (Tensor[N, H, W]): masks to transform where N is the number of masks - and (H, W) are the spatial dimensions. - dilation_ratio (float, optional): ratio used for the dilation operation. Default: 0.02 + Parameters: + - masks: Input binary masks tensor of shape [N, H, W]. + - kernel_size: Size of the kernel for the structuring element, must be odd. Returns: - Tensor[N, H, W]: boundaries + - Tensor representing the boundaries of the masks with shape [N, H, W]. """ - # If no masks are provided, return an empty tensor if masks.numel() == 0: return torch.zeros_like(masks) - n, h, w = masks.shape - img_diag = math.sqrt(h ** 2 + w ** 2) - dilation = int(round(dilation_ratio * img_diag)) - selem_size = dilation * 2 + 1 - selem = torch.ones((n, 1, selem_size, selem_size), device=masks.device) + # Ensure kernel_size is odd + if kernel_size % 2 == 0: + raise ValueError("kernel_size must be odd.") + + # Define the structuring element based on kernel_size + selem = torch.ones((1, 1, kernel_size, kernel_size), dtype=torch.float32, device=masks.device) - # Compute the boundaries for each mask - masks = masks.float().unsqueeze(1) - eroded_masks = F.conv2d(masks, selem, padding=dilation) - # Make the output binary - eroded_masks = (eroded_masks == selem.view(n, -1).sum(-1).view(n, 1, 1, 1)).byte() + masks_float = masks.float().unsqueeze(1) + + # Apply convolution with the structuring element + padding = (kernel_size - 1) // 2 + eroded_masks = F.conv2d(masks_float, selem, padding=padding, stride=1) + eroded_masks = eroded_masks.squeeze(1) # Remove channel dimension after convolution + + # Thresholding: a pixel in the eroded mask should be set if the convolution result + # is equal to the sum of the structuring element (i.e., all ones in the kernel) + threshold = torch.sum(selem).item() + eroded_masks = (eroded_masks == threshold).float() + + contours = torch.logical_xor(masks, eroded_masks.bool()) + + return contours - contours = masks.byte() - eroded_masks - return contours.squeeze(1) - def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: """ Compute the bounding boxes around the provided masks. @@ -464,4 +470,3 @@ def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: bounding_boxes[index, 3] = torch.max(y) return bounding_boxes -