From 83a4069d1134d2d0e9d8af847651be7943c323e5 Mon Sep 17 00:00:00 2001 From: abhijit_linux Date: Sat, 3 Sep 2022 02:05:51 +0530 Subject: [PATCH 1/8] initial commit --- references/detection/transforms.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index d26bf6eac85..36b5267d182 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -591,3 +591,7 @@ def forward( def __repr__(self) -> str: s = f"{self.__class__.__name__}(blending={self.blending}, resize_interpolation={self.resize_interpolation})" return s + + +class Mosaic: + pass From 8b630f0751cbbb20cfa87482ba3a43ed255206f0 Mon Sep 17 00:00:00 2001 From: abhijit_linux Date: Sun, 18 Sep 2022 22:44:39 +0530 Subject: [PATCH 2/8] update: basic Mosaic Implemented --- references/detection/transforms.py | 77 +++++++++++++++++++++++++++++- 1 file changed, 75 insertions(+), 2 deletions(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 36b5267d182..84119f7ebde 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -593,5 +593,78 @@ def __repr__(self) -> str: return s -class Mosaic: - pass +class Mosaic(nn.Module): + def __init__(self, min_frac: float = 0.25, max_frac: float = 0.75, size_limit=10) -> None: + super().__init__() + self.min_frac = min_frac + self.max_frac = max_frac + self.size_limit = size_limit + + def forward(self, images, targets): + """ + images : torch.Tensor of NWHC type + targets: list[torch.Tensor]; bounding boxes in xyxy format. + """ + + # implementation is heavily inspired from this colab notebook + sx, sy = images.shape[-2], images.shape[-3] + + num_channels = images.shape[1] + + xc = torch.random.randint(sx * self.min_frac, sx * self.max_frac) + yc = torch.random.randint(sy * self.min_frac, sy * self.min_frac) + + mosaic_image = torch.zeros((sy, sx, num_channels), dtype=torch.float32) + + x1a1, y1a1, x2a1, y2a1 = 0, 0, xc, yc + x1b1, y1b1, x2b1, y2b1 = sx - xc, sy - yc, sx, sy + + x1a2, y1a2, x2a2, y2a2 = xc, 0, sx, yc + x1b2, y1b2, x2b2, y2b2 = 0, sy - yc, sx - xc, sy + + x1a3, y1a3, x2a3, y2a3 = 0, yc, xc, sy + x1b3, y1b3, x2b3, y2b3 = sx - xc, 0, sx, sy - yc + + x1a4, y1a4, x2a4, y2a4 = xc, yc, sx, sy + x1b4, y1b4, x2b4, y2b4 = 0, 0, sx - xc, sy - yc + + # calculate and apply box offsets due to replacement + offset_x1 = x1a1 - x1b1 + offset_y1 = y1a1 - y1b1 + + offset_y2 = y1a2 - y1b2 + offset_x2 = x1a2 - x1b2 + + offset_y3 = y1a3 - y1b3 + offset_x3 = x1a3 - x1b3 + + offset_y4 = y1a4 - y1b4 + offset_x4 = x1a4 - x1b4 + + targets[0][:, 0::2] += offset_x1 + targets[0][:, 1::2] += offset_y1 + + targets[1][:, 0::2] += offset_x2 + targets[1][:, 1::2] += offset_y2 + + targets[2][:, 0::2] += offset_x3 + targets[2][:, 1::2] += offset_y3 + targets[2][:, 0::2] += offset_x4 + targets[2][:, 1::2] += offset_y4 + + mosaic_image[y1a1:y2a1, x1a1:x2a1] = images[0][y1b1:y2b1, x1b1:x2b1] + mosaic_image[y1a2:y2a2, x1a2:x2a2] = images[1][y1b2:y2b2, x1b2:x2b2] + mosaic_image[y1a3:y2a3, x1a3:x2a3] = images[2][y1b3:y2b3, x1b3:x2b3] + mosaic_image[y1a4:y2a4, x1a4:x2a4] = images[3][y1b4:y2b4, x1b4:x2b4] + + mosaic_target = torch.vstack(targets) + + mosaic_target[:, 0::2] = torch.clip(mosaic_target[:, 0::2], 0, sx) + mosaic_target[:, 1::2] = torch.clip(mosaic_target[:, 1::2], 0, sy) + + w = mosaic_target[:, 2] - mosaic_target[:, 0] + h = mosaic_target[:, 3] - mosaic_target[:, 1] + + final_boxes = final_boxes[(w >= self.size_limit) & (h >= self.size_limit)] + + return mosaic_image, final_boxes From 2b7ed01d11cd2dae76ab39fd6a0ec88468b4a316 Mon Sep 17 00:00:00 2001 From: abhijit_linux Date: Sun, 18 Sep 2022 22:55:22 +0530 Subject: [PATCH 3/8] typo --- references/detection/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 84119f7ebde..398155e1b13 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -607,14 +607,14 @@ def forward(self, images, targets): """ # implementation is heavily inspired from this colab notebook - sx, sy = images.shape[-2], images.shape[-3] + sx, sy = images.shape[-3], images.shape[-2] num_channels = images.shape[1] xc = torch.random.randint(sx * self.min_frac, sx * self.max_frac) yc = torch.random.randint(sy * self.min_frac, sy * self.min_frac) - mosaic_image = torch.zeros((sy, sx, num_channels), dtype=torch.float32) + mosaic_image = torch.zeros((sx, sy, num_channels), dtype=torch.float32) x1a1, y1a1, x2a1, y2a1 = 0, 0, xc, yc x1b1, y1b1, x2b1, y2b1 = sx - xc, sy - yc, sx, sy From ced3ac456bb516fba2320028026ff82cf265fa21 Mon Sep 17 00:00:00 2001 From: abhijit_linux Date: Wed, 21 Sep 2022 00:10:41 +0530 Subject: [PATCH 4/8] channel last -> channel first --- references/detection/transforms.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 398155e1b13..27d6fd7b0bf 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -600,21 +600,21 @@ def __init__(self, min_frac: float = 0.25, max_frac: float = 0.75, size_limit=10 self.max_frac = max_frac self.size_limit = size_limit - def forward(self, images, targets): + def forward(self, images, targets=None): """ images : torch.Tensor of NWHC type targets: list[torch.Tensor]; bounding boxes in xyxy format. """ # implementation is heavily inspired from this colab notebook - sx, sy = images.shape[-3], images.shape[-2] + sx, sy = images.shape[-2:] - num_channels = images.shape[1] + num_channels = images.shape[-3] - xc = torch.random.randint(sx * self.min_frac, sx * self.max_frac) - yc = torch.random.randint(sy * self.min_frac, sy * self.min_frac) + xc = torch.randint(int(sx * self.min_frac), int(sx * self.max_frac), size=(1,)) + yc = torch.randint(int(sy * self.min_frac), int(sy * self.max_frac), size=(1,)) - mosaic_image = torch.zeros((sx, sy, num_channels), dtype=torch.float32) + mosaic_image = torch.zeros((num_channels, sx, sy), dtype=torch.float32) x1a1, y1a1, x2a1, y2a1 = 0, 0, xc, yc x1b1, y1b1, x2b1, y2b1 = sx - xc, sy - yc, sx, sy @@ -652,10 +652,10 @@ def forward(self, images, targets): targets[2][:, 0::2] += offset_x4 targets[2][:, 1::2] += offset_y4 - mosaic_image[y1a1:y2a1, x1a1:x2a1] = images[0][y1b1:y2b1, x1b1:x2b1] - mosaic_image[y1a2:y2a2, x1a2:x2a2] = images[1][y1b2:y2b2, x1b2:x2b2] - mosaic_image[y1a3:y2a3, x1a3:x2a3] = images[2][y1b3:y2b3, x1b3:x2b3] - mosaic_image[y1a4:y2a4, x1a4:x2a4] = images[3][y1b4:y2b4, x1b4:x2b4] + mosaic_image[:, x1a1:x2a1, y1a1:y2a1] = images[0][:, x1b1:x2b1, y1b1:y2b1] + mosaic_image[:, x1a2:x2a2, y1a2:y2a2] = images[1][:, x1b2:x2b2, y1b2:y2b2] + mosaic_image[:, x1a3:x2a3, y1a3:y2a3] = images[2][:, x1b3:x2b3, y1b3:y2b3] + mosaic_image[:, x1a4:x2a4, y1a4:y2a4] = images[3][:, x1b4:x2b4, y1b4:y2b4] mosaic_target = torch.vstack(targets) @@ -665,6 +665,6 @@ def forward(self, images, targets): w = mosaic_target[:, 2] - mosaic_target[:, 0] h = mosaic_target[:, 3] - mosaic_target[:, 1] - final_boxes = final_boxes[(w >= self.size_limit) & (h >= self.size_limit)] + mosaic_target = mosaic_target[(w >= self.size_limit) & (h >= self.size_limit)] - return mosaic_image, final_boxes + return mosaic_image, mosaic_target From eabac3ce18112d28388f7523c5916234da93db5e Mon Sep 17 00:00:00 2001 From: abhijit_linux Date: Wed, 21 Sep 2022 00:20:20 +0530 Subject: [PATCH 5/8] [skip-ci] --- references/detection/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 27d6fd7b0bf..4096758bd58 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -600,7 +600,7 @@ def __init__(self, min_frac: float = 0.25, max_frac: float = 0.75, size_limit=10 self.max_frac = max_frac self.size_limit = size_limit - def forward(self, images, targets=None): + def forward(self, images, targets): """ images : torch.Tensor of NWHC type targets: list[torch.Tensor]; bounding boxes in xyxy format. From c36ab7f11e77c583ed3446f3967c86dd067ad09c Mon Sep 17 00:00:00 2001 From: abhijit_linux Date: Wed, 21 Sep 2022 20:12:14 +0530 Subject: [PATCH 6/8] SKIP CI ;fix typos typos typos --- references/detection/transforms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 4096758bd58..ea4e9bafa93 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -649,8 +649,8 @@ def forward(self, images, targets): targets[2][:, 0::2] += offset_x3 targets[2][:, 1::2] += offset_y3 - targets[2][:, 0::2] += offset_x4 - targets[2][:, 1::2] += offset_y4 + targets[3][:, 0::2] += offset_x4 + targets[3][:, 1::2] += offset_y4 mosaic_image[:, x1a1:x2a1, y1a1:y2a1] = images[0][:, x1b1:x2b1, y1b1:y2b1] mosaic_image[:, x1a2:x2a2, y1a2:y2a2] = images[1][:, x1b2:x2b2, y1b2:y2b2] From 6effa13a2ea378783777764210fa289db93a9b63 Mon Sep 17 00:00:00 2001 From: abhijit_linux Date: Mon, 26 Dec 2022 16:00:39 +0530 Subject: [PATCH 7/8] cleanup and some progress after long time; --- references/detection/transforms.py | 92 ++++++++++++++++-------------- 1 file changed, 50 insertions(+), 42 deletions(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index ea4e9bafa93..ef229c79fd6 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -594,77 +594,85 @@ def __repr__(self) -> str: class Mosaic(nn.Module): + """ + Mosaic Transform + """ + def __init__(self, min_frac: float = 0.25, max_frac: float = 0.75, size_limit=10) -> None: super().__init__() self.min_frac = min_frac self.max_frac = max_frac self.size_limit = size_limit - def forward(self, images, targets): + def forward(self, images, boxes, labels): """ - images : torch.Tensor of NWHC type - targets: list[torch.Tensor]; bounding boxes in xyxy format. + images : torch.Tensor channels first image tensor; + boxes: bounding boxes in xyxy format. + labels: labels corresponding to the bounding boxes. + """ - # implementation is heavily inspired from this colab notebook - sx, sy = images.shape[-2:] + # implementation is heavily inspired from this colab notebook : https://colab.research.google.com/drive/1YWb7a_3bHqG30SoIxU5S4lHKktkRseyY?usp=sharing#scrollTo=yHsYf3Z3bujO + sy, sx = images.shape[-2:] num_channels = images.shape[-3] xc = torch.randint(int(sx * self.min_frac), int(sx * self.max_frac), size=(1,)) yc = torch.randint(int(sy * self.min_frac), int(sy * self.max_frac), size=(1,)) - mosaic_image = torch.zeros((num_channels, sx, sy), dtype=torch.float32) + mosaic_image = torch.zeros((num_channels, sy, sx), dtype=images.dtype) - x1a1, y1a1, x2a1, y2a1 = 0, 0, xc, yc - x1b1, y1b1, x2b1, y2b1 = sx - xc, sy - yc, sx, sy + x0a0, y0a0, x1a0, y1a0 = 0, 0, xc, yc + x0b0, y0b0, x1b0, y1b0 = sx - xc, sy - yc, sx, sy - x1a2, y1a2, x2a2, y2a2 = xc, 0, sx, yc - x1b2, y1b2, x2b2, y2b2 = 0, sy - yc, sx - xc, sy + x0a1, y0a1, x1a1, y1a1 = 0, yc, xc, sy + x0b1, y0b1, x1b1, y1b1 = sx - xc, 0, sx, sy - yc - x1a3, y1a3, x2a3, y2a3 = 0, yc, xc, sy - x1b3, y1b3, x2b3, y2b3 = sx - xc, 0, sx, sy - yc + x0a2, y0a2, x1a2, y1a2 = xc, 0, sx, yc + x0b2, y0b2, x1b2, y1b2 = 0, sy - yc, sx - xc, sy - x1a4, y1a4, x2a4, y2a4 = xc, yc, sx, sy - x1b4, y1b4, x2b4, y2b4 = 0, 0, sx - xc, sy - yc + x0a3, y0a3, x1a3, y1a3 = xc, yc, sx, sy + x0b3, y0b3, x1b3, y1b3 = 0, 0, sx - xc, sy - yc - # calculate and apply box offsets due to replacement - offset_x1 = x1a1 - x1b1 - offset_y1 = y1a1 - y1b1 + mosaic_image[..., y0a0:y1a0, x0a0:x1a0] = images[0][..., y0b0:y1b0, x0b0:x1b0] + mosaic_image[..., y0a1:y1a1, x0a1:x1a1] = images[1][..., y0b1:y1b1, x0b1:x1b1] + mosaic_image[..., y0a2:y1a2, x0a2:x1a2] = images[2][..., y0b2:y1b2, x0b2:x1b2] + mosaic_image[..., y0a3:y1a3, x0a3:x1a3] = images[3][..., y0b3:y1b3, x0b3:x1b3] - offset_y2 = y1a2 - y1b2 - offset_x2 = x1a2 - x1b2 + offset_y0 = y0a0 - y0b0 + offset_x0 = x0a0 - x0b0 - offset_y3 = y1a3 - y1b3 - offset_x3 = x1a3 - x1b3 + offset_y1 = y0a1 - y0b1 + offset_x1 = x0a1 - x0b1 - offset_y4 = y1a4 - y1b4 - offset_x4 = x1a4 - x1b4 + offset_y2 = y0a2 - y0b2 + offset_x2 = x0a2 - x0b2 - targets[0][:, 0::2] += offset_x1 - targets[0][:, 1::2] += offset_y1 + offset_y3 = y0a3 - y0b3 + offset_x3 = x0a3 - x0b3 - targets[1][:, 0::2] += offset_x2 - targets[1][:, 1::2] += offset_y2 + boxes[0][..., 0:4:2] += offset_x0 + boxes[0][..., 1:4:2] += offset_y0 - targets[2][:, 0::2] += offset_x3 - targets[2][:, 1::2] += offset_y3 - targets[3][:, 0::2] += offset_x4 - targets[3][:, 1::2] += offset_y4 + boxes[1][..., 0:4:2] += offset_x1 + boxes[1][..., 1:4:2] += offset_y1 - mosaic_image[:, x1a1:x2a1, y1a1:y2a1] = images[0][:, x1b1:x2b1, y1b1:y2b1] - mosaic_image[:, x1a2:x2a2, y1a2:y2a2] = images[1][:, x1b2:x2b2, y1b2:y2b2] - mosaic_image[:, x1a3:x2a3, y1a3:y2a3] = images[2][:, x1b3:x2b3, y1b3:y2b3] - mosaic_image[:, x1a4:x2a4, y1a4:y2a4] = images[3][:, x1b4:x2b4, y1b4:y2b4] + boxes[2][..., 0:4:2] += offset_x2 + boxes[2][..., 1:4:2] += offset_y2 - mosaic_target = torch.vstack(targets) + boxes[3][..., 0:4:2] += offset_x3 + boxes[3][..., 1:4:2] += offset_y3 - mosaic_target[:, 0::2] = torch.clip(mosaic_target[:, 0::2], 0, sx) - mosaic_target[:, 1::2] = torch.clip(mosaic_target[:, 1::2], 0, sy) + mosaic_boxes = torch.vstack(boxes) + mosaic_labels = torch.vstack(labels) - w = mosaic_target[:, 2] - mosaic_target[:, 0] - h = mosaic_target[:, 3] - mosaic_target[:, 1] + mosaic_boxes[..., 0::2] = torch.clip(mosaic_boxes[..., 0::2], 0, sx) + mosaic_boxes[..., 1::2] = torch.clip(mosaic_boxes[..., 1::2], 0, sy) - mosaic_target = mosaic_target[(w >= self.size_limit) & (h >= self.size_limit)] + w = mosaic_boxes[..., 2] - mosaic_boxes[..., 0] + h = mosaic_boxes[..., 3] - mosaic_boxes[..., 1] - return mosaic_image, mosaic_target + mask = (w >= self.size_limit) & (h >= self.size_limit) + mosaic_boxes = mosaic_boxes[mask] + mosaic_labels = mosaic_labels[mask] + return mosaic_image, mosaic_boxes, mosaic_labels From f0c2b42b80f7061a93ddf980519588c183da1108 Mon Sep 17 00:00:00 2001 From: abhijit_linux Date: Tue, 11 Oct 2022 21:31:54 +0530 Subject: [PATCH 8/8] add batch support --- references/detection/transforms.py | 76 ++++++++++++++++++------------ 1 file changed, 46 insertions(+), 30 deletions(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index ef229c79fd6..a80941b91a4 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -598,30 +598,35 @@ class Mosaic(nn.Module): Mosaic Transform """ - def __init__(self, min_frac: float = 0.25, max_frac: float = 0.75, size_limit=10) -> None: + def __init__(self, min_frac: float = 0.25, max_frac: float = 0.75, size_limit: int = 10) -> None: super().__init__() self.min_frac = min_frac self.max_frac = max_frac self.size_limit = size_limit - def forward(self, images, boxes, labels): + def forward( + self, images: torch.Tensor, boxes: List[List[List[torch.Tensor]]], labels: List[List[List[torch.Tensor]]] + ): """ - images : torch.Tensor channels first image tensor; - boxes: bounding boxes in xyxy format. - labels: labels corresponding to the bounding boxes. - + images (torch.Tensor) : Image Tensor of B*4*C*H*W + boxes (List[List[List[torch.Tensor]]]) : Bounding boxes in xyxy format. + labels (List[List[torch.Tensor]]) : labels corresponding to the bounding boxes. """ # implementation is heavily inspired from this colab notebook : https://colab.research.google.com/drive/1YWb7a_3bHqG30SoIxU5S4lHKktkRseyY?usp=sharing#scrollTo=yHsYf3Z3bujO - sy, sx = images.shape[-2:] + if images.ndim != 5: + raise ValueError(f"Image tensor should be 5 dimensional, got {images.ndim}") + if images.shape[1] != 4: + raise ValueError(f"First dimension of image tensor should be of size 4, got {images.shape[1]}") + + sy, sx = images.shape[-2:] + B = images.shape[0] num_channels = images.shape[-3] xc = torch.randint(int(sx * self.min_frac), int(sx * self.max_frac), size=(1,)) yc = torch.randint(int(sy * self.min_frac), int(sy * self.max_frac), size=(1,)) - mosaic_image = torch.zeros((num_channels, sy, sx), dtype=images.dtype) - x0a0, y0a0, x1a0, y1a0 = 0, 0, xc, yc x0b0, y0b0, x1b0, y1b0 = sx - xc, sy - yc, sx, sy @@ -634,11 +639,14 @@ def forward(self, images, boxes, labels): x0a3, y0a3, x1a3, y1a3 = xc, yc, sx, sy x0b3, y0b3, x1b3, y1b3 = 0, 0, sx - xc, sy - yc - mosaic_image[..., y0a0:y1a0, x0a0:x1a0] = images[0][..., y0b0:y1b0, x0b0:x1b0] - mosaic_image[..., y0a1:y1a1, x0a1:x1a1] = images[1][..., y0b1:y1b1, x0b1:x1b1] - mosaic_image[..., y0a2:y1a2, x0a2:x1a2] = images[2][..., y0b2:y1b2, x0b2:x1b2] - mosaic_image[..., y0a3:y1a3, x0a3:x1a3] = images[3][..., y0b3:y1b3, x0b3:x1b3] + mosaic_image = torch.zeros((B, num_channels, sy, sx), dtype=images.dtype) + + mosaic_image[..., y0a0:y1a0, x0a0:x1a0] = images[:, 0][..., y0b0:y1b0, x0b0:x1b0] + mosaic_image[..., y0a1:y1a1, x0a1:x1a1] = images[:, 1][..., y0b1:y1b1, x0b1:x1b1] + mosaic_image[..., y0a2:y1a2, x0a2:x1a2] = images[:, 2][..., y0b2:y1b2, x0b2:x1b2] + mosaic_image[..., y0a3:y1a3, x0a3:x1a3] = images[:, 3][..., y0b3:y1b3, x0b3:x1b3] + # calculating offsets for the bounding boxes; offset_y0 = y0a0 - y0b0 offset_x0 = x0a0 - x0b0 @@ -650,29 +658,37 @@ def forward(self, images, boxes, labels): offset_y3 = y0a3 - y0b3 offset_x3 = x0a3 - x0b3 + mosaic_boxes = [] + mosaic_labels = [] + + for i in range(B): + boxes[i][0][:, 0:4:2] += offset_x0 + boxes[i][0][:, 1:4:2] += offset_y0 + + boxes[i][1][:, 0:4:2] += offset_x1 + boxes[i][1][:, 1:4:2] += offset_y1 + + boxes[i][2][:, 0:4:2] += offset_x2 + boxes[i][2][:, 1:4:2] += offset_y2 - boxes[0][..., 0:4:2] += offset_x0 - boxes[0][..., 1:4:2] += offset_y0 + boxes[i][3][:, 0:4:2] += offset_x3 + boxes[i][3][:, 1:4:2] += offset_y3 - boxes[1][..., 0:4:2] += offset_x1 - boxes[1][..., 1:4:2] += offset_y1 + temp_box = torch.vstack(boxes[i]) + temp_label = torch.vstack(labels[i]) - boxes[2][..., 0:4:2] += offset_x2 - boxes[2][..., 1:4:2] += offset_y2 + temp_box[..., 0::2] = torch.clip(temp_box[..., 0::2], 0, sx) + temp_box[..., 1::2] = torch.clip(temp_box[..., 1::2], 0, sy) - boxes[3][..., 0:4:2] += offset_x3 - boxes[3][..., 1:4:2] += offset_y3 + w_ = temp_box[..., 2] - temp_box[..., 0] + h_ = temp_box[..., 3] - temp_box[..., 1] - mosaic_boxes = torch.vstack(boxes) - mosaic_labels = torch.vstack(labels) + mask_ = (w_ > self.size_limit) & (h_ > self.size_limit) - mosaic_boxes[..., 0::2] = torch.clip(mosaic_boxes[..., 0::2], 0, sx) - mosaic_boxes[..., 1::2] = torch.clip(mosaic_boxes[..., 1::2], 0, sy) + temp_box = temp_box[mask_] + temp_label = temp_label[mask_] - w = mosaic_boxes[..., 2] - mosaic_boxes[..., 0] - h = mosaic_boxes[..., 3] - mosaic_boxes[..., 1] + mosaic_boxes.append(temp_box) + mosaic_labels.append(temp_label) - mask = (w >= self.size_limit) & (h >= self.size_limit) - mosaic_boxes = mosaic_boxes[mask] - mosaic_labels = mosaic_labels[mask] return mosaic_image, mosaic_boxes, mosaic_labels