Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mosaic Transform #6534

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions references/detection/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,3 +591,104 @@ def forward(
def __repr__(self) -> str:
s = f"{self.__class__.__name__}(blending={self.blending}, resize_interpolation={self.resize_interpolation})"
return s


class Mosaic(nn.Module):
"""
Mosaic Transform
"""

def __init__(self, min_frac: float = 0.25, max_frac: float = 0.75, size_limit: int = 10) -> None:
super().__init__()
self.min_frac = min_frac
self.max_frac = max_frac
Comment on lines +602 to +604
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we need to check if the min_frac and max_frac arguments are in between 0 and 1.

self.size_limit = size_limit

def forward(
self, images: torch.Tensor, boxes: List[List[List[torch.Tensor]]], labels: List[List[List[torch.Tensor]]]
):
"""
images (torch.Tensor) : Image Tensor of B*4*C*H*W
boxes (List[List[List[torch.Tensor]]]) : Bounding boxes in xyxy format.
labels (List[List[torch.Tensor]]) : labels corresponding to the bounding boxes.
"""

# implementation is heavily inspired from this colab notebook : https://colab.research.google.com/drive/1YWb7a_3bHqG30SoIxU5S4lHKktkRseyY?usp=sharing#scrollTo=yHsYf3Z3bujO

if images.ndim != 5:
raise ValueError(f"Image tensor should be 5 dimensional, got {images.ndim}")
if images.shape[1] != 4:
raise ValueError(f"First dimension of image tensor should be of size 4, got {images.shape[1]}")

sy, sx = images.shape[-2:]
B = images.shape[0]
num_channels = images.shape[-3]
abhi-glitchhg marked this conversation as resolved.
Show resolved Hide resolved

xc = torch.randint(int(sx * self.min_frac), int(sx * self.max_frac), size=(1,))
yc = torch.randint(int(sy * self.min_frac), int(sy * self.max_frac), size=(1,))

x0a0, y0a0, x1a0, y1a0 = 0, 0, xc, yc
x0b0, y0b0, x1b0, y1b0 = sx - xc, sy - yc, sx, sy

x0a1, y0a1, x1a1, y1a1 = 0, yc, xc, sy
x0b1, y0b1, x1b1, y1b1 = sx - xc, 0, sx, sy - yc

x0a2, y0a2, x1a2, y1a2 = xc, 0, sx, yc
x0b2, y0b2, x1b2, y1b2 = 0, sy - yc, sx - xc, sy

x0a3, y0a3, x1a3, y1a3 = xc, yc, sx, sy
x0b3, y0b3, x1b3, y1b3 = 0, 0, sx - xc, sy - yc

mosaic_image = torch.zeros((B, num_channels, sy, sx), dtype=images.dtype)

mosaic_image[..., y0a0:y1a0, x0a0:x1a0] = images[:, 0][..., y0b0:y1b0, x0b0:x1b0]
mosaic_image[..., y0a1:y1a1, x0a1:x1a1] = images[:, 1][..., y0b1:y1b1, x0b1:x1b1]
mosaic_image[..., y0a2:y1a2, x0a2:x1a2] = images[:, 2][..., y0b2:y1b2, x0b2:x1b2]
mosaic_image[..., y0a3:y1a3, x0a3:x1a3] = images[:, 3][..., y0b3:y1b3, x0b3:x1b3]

# calculating offsets for the bounding boxes;
offset_y0 = y0a0 - y0b0
offset_x0 = x0a0 - x0b0

offset_y1 = y0a1 - y0b1
offset_x1 = x0a1 - x0b1

offset_y2 = y0a2 - y0b2
offset_x2 = x0a2 - x0b2

offset_y3 = y0a3 - y0b3
offset_x3 = x0a3 - x0b3
mosaic_boxes = []
mosaic_labels = []

for i in range(B):
boxes[i][0][:, 0:4:2] += offset_x0
boxes[i][0][:, 1:4:2] += offset_y0

boxes[i][1][:, 0:4:2] += offset_x1
boxes[i][1][:, 1:4:2] += offset_y1

boxes[i][2][:, 0:4:2] += offset_x2
boxes[i][2][:, 1:4:2] += offset_y2

boxes[i][3][:, 0:4:2] += offset_x3
boxes[i][3][:, 1:4:2] += offset_y3

temp_box = torch.vstack(boxes[i])
temp_label = torch.vstack(labels[i])

temp_box[..., 0::2] = torch.clip(temp_box[..., 0::2], 0, sx)
temp_box[..., 1::2] = torch.clip(temp_box[..., 1::2], 0, sy)

w_ = temp_box[..., 2] - temp_box[..., 0]
h_ = temp_box[..., 3] - temp_box[..., 1]

mask_ = (w_ > self.size_limit) & (h_ > self.size_limit)

temp_box = temp_box[mask_]
temp_label = temp_label[mask_]

mosaic_boxes.append(temp_box)
mosaic_labels.append(temp_label)

return mosaic_image, mosaic_boxes, mosaic_labels