-
Notifications
You must be signed in to change notification settings - Fork 3
/
dataset.py
185 lines (165 loc) · 7.57 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import os
from typing import Callable, Dict, Optional, Tuple
import cv2
import numpy as np
import torch
from src.datatools.line import get_extreme_points, sort_anno
from src.datatools.reader import read_annot
from torch.utils.data.dataset import Dataset
class EHMDataset(Dataset):
"""Dataset for extreme heatpoint detection model.
Notes:
The dataset sample includes:
- keypoints (np.ndarray): keypoint coordinates in pixels in an
order of x, y, flag.
- image (np.ndarray): The original image. Shape is (H, W, C).
- line_para (List[Tuple[float, float]]): List of line_para for
each line (slope, intercept).
- keypoint-maps: The target heatmaps, each heatmap corresponding
to one line, in other words, two extreme heat points. Shape is
(num_keypoint_pairs, H // stride, W // stride).
Args:
dataset_folder (str): The path to the dataset folder.
stride (int): The factor that is used to down-sample the input.
sigma (int): The standard deviation of the Gaussian kernel used to
generate the heat points.
input_size (Tuple[int, int]): (height, width) of
the input image.
num_keypoint_pairs (int): The number of keypoint pairs to generate
lines.
transform (Optional[Callable]): Data transformation applied to the
input.
"""
def __init__(self,
dataset_folder: str,
stride: int = 4,
sigma: int = 7,
input_size: Tuple[int, int] = (960, 540),
num_keypoint_pairs: int = 23,
transform: Optional[Callable] = None):
super().__init__()
self._dataset_folder = dataset_folder
self._stride = stride
self._sigma = sigma
self.num_keypoint_pairs = num_keypoint_pairs
self._transform = transform
self._labels = []
self._img_paths = []
self.input_size = input_size
for fname in os.listdir(dataset_folder):
if 'info' not in fname:
annot_path = os.path.join(dataset_folder, fname)
if annot_path.endswith('.json'):
img_path = annot_path.replace('.json', '.jpg')
if os.path.exists(img_path):
res = read_annot(annot_path)
res, usable_flag = sort_anno(res,
img_size=self.input_size)
if usable_flag:
self._labels.append(
get_extreme_points(res,
img_size=self.input_size))
self._img_paths.append(img_path)
print(f'Size of {dataset_folder} is {len(self._img_paths)}')
def __getitem__(self, idx: int) -> Dict:
kpts_dict = self._labels[idx]
image = cv2.imread(self._img_paths[idx], cv2.IMREAD_COLOR)
image = cv2.resize(image, self.input_size)
keypoints = np.ones(self.num_keypoint_pairs * 3 * 2,
dtype=np.float32) * -1
line_paras = []
for idx in range(self.num_keypoint_pairs):
if kpts_dict[idx] is not None:
# (points, line_para)
points = kpts_dict[idx][0]
keypoints[idx * 6] = points[0][0]
keypoints[idx * 6 + 1] = points[0][1]
keypoints[idx * 6 + 2] = 1
keypoints[idx * 6 + 3] = points[1][0]
keypoints[idx * 6 + 4] = points[1][1]
keypoints[idx * 6 + 5] = 1
line_paras.append(kpts_dict[idx][1])
else:
keypoints[idx * 6 + 2] = 0
keypoints[idx * 6 + 5] = 0
line_paras.append((np.nan, np.nan))
sample = {
'keypoints': keypoints,
'image': image,
'line_para': line_paras
}
sample['keypoint_maps'] = self._generate_keypoint_maps(sample)
if self._transform:
sample = self._transform(sample)
return sample
def __len__(self):
return len(self._labels)
def _generate_keypoint_maps(self, sample: Dict) -> torch.Tensor:
n_rows, n_cols, _ = sample['image'].shape
keypoint_maps = np.zeros(
shape=(self.num_keypoint_pairs, int(round(n_rows / self._stride)),
int(round(n_cols / self._stride))), dtype=np.float32)
keypoints = sample['keypoints']
all_ps = []
for id in range(len(keypoints) // 6):
# Prepare points to be converted to heat
points = []
if keypoints[id * 6 + 2] == 1:
p1 = (keypoints[id * 6], keypoints[id * 6 + 1])
points.append(p1)
all_ps.append(p1)
if keypoints[id * 6 + 5] == 1:
p2 = (keypoints[id * 6 + 3], keypoints[id * 6 + 4])
points.append(p2)
all_ps.append(p2)
self._add_gaussian(keypoint_maps[id], points,
self._stride, self._sigma)
return torch.tensor(keypoint_maps)
def _add_gaussian(self, keypoint_map: np.ndarray, points: list,
stride: int, sigma: float = 1) -> np.ndarray:
"""
Adds Gaussian peaks to a keypoint map at specified points.
This method modifies the input keypoint map by adding Gaussian peaks
centered at the provided point locations. Each point in 'points' is
expected to be a 2D coordinate. The Gaussian peak is added such that
its maximum is at the point's location, and it spreads out according
to the specified 'sigma'.
The method handles the stride and ensures the Gaussian is added
correctly considering the scale of the heatmap.
Args:
keypoint_map (np.ndarray): A 2D numpy array representing the
keypoint map to which Gaussian peaks will be added. Its shape
is expected to be (img_h // stride, img_w // stride).
points (list): A list of points, where each point is a list or
tuple of two elements representing the x and y coordinates,
respectively.
stride (int): The stride of the keypoint map, which scales the
point locations appropriately.
sigma (float, optional): The standard deviation of the Gaussian
distribution. Defaults to 1.
Returns:
np.ndarray: The updated keypoint map with Gaussian peaks added at
the specified points.
Shape is (img_h // stride, img_w // stride).
"""
# keypoint_map shape of img_h//stride, img_w//stride
h, w = keypoint_map.shape
if len(points) > 0:
x = np.arange(w).astype(float)
y = np.arange(h).astype(float)
x_grids, y_grids = np.meshgrid(x, y)
for point in points:
x, y = point[0], point[1]
mu_x, mu_y = (min(w-1, round(x/stride)),
min(h-1, round(y/stride)))
gauss = np.exp(-((x_grids - mu_x) ** 2 + (y_grids - mu_y) ** 2
) / (2 * sigma ** 2))
max_value = np.max(gauss)
gauss /= max_value
keypoint_map += gauss
return keypoint_map
if __name__ == "__main__":
dataset_path = '/workdir/data/dataset/valid'
data = EHMDataset(dataset_path)
first_data = data[0]
print(first_data)