-
-
Notifications
You must be signed in to change notification settings - Fork 399
/
Copy pathdataset.py
102 lines (79 loc) · 2.92 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
'''by lyuwenyu
'''
import os
import glob
from PIL import Image
import torch
import torch.utils.data as data
import torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as F
class ToTensor(T.ToTensor):
def __init__(self) -> None:
super().__init__()
def __call__(self, pic):
if isinstance(pic, torch.Tensor):
return pic
return super().__call__(pic)
class PadToSize(T.Pad):
def __init__(self, size, fill=0, padding_mode='constant'):
super().__init__(0, fill, padding_mode)
self.size = size
self.fill = fill
def __call__(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be padded.
Returns:
PIL Image or Tensor: Padded image.
"""
w, h = F.get_image_size(img)
padding = (0, 0, self.size[0] - w, self.size[1] - h)
return F.pad(img, padding, self.fill, self.padding_mode)
class Dataset(data.Dataset):
def __init__(self, img_dir: str='', preprocess: T.Compose=None, device='cuda:0') -> None:
super().__init__()
self.device = device
self.size = 640
self.im_path_list = list(glob.glob(os.path.join(img_dir, '*.jpg')))
if preprocess is None:
self.preprocess = T.Compose([
T.Resize(size=639, max_size=640),
PadToSize(size=(640, 640), fill=114),
ToTensor(),
T.ConvertImageDtype(torch.float),
])
else:
self.preprocess = preprocess
def __len__(self, ):
return len(self.im_path_list)
def __getitem__(self, index):
# im = Image.open(self.img_path_list[index]).convert('RGB')
im = torchvision.io.read_file(self.im_path_list[index])
im = torchvision.io.decode_jpeg(im, mode=torchvision.io.ImageReadMode.RGB, device=self.device)
_, h, w = im.shape # c,h,w
im = self.preprocess(im)
blob = {
'image': im,
'im_shape': torch.tensor([self.size, self.size]).to(im.device),
'scale_factor': torch.tensor([self.size / h, self.size / w]).to(im.device),
'orig_size': torch.tensor([w, h]).to(im.device),
}
return blob
@staticmethod
def post_process():
pass
@staticmethod
def collate_fn():
pass
def draw_nms_result(blob, outputs, draw_score_threshold=0.25, name=''):
'''show result
Keys:
'num_dets', 'det_boxes', 'det_scores', 'det_classes'
'''
for i in range(blob['image'].shape[0]):
det_scores = outputs['det_scores'][i]
det_boxes = outputs['det_boxes'][i][det_scores > draw_score_threshold]
im = (blob['image'][i] * 255).to(torch.uint8)
im = torchvision.utils.draw_bounding_boxes(im, boxes=det_boxes, width=2)
Image.fromarray(im.permute(1, 2, 0).cpu().numpy()).save(f'test_{name}_{i}.jpg')