|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import torchvision.transforms as T |
| 4 | +from torch.cuda.amp import autocast |
| 5 | +import numpy as np |
| 6 | +from PIL import Image, ImageDraw, ImageFont |
| 7 | +import os |
| 8 | +import sys |
| 9 | +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')) |
| 10 | +import argparse |
| 11 | +import src.misc.dist as dist |
| 12 | +from src.core import YAMLConfig |
| 13 | +from src.solver import TASKS |
| 14 | +import numpy as np |
| 15 | + |
| 16 | +def postprocess(labels, boxes, scores, iou_threshold=0.55): |
| 17 | + def calculate_iou(box1, box2): |
| 18 | + x1, y1, x2, y2 = box1 |
| 19 | + x3, y3, x4, y4 = box2 |
| 20 | + xi1 = max(x1, x3) |
| 21 | + yi1 = max(y1, y3) |
| 22 | + xi2 = min(x2, x4) |
| 23 | + yi2 = min(y2, y4) |
| 24 | + inter_width = max(0, xi2 - xi1) |
| 25 | + inter_height = max(0, yi2 - yi1) |
| 26 | + inter_area = inter_width * inter_height |
| 27 | + box1_area = (x2 - x1) * (y2 - y1) |
| 28 | + box2_area = (x4 - x3) * (y4 - y3) |
| 29 | + union_area = box1_area + box2_area - inter_area |
| 30 | + iou = inter_area / union_area if union_area != 0 else 0 |
| 31 | + return iou |
| 32 | + merged_labels = [] |
| 33 | + merged_boxes = [] |
| 34 | + merged_scores = [] |
| 35 | + used_indices = set() |
| 36 | + for i in range(len(boxes)): |
| 37 | + if i in used_indices: |
| 38 | + continue |
| 39 | + current_box = boxes[i] |
| 40 | + current_label = labels[i] |
| 41 | + current_score = scores[i] |
| 42 | + boxes_to_merge = [current_box] |
| 43 | + scores_to_merge = [current_score] |
| 44 | + used_indices.add(i) |
| 45 | + for j in range(i + 1, len(boxes)): |
| 46 | + if j in used_indices: |
| 47 | + continue |
| 48 | + if labels[j] != current_label: |
| 49 | + continue |
| 50 | + other_box = boxes[j] |
| 51 | + iou = calculate_iou(current_box, other_box) |
| 52 | + if iou >= iou_threshold: |
| 53 | + boxes_to_merge.append(other_box.tolist()) |
| 54 | + scores_to_merge.append(scores[j]) |
| 55 | + used_indices.add(j) |
| 56 | + xs = np.concatenate([[box[0], box[2]] for box in boxes_to_merge]) |
| 57 | + ys = np.concatenate([[box[1], box[3]] for box in boxes_to_merge]) |
| 58 | + merged_box = [np.min(xs), np.min(ys), np.max(xs), np.max(ys)] |
| 59 | + merged_score = max(scores_to_merge) |
| 60 | + merged_boxes.append(merged_box) |
| 61 | + merged_labels.append(current_label) |
| 62 | + merged_scores.append(merged_score) |
| 63 | + return [np.array(merged_labels)], [np.array(merged_boxes)], [np.array(merged_scores)] |
| 64 | +def slice_image(image, slice_height, slice_width, overlap_ratio): |
| 65 | + img_width, img_height = image.size |
| 66 | + |
| 67 | + slices = [] |
| 68 | + coordinates = [] |
| 69 | + step_x = int(slice_width * (1 - overlap_ratio)) |
| 70 | + step_y = int(slice_height * (1 - overlap_ratio)) |
| 71 | + |
| 72 | + for y in range(0, img_height, step_y): |
| 73 | + for x in range(0, img_width, step_x): |
| 74 | + box = (x, y, min(x + slice_width, img_width), min(y + slice_height, img_height)) |
| 75 | + slice_img = image.crop(box) |
| 76 | + slices.append(slice_img) |
| 77 | + coordinates.append((x, y)) |
| 78 | + return slices, coordinates |
| 79 | +def merge_predictions(predictions, slice_coordinates, orig_image_size, slice_width, slice_height, threshold=0.30): |
| 80 | + merged_labels = [] |
| 81 | + merged_boxes = [] |
| 82 | + merged_scores = [] |
| 83 | + orig_height, orig_width = orig_image_size |
| 84 | + for i, (label, boxes, scores) in enumerate(predictions): |
| 85 | + x_shift, y_shift = slice_coordinates[i] |
| 86 | + scores = np.array(scores).reshape(-1) |
| 87 | + valid_indices = scores > threshold |
| 88 | + valid_labels = np.array(label).reshape(-1)[valid_indices] |
| 89 | + valid_boxes = np.array(boxes).reshape(-1, 4)[valid_indices] |
| 90 | + valid_scores = scores[valid_indices] |
| 91 | + for j, box in enumerate(valid_boxes): |
| 92 | + box[0] = np.clip(box[0] + x_shift, 0, orig_width) |
| 93 | + box[1] = np.clip(box[1] + y_shift, 0, orig_height) |
| 94 | + box[2] = np.clip(box[2] + x_shift, 0, orig_width) |
| 95 | + box[3] = np.clip(box[3] + y_shift, 0, orig_height) |
| 96 | + valid_boxes[j] = box |
| 97 | + merged_labels.extend(valid_labels) |
| 98 | + merged_boxes.extend(valid_boxes) |
| 99 | + merged_scores.extend(valid_scores) |
| 100 | + return np.array(merged_labels), np.array(merged_boxes), np.array(merged_scores) |
| 101 | +def draw(images, labels, boxes, scores, thrh = 0.6, path = ""): |
| 102 | + for i, im in enumerate(images): |
| 103 | + draw = ImageDraw.Draw(im) |
| 104 | + scr = scores[i] |
| 105 | + lab = labels[i][scr > thrh] |
| 106 | + box = boxes[i][scr > thrh] |
| 107 | + scrs = scores[i][scr > thrh] |
| 108 | + for j,b in enumerate(box): |
| 109 | + draw.rectangle(list(b), outline='red',) |
| 110 | + draw.text((b[0], b[1]), text=f"label: {lab[j].item()} {round(scrs[j].item(),2)}", font=ImageFont.load_default(), fill='blue') |
| 111 | + if path == "": |
| 112 | + im.save(f'results_{i}.jpg') |
| 113 | + else: |
| 114 | + im.save(path) |
| 115 | + |
| 116 | +def main(args, ): |
| 117 | + """main |
| 118 | + """ |
| 119 | + cfg = YAMLConfig(args.config, resume=args.resume) |
| 120 | + if args.resume: |
| 121 | + checkpoint = torch.load(args.resume, map_location='cpu') |
| 122 | + if 'ema' in checkpoint: |
| 123 | + state = checkpoint['ema']['module'] |
| 124 | + else: |
| 125 | + state = checkpoint['model'] |
| 126 | + else: |
| 127 | + raise AttributeError('Only support resume to load model.state_dict by now.') |
| 128 | + # NOTE load train mode state -> convert to deploy mode |
| 129 | + cfg.model.load_state_dict(state) |
| 130 | + class Model(nn.Module): |
| 131 | + def __init__(self, ) -> None: |
| 132 | + super().__init__() |
| 133 | + self.model = cfg.model.deploy() |
| 134 | + self.postprocessor = cfg.postprocessor.deploy() |
| 135 | + |
| 136 | + def forward(self, images, orig_target_sizes): |
| 137 | + outputs = self.model(images) |
| 138 | + outputs = self.postprocessor(outputs, orig_target_sizes) |
| 139 | + return outputs |
| 140 | + |
| 141 | + model = Model().to(args.device) |
| 142 | + im_pil = Image.open(args.im_file).convert('RGB') |
| 143 | + w, h = im_pil.size |
| 144 | + orig_size = torch.tensor([w, h])[None].to(args.device) |
| 145 | + |
| 146 | + transforms = T.Compose([ |
| 147 | + T.Resize((640, 640)), |
| 148 | + T.ToTensor(), |
| 149 | + ]) |
| 150 | + im_data = transforms(im_pil)[None].to(args.device) |
| 151 | + if args.sliced: |
| 152 | + num_boxes = args.numberofboxes |
| 153 | + |
| 154 | + aspect_ratio = w / h |
| 155 | + num_cols = int(np.sqrt(num_boxes * aspect_ratio)) |
| 156 | + num_rows = int(num_boxes / num_cols) |
| 157 | + slice_height = h // num_rows |
| 158 | + slice_width = w // num_cols |
| 159 | + overlap_ratio = 0.2 |
| 160 | + slices, coordinates = slice_image(im_pil, slice_height, slice_width, overlap_ratio) |
| 161 | + predictions = [] |
| 162 | + for i, slice_img in enumerate(slices): |
| 163 | + slice_tensor = transforms(slice_img)[None].to(args.device) |
| 164 | + with autocast(): # Use AMP for each slice |
| 165 | + output = model(slice_tensor, torch.tensor([[slice_img.size[0], slice_img.size[1]]]).to(args.device)) |
| 166 | + torch.cuda.empty_cache() |
| 167 | + labels, boxes, scores = output |
| 168 | + |
| 169 | + labels = labels.cpu().detach().numpy() |
| 170 | + boxes = boxes.cpu().detach().numpy() |
| 171 | + scores = scores.cpu().detach().numpy() |
| 172 | + predictions.append((labels, boxes, scores)) |
| 173 | + |
| 174 | + merged_labels, merged_boxes, merged_scores = merge_predictions(predictions, coordinates, (h, w), slice_width, slice_height) |
| 175 | + labels, boxes, scores = postprocess(merged_labels, merged_boxes, merged_scores) |
| 176 | + else: |
| 177 | + output = model(im_data, orig_size) |
| 178 | + labels, boxes, scores = output |
| 179 | + |
| 180 | + draw([im_pil], labels, boxes, scores, 0.6) |
| 181 | + |
| 182 | +if __name__ == '__main__': |
| 183 | + import argparse |
| 184 | + parser = argparse.ArgumentParser() |
| 185 | + parser.add_argument('-c', '--config', type=str, ) |
| 186 | + parser.add_argument('-r', '--resume', type=str, ) |
| 187 | + parser.add_argument('-f', '--im-file', type=str, ) |
| 188 | + parser.add_argument('-s', '--sliced', type=bool, default=False) |
| 189 | + parser.add_argument('-d', '--device', type=str, default='cpu') |
| 190 | + parser.add_argument('-nc', '--numberofboxes', type=int, default=25) |
| 191 | + args = parser.parse_args() |
| 192 | + main(args) |
| 193 | + |
| 194 | + |
| 195 | + |
| 196 | + |
| 197 | + |
| 198 | + |
| 199 | + |
| 200 | + |
| 201 | + |
| 202 | + |
| 203 | + |
0 commit comments