Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add masks to boundaries #7704

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The below operators perform pre-processing as well as post-processing required i

batched_nms
masks_to_boxes
masks_to_boundaries
nms
roi_align
roi_pool
Expand Down
26 changes: 26 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,32 @@ def test_is_leaf_node(self, device):
assert len(graph_node_names[0]) == len(graph_node_names[1])
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs

class TestMasksToBoundaries(ABC):

@pytest.mark.parametrize("device", ['cpu', 'cuda'])
def test_masks_to_boundaries(self, device):
# Create masks
mask = torch.zeros(4, 32, 32, dtype=torch.bool)
mask[0, 1:10, 1:10] = True
mask[0, 12:20, 12:20] = True
mask[0, 15:18, 20:32] = True
mask[1, 15:23, 15:23] = True
mask[1, 22:33, 22:33] = True
mask[2, 1:5, 22:30] = True
mask[2, 5:14, 25:27] = True
pil_img = Image.new("L", (32, 32))
draw = ImageDraw.Draw(pil_img)
draw.ellipse([2, 7, 26, 26], fill=1, outline=1, width=1)
mask[3, ...] = torch.from_numpy(np.asarray(pil_img))
mask = mask.to(device)
dilation_ratio = 0.02
boundaries = ops.masks_to_boundaries(mask, dilation_ratio)
# Generate expected output
# TODO: How we generate handle the expected output?
bhack marked this conversation as resolved.
Show resolved Hide resolved
# replace with actual code to generate expected output
expected_boundaries = torch.zeros_like(mask)
torch.testing.assert_close(expected_boundaries, boundaries)


class TestNMS:
def _reference_nms(self, boxes, scores, iou_threshold):
Expand Down
2 changes: 2 additions & 0 deletions torchvision/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
distance_box_iou,
generalized_box_iou,
masks_to_boxes,
masks_to_boundaries,
nms,
remove_small_boxes,
)
Expand All @@ -32,6 +33,7 @@

__all__ = [
"masks_to_boxes",
"masks_to_boundaries",
"deform_conv2d",
"DeformConv2d",
"nms",
Expand Down
34 changes: 34 additions & 0 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,40 @@ def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Te
# distance between boxes' centers squared.
return iou - (centers_distance_squared / diagonal_distance_squared), iou

def masks_to_boundaries(masks: torch.Tensor, dilation_ratio: float = 0.02) -> torch.Tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's OK to have the implementation in this file even though this isn't related to boxed. However, I don't think we should expose it here. I think we should just expose it in from the torchvision.ops namespace (otherwise the implementation will always have to stay in this file for BC, and that may lock us).

We probably just need to rename this to _masks_to_boundaries and the expose it in torchvision.ops.__init__.py like

from .boxes import import _masks_to_boundaries as masks_to_boundaries

Any other suggestion @pmeier @vfdev-5 @oke-aditya ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's OK to have the implementation in this file even though this isn't related to boxed.

No strong opinion, but could we maybe also have a new _masks.py module or move it into the misc.py one?

👍 for only exposing it in the torchvision.ops namespace.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tbh there is demand for mask_utils. Several of them, #4415 . Candidate utils like convert_masks_format, paste_masks_in_images, etc. Maybe it's time to create new files mask_utils.py and make future extensions possible?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can always create an ops.mask* namespace at any time. We should only do that when we know for sure we need it, i.e. when we start having 2+ mask utils. Alls ops are exposed in the ops. namespace anyway so there's no need to rush and create a file which will only have one single util in it ATM.

I'm OK with creating _mask.py as well (and we can rename it into mask.py later if we want to).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK with creating _mask.py as well (and we can rename it into mask.py later if we want to).

This sounds best solution! We can avoid the bloat inside this file as well as keep them private 😄

"""
Compute the boundaries around the provided masks using morphological operations.

Returns a tensor of the same shape as the input masks containing the boundaries of each mask.

Args:
masks (Tensor[N, H, W]): masks to transform where N is the number of masks
and (H, W) are the spatial dimensions.
dilation_ratio (float, optional): ratio used for the dilation operation. Default: 0.02

Returns:
Tensor[N, H, W]: boundaries
bhack marked this conversation as resolved.
Show resolved Hide resolved
"""
# If no masks are provided, return an empty tensor
if masks.numel() == 0:
return torch.zeros_like(masks)

n, h, w = masks.shape
img_diag = math.sqrt(h ** 2 + w ** 2)
dilation = int(round(dilation_ratio * img_diag))
selem_size = dilation * 2 + 1
bhack marked this conversation as resolved.
Show resolved Hide resolved
bhack marked this conversation as resolved.
Show resolved Hide resolved
selem = torch.ones((n, 1, selem_size, selem_size), device=masks.device)

# Compute the boundaries for each mask
masks = masks.float().unsqueeze(1)
eroded_masks = F.conv2d(masks, selem, padding=dilation)
# Make the output binary
eroded_masks = (eroded_masks == selem.view(n, -1).sum(-1).view(n, 1, 1, 1)).byte()

contours = masks.byte() - eroded_masks

return contours.squeeze(1)

def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
"""
Compute the bounding boxes around the provided masks.
Expand Down Expand Up @@ -431,3 +464,4 @@ def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
bounding_boxes[index, 3] = torch.max(y)

return bounding_boxes