20
20
21
21
22
22
def apply_label_noise (
23
- labels : torch .Tensor ,
24
- label_noise_prob : float = 0.2 ,
23
+ labels : torch .Tensor ,
24
+ label_noise_prob : float = 0.2 ,
25
25
num_classes : int = 80 ,
26
26
):
27
27
"""
@@ -57,16 +57,14 @@ def apply_box_noise(
57
57
diff = torch .zeros_like (boxes )
58
58
diff [:, :2 ] = boxes [:, 2 :] / 2
59
59
diff [:, 2 :] = boxes [:, 2 :]
60
- boxes += (
61
- torch .mul ((torch .rand_like (boxes ) * 2 - 1.0 ), diff ) * box_noise_scale
62
- )
60
+ boxes += torch .mul ((torch .rand_like (boxes ) * 2 - 1.0 ), diff ) * box_noise_scale
63
61
boxes = boxes .clamp (min = 0.0 , max = 1.0 )
64
62
return boxes
65
63
66
64
67
65
class GenerateDNQueries (nn .Module ):
68
66
"""Generate denoising queries for DN-DETR
69
-
67
+
70
68
Args:
71
69
num_queries (int): Number of total queries in DN-DETR. Default: 300
72
70
num_classes (int): Number of total categories. Default: 80.
@@ -77,6 +75,7 @@ class GenerateDNQueries(nn.Module):
77
75
with_indicator (bool): If True, add indicator in noised label/box queries.
78
76
79
77
"""
78
+
80
79
def __init__ (
81
80
self ,
82
81
num_queries : int = 300 ,
@@ -95,7 +94,7 @@ def __init__(
95
94
self .label_noise_prob = label_noise_prob
96
95
self .box_noise_scale = box_noise_scale
97
96
self .with_indicator = with_indicator
98
-
97
+
99
98
# leave one dim for indicator mentioned in DN-DETR
100
99
if with_indicator :
101
100
self .label_encoder = nn .Embedding (num_classes , label_embed_dim - 1 )
@@ -116,15 +115,17 @@ def generate_query_masks(self, max_gt_num_per_image, device):
116
115
] = True
117
116
if i == self .denoising_groups - 1 :
118
117
attn_mask [
119
- max_gt_num_per_image * i : max_gt_num_per_image * (i + 1 ), : max_gt_num_per_image * i
118
+ max_gt_num_per_image * i : max_gt_num_per_image * (i + 1 ),
119
+ : max_gt_num_per_image * i ,
120
120
] = True
121
121
else :
122
122
attn_mask [
123
123
max_gt_num_per_image * i : max_gt_num_per_image * (i + 1 ),
124
124
max_gt_num_per_image * (i + 1 ) : noised_query_nums ,
125
125
] = True
126
126
attn_mask [
127
- max_gt_num_per_image * i : max_gt_num_per_image * (i + 1 ), : max_gt_num_per_image * i
127
+ max_gt_num_per_image * i : max_gt_num_per_image * (i + 1 ),
128
+ : max_gt_num_per_image * i ,
128
129
] = True
129
130
return attn_mask
130
131
@@ -135,7 +136,7 @@ def forward(
135
136
):
136
137
"""
137
138
Args:
138
- gt_boxes_list (list[torch.Tensor]): Ground truth bounding boxes per image
139
+ gt_boxes_list (list[torch.Tensor]): Ground truth bounding boxes per image
139
140
with normalized coordinates in format ``(x, y, w, h)`` in shape ``(num_gts, 4)``
140
141
gt_labels_list (list[torch.Tensor]): Classification labels per image in shape ``(num_gt, )``.
141
142
"""
@@ -162,7 +163,6 @@ def forward(
162
163
# means there are 2 instances in the first image and 3 instances in the second image
163
164
gt_nums_per_image = [x .numel () for x in gt_labels_list ]
164
165
165
-
166
166
# Add noise on labels and boxes
167
167
noised_labels = apply_label_noise (gt_labels , self .label_noise_prob , self .num_classes )
168
168
noised_boxes = apply_box_noise (gt_boxes , self .box_noise_scale )
@@ -175,50 +175,67 @@ def forward(
175
175
# add indicator to label encoding if with_indicator == True
176
176
if self .with_indicator :
177
177
label_embedding = torch .cat ([label_embedding , torch .ones ([query_num , 1 ]).to (device )], 1 )
178
-
178
+
179
179
# calculate the max number of ground truth in one image inside the batch.
180
- # e.g. gt_nums_per_image = [2, 3] which means the first image has 2 instances and the second image has 3 instances
180
+ # e.g. gt_nums_per_image = [2, 3] which means
181
+ # the first image has 2 instances and the second image has 3 instances
181
182
# then the max_gt_num_per_image should be 3.
182
183
max_gt_num_per_image = max (gt_nums_per_image )
183
-
184
+
184
185
# the total denoising queries is depended on denoising groups and max number of instances.
185
186
noised_query_nums = max_gt_num_per_image * self .denoising_groups
186
187
187
188
# initialize the generated noised queries to zero.
188
189
# And the zero initialized queries will be assigned with noised embeddings later.
189
- noised_label_queries = torch .zeros (noised_query_nums , self .label_embed_dim ).to (device ).repeat (batch_size , 1 , 1 )
190
+ noised_label_queries = (
191
+ torch .zeros (noised_query_nums , self .label_embed_dim ).to (device ).repeat (batch_size , 1 , 1 )
192
+ )
190
193
noised_box_queries = torch .zeros (noised_query_nums , 4 ).to (device ).repeat (batch_size , 1 , 1 )
191
194
192
-
193
195
# batch index per image: [0, 1, 2, 3] for batch_size == 4
194
196
batch_idx = torch .arange (0 , batch_size )
195
-
197
+
196
198
# e.g. gt_nums_per_image = [2, 3]
197
199
# batch_idx = [0, 1]
198
- # then the "batch_idx_per_instance" equals to [0, 0, 1, 1, 1] which indicates which image the instance belongs to.
200
+ # then the "batch_idx_per_instance" equals to [0, 0, 1, 1, 1]
201
+ # which indicates which image the instance belongs to.
199
202
# cuz the instances has been flattened before.
200
- batch_idx_per_instance = torch .repeat_interleave (batch_idx , torch .tensor (gt_nums_per_image ).long ())
203
+ batch_idx_per_instance = torch .repeat_interleave (
204
+ batch_idx , torch .tensor (gt_nums_per_image ).long ()
205
+ )
201
206
202
207
# indicate which image the noised labels belong to. For example:
203
208
# noised label: tensor([0, 1, 2, 2, 3, 4, 0, 1, 2, 2, 3, 4])
204
209
# batch_idx_per_group: tensor([0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1])
205
210
# which means the first label "tensor([0])"" belongs to "image_0".
206
211
batch_idx_per_group = batch_idx_per_instance .repeat (self .denoising_groups , 1 ).flatten ()
207
212
208
-
209
- # Cuz there might be different numbers of ground truth in each image of the same batch.
213
+ # Cuz there might be different numbers of ground truth in each image of the same batch.
210
214
# So there might be some padding part in noising queries.
211
- # Here we calculate the indexes for the valid queries and fill them with the noised embeddings.
215
+ # Here we calculate the indexes for the valid queries and
216
+ # fill them with the noised embeddings.
212
217
# And leave the padding part to zeros.
213
218
if len (gt_nums_per_image ):
214
- valid_index_per_group = torch .cat ([torch .tensor (list (range (num ))) for num in gt_nums_per_image ])
215
219
valid_index_per_group = torch .cat (
216
- [valid_index_per_group + max_gt_num_per_image * i for i in range (self .denoising_groups )]).long ()
220
+ [torch .tensor (list (range (num ))) for num in gt_nums_per_image ]
221
+ )
222
+ valid_index_per_group = torch .cat (
223
+ [
224
+ valid_index_per_group + max_gt_num_per_image * i
225
+ for i in range (self .denoising_groups )
226
+ ]
227
+ ).long ()
217
228
if len (batch_idx_per_group ):
218
229
noised_label_queries [(batch_idx_per_group , valid_index_per_group )] = label_embedding
219
230
noised_box_queries [(batch_idx_per_group , valid_index_per_group )] = noised_boxes
220
231
221
232
# generate attention masks for transformer layers
222
233
attn_mask = self .generate_query_masks (max_gt_num_per_image , device )
223
234
224
- return noised_label_queries , noised_box_queries , attn_mask , self .denoising_groups , max_gt_num_per_image
235
+ return (
236
+ noised_label_queries ,
237
+ noised_box_queries ,
238
+ attn_mask ,
239
+ self .denoising_groups ,
240
+ max_gt_num_per_image ,
241
+ )
0 commit comments