Skip to content

Commit 74f82d6

Browse files
mali-afridikhadijairfan2345
andauthoredOct 11, 2024
Sliced Inference Support for Improved Results (#468)
* sliced inference support * Update README.md --------- Co-authored-by: khadijairfan2345 <khadijairfan2345@gmail.com>
1 parent 1497151 commit 74f82d6

File tree

2 files changed

+204
-0
lines changed

2 files changed

+204
-0
lines changed
 

‎README.md

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ This is the official implementation of papers
4646

4747

4848
## 🚀 Updates
49+
- \[2024.10.10\] Added Sliced Inference Support for Better results on Distant Objects.
4950
- \[2024.09.23\] Added Backbone Support for Regnet and DLA34.
5051
- \[2024.08.27\] Add hubconf.py file to support torch hub.
5152
- \[2024.08.22\] Improve the performance of ✅ [RT-DETRv2-S](./rtdetrv2_pytorch/) to 48.1 mAP (<font color=green>+1.6</font> compared to RT-DETR-R18).

‎rtdetr_pytorch/tools/infer.py

+203
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
import torch
2+
import torch.nn as nn
3+
import torchvision.transforms as T
4+
from torch.cuda.amp import autocast
5+
import numpy as np
6+
from PIL import Image, ImageDraw, ImageFont
7+
import os
8+
import sys
9+
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
10+
import argparse
11+
import src.misc.dist as dist
12+
from src.core import YAMLConfig
13+
from src.solver import TASKS
14+
import numpy as np
15+
16+
def postprocess(labels, boxes, scores, iou_threshold=0.55):
17+
def calculate_iou(box1, box2):
18+
x1, y1, x2, y2 = box1
19+
x3, y3, x4, y4 = box2
20+
xi1 = max(x1, x3)
21+
yi1 = max(y1, y3)
22+
xi2 = min(x2, x4)
23+
yi2 = min(y2, y4)
24+
inter_width = max(0, xi2 - xi1)
25+
inter_height = max(0, yi2 - yi1)
26+
inter_area = inter_width * inter_height
27+
box1_area = (x2 - x1) * (y2 - y1)
28+
box2_area = (x4 - x3) * (y4 - y3)
29+
union_area = box1_area + box2_area - inter_area
30+
iou = inter_area / union_area if union_area != 0 else 0
31+
return iou
32+
merged_labels = []
33+
merged_boxes = []
34+
merged_scores = []
35+
used_indices = set()
36+
for i in range(len(boxes)):
37+
if i in used_indices:
38+
continue
39+
current_box = boxes[i]
40+
current_label = labels[i]
41+
current_score = scores[i]
42+
boxes_to_merge = [current_box]
43+
scores_to_merge = [current_score]
44+
used_indices.add(i)
45+
for j in range(i + 1, len(boxes)):
46+
if j in used_indices:
47+
continue
48+
if labels[j] != current_label:
49+
continue
50+
other_box = boxes[j]
51+
iou = calculate_iou(current_box, other_box)
52+
if iou >= iou_threshold:
53+
boxes_to_merge.append(other_box.tolist())
54+
scores_to_merge.append(scores[j])
55+
used_indices.add(j)
56+
xs = np.concatenate([[box[0], box[2]] for box in boxes_to_merge])
57+
ys = np.concatenate([[box[1], box[3]] for box in boxes_to_merge])
58+
merged_box = [np.min(xs), np.min(ys), np.max(xs), np.max(ys)]
59+
merged_score = max(scores_to_merge)
60+
merged_boxes.append(merged_box)
61+
merged_labels.append(current_label)
62+
merged_scores.append(merged_score)
63+
return [np.array(merged_labels)], [np.array(merged_boxes)], [np.array(merged_scores)]
64+
def slice_image(image, slice_height, slice_width, overlap_ratio):
65+
img_width, img_height = image.size
66+
67+
slices = []
68+
coordinates = []
69+
step_x = int(slice_width * (1 - overlap_ratio))
70+
step_y = int(slice_height * (1 - overlap_ratio))
71+
72+
for y in range(0, img_height, step_y):
73+
for x in range(0, img_width, step_x):
74+
box = (x, y, min(x + slice_width, img_width), min(y + slice_height, img_height))
75+
slice_img = image.crop(box)
76+
slices.append(slice_img)
77+
coordinates.append((x, y))
78+
return slices, coordinates
79+
def merge_predictions(predictions, slice_coordinates, orig_image_size, slice_width, slice_height, threshold=0.30):
80+
merged_labels = []
81+
merged_boxes = []
82+
merged_scores = []
83+
orig_height, orig_width = orig_image_size
84+
for i, (label, boxes, scores) in enumerate(predictions):
85+
x_shift, y_shift = slice_coordinates[i]
86+
scores = np.array(scores).reshape(-1)
87+
valid_indices = scores > threshold
88+
valid_labels = np.array(label).reshape(-1)[valid_indices]
89+
valid_boxes = np.array(boxes).reshape(-1, 4)[valid_indices]
90+
valid_scores = scores[valid_indices]
91+
for j, box in enumerate(valid_boxes):
92+
box[0] = np.clip(box[0] + x_shift, 0, orig_width)
93+
box[1] = np.clip(box[1] + y_shift, 0, orig_height)
94+
box[2] = np.clip(box[2] + x_shift, 0, orig_width)
95+
box[3] = np.clip(box[3] + y_shift, 0, orig_height)
96+
valid_boxes[j] = box
97+
merged_labels.extend(valid_labels)
98+
merged_boxes.extend(valid_boxes)
99+
merged_scores.extend(valid_scores)
100+
return np.array(merged_labels), np.array(merged_boxes), np.array(merged_scores)
101+
def draw(images, labels, boxes, scores, thrh = 0.6, path = ""):
102+
for i, im in enumerate(images):
103+
draw = ImageDraw.Draw(im)
104+
scr = scores[i]
105+
lab = labels[i][scr > thrh]
106+
box = boxes[i][scr > thrh]
107+
scrs = scores[i][scr > thrh]
108+
for j,b in enumerate(box):
109+
draw.rectangle(list(b), outline='red',)
110+
draw.text((b[0], b[1]), text=f"label: {lab[j].item()} {round(scrs[j].item(),2)}", font=ImageFont.load_default(), fill='blue')
111+
if path == "":
112+
im.save(f'results_{i}.jpg')
113+
else:
114+
im.save(path)
115+
116+
def main(args, ):
117+
"""main
118+
"""
119+
cfg = YAMLConfig(args.config, resume=args.resume)
120+
if args.resume:
121+
checkpoint = torch.load(args.resume, map_location='cpu')
122+
if 'ema' in checkpoint:
123+
state = checkpoint['ema']['module']
124+
else:
125+
state = checkpoint['model']
126+
else:
127+
raise AttributeError('Only support resume to load model.state_dict by now.')
128+
# NOTE load train mode state -> convert to deploy mode
129+
cfg.model.load_state_dict(state)
130+
class Model(nn.Module):
131+
def __init__(self, ) -> None:
132+
super().__init__()
133+
self.model = cfg.model.deploy()
134+
self.postprocessor = cfg.postprocessor.deploy()
135+
136+
def forward(self, images, orig_target_sizes):
137+
outputs = self.model(images)
138+
outputs = self.postprocessor(outputs, orig_target_sizes)
139+
return outputs
140+
141+
model = Model().to(args.device)
142+
im_pil = Image.open(args.im_file).convert('RGB')
143+
w, h = im_pil.size
144+
orig_size = torch.tensor([w, h])[None].to(args.device)
145+
146+
transforms = T.Compose([
147+
T.Resize((640, 640)),
148+
T.ToTensor(),
149+
])
150+
im_data = transforms(im_pil)[None].to(args.device)
151+
if args.sliced:
152+
num_boxes = args.numberofboxes
153+
154+
aspect_ratio = w / h
155+
num_cols = int(np.sqrt(num_boxes * aspect_ratio))
156+
num_rows = int(num_boxes / num_cols)
157+
slice_height = h // num_rows
158+
slice_width = w // num_cols
159+
overlap_ratio = 0.2
160+
slices, coordinates = slice_image(im_pil, slice_height, slice_width, overlap_ratio)
161+
predictions = []
162+
for i, slice_img in enumerate(slices):
163+
slice_tensor = transforms(slice_img)[None].to(args.device)
164+
with autocast(): # Use AMP for each slice
165+
output = model(slice_tensor, torch.tensor([[slice_img.size[0], slice_img.size[1]]]).to(args.device))
166+
torch.cuda.empty_cache()
167+
labels, boxes, scores = output
168+
169+
labels = labels.cpu().detach().numpy()
170+
boxes = boxes.cpu().detach().numpy()
171+
scores = scores.cpu().detach().numpy()
172+
predictions.append((labels, boxes, scores))
173+
174+
merged_labels, merged_boxes, merged_scores = merge_predictions(predictions, coordinates, (h, w), slice_width, slice_height)
175+
labels, boxes, scores = postprocess(merged_labels, merged_boxes, merged_scores)
176+
else:
177+
output = model(im_data, orig_size)
178+
labels, boxes, scores = output
179+
180+
draw([im_pil], labels, boxes, scores, 0.6)
181+
182+
if __name__ == '__main__':
183+
import argparse
184+
parser = argparse.ArgumentParser()
185+
parser.add_argument('-c', '--config', type=str, )
186+
parser.add_argument('-r', '--resume', type=str, )
187+
parser.add_argument('-f', '--im-file', type=str, )
188+
parser.add_argument('-s', '--sliced', type=bool, default=False)
189+
parser.add_argument('-d', '--device', type=str, default='cpu')
190+
parser.add_argument('-nc', '--numberofboxes', type=int, default=25)
191+
args = parser.parse_args()
192+
main(args)
193+
194+
195+
196+
197+
198+
199+
200+
201+
202+
203+

0 commit comments

Comments
 (0)
Please sign in to comment.