Skip to content

Commit 293e9c7

Browse files
rentainhentianhe ren
and
ntianhe ren
authored
Refine code format (IDEA-Research#80)
* refine format * refine format * refine docs links Co-authored-by: ntianhe ren <[email protected]>
1 parent 4278d72 commit 293e9c7

File tree

72 files changed

+674
-635
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+674
-635
lines changed

.flake8

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
[flake8]
55
ignore = W503, E203, E221, C901, C408, E741, C407, B017
6-
max-line-length = 100
6+
max-line-length = 120
77
max-complexity = 18
88
select = B,C,E,F,W,T4,B9
99
exclude = build, detectron2

configs/common/coco_schedule.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,4 @@ def default_coco_scheduler(epochs=50, decay_epochs=40, warmup_epochs=0):
8484

8585
# warmup scheduler for detr
8686
lr_multiplier_50ep_warmup = default_coco_scheduler(50, 40, 1e-3)
87-
lr_multiplier_12ep_warmup = default_coco_scheduler(12, 11, 1e-3)
87+
lr_multiplier_12ep_warmup = default_coco_scheduler(12, 11, 1e-3)

detrex/layers/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@
4949
apply_label_noise,
5050
GenerateDNQueries,
5151
)
52-
from .shape_spec import ShapeSpec
52+
from .shape_spec import ShapeSpec

detrex/layers/denoising.py

+42-25
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121

2222
def apply_label_noise(
23-
labels: torch.Tensor,
24-
label_noise_prob: float = 0.2,
23+
labels: torch.Tensor,
24+
label_noise_prob: float = 0.2,
2525
num_classes: int = 80,
2626
):
2727
"""
@@ -57,16 +57,14 @@ def apply_box_noise(
5757
diff = torch.zeros_like(boxes)
5858
diff[:, :2] = boxes[:, 2:] / 2
5959
diff[:, 2:] = boxes[:, 2:]
60-
boxes += (
61-
torch.mul((torch.rand_like(boxes) * 2 - 1.0), diff) * box_noise_scale
62-
)
60+
boxes += torch.mul((torch.rand_like(boxes) * 2 - 1.0), diff) * box_noise_scale
6361
boxes = boxes.clamp(min=0.0, max=1.0)
6462
return boxes
6563

6664

6765
class GenerateDNQueries(nn.Module):
6866
"""Generate denoising queries for DN-DETR
69-
67+
7068
Args:
7169
num_queries (int): Number of total queries in DN-DETR. Default: 300
7270
num_classes (int): Number of total categories. Default: 80.
@@ -77,6 +75,7 @@ class GenerateDNQueries(nn.Module):
7775
with_indicator (bool): If True, add indicator in noised label/box queries.
7876
7977
"""
78+
8079
def __init__(
8180
self,
8281
num_queries: int = 300,
@@ -95,7 +94,7 @@ def __init__(
9594
self.label_noise_prob = label_noise_prob
9695
self.box_noise_scale = box_noise_scale
9796
self.with_indicator = with_indicator
98-
97+
9998
# leave one dim for indicator mentioned in DN-DETR
10099
if with_indicator:
101100
self.label_encoder = nn.Embedding(num_classes, label_embed_dim - 1)
@@ -116,15 +115,17 @@ def generate_query_masks(self, max_gt_num_per_image, device):
116115
] = True
117116
if i == self.denoising_groups - 1:
118117
attn_mask[
119-
max_gt_num_per_image * i : max_gt_num_per_image * (i + 1), : max_gt_num_per_image * i
118+
max_gt_num_per_image * i : max_gt_num_per_image * (i + 1),
119+
: max_gt_num_per_image * i,
120120
] = True
121121
else:
122122
attn_mask[
123123
max_gt_num_per_image * i : max_gt_num_per_image * (i + 1),
124124
max_gt_num_per_image * (i + 1) : noised_query_nums,
125125
] = True
126126
attn_mask[
127-
max_gt_num_per_image * i : max_gt_num_per_image * (i + 1), : max_gt_num_per_image * i
127+
max_gt_num_per_image * i : max_gt_num_per_image * (i + 1),
128+
: max_gt_num_per_image * i,
128129
] = True
129130
return attn_mask
130131

@@ -135,7 +136,7 @@ def forward(
135136
):
136137
"""
137138
Args:
138-
gt_boxes_list (list[torch.Tensor]): Ground truth bounding boxes per image
139+
gt_boxes_list (list[torch.Tensor]): Ground truth bounding boxes per image
139140
with normalized coordinates in format ``(x, y, w, h)`` in shape ``(num_gts, 4)``
140141
gt_labels_list (list[torch.Tensor]): Classification labels per image in shape ``(num_gt, )``.
141142
"""
@@ -162,7 +163,6 @@ def forward(
162163
# means there are 2 instances in the first image and 3 instances in the second image
163164
gt_nums_per_image = [x.numel() for x in gt_labels_list]
164165

165-
166166
# Add noise on labels and boxes
167167
noised_labels = apply_label_noise(gt_labels, self.label_noise_prob, self.num_classes)
168168
noised_boxes = apply_box_noise(gt_boxes, self.box_noise_scale)
@@ -175,50 +175,67 @@ def forward(
175175
# add indicator to label encoding if with_indicator == True
176176
if self.with_indicator:
177177
label_embedding = torch.cat([label_embedding, torch.ones([query_num, 1]).to(device)], 1)
178-
178+
179179
# calculate the max number of ground truth in one image inside the batch.
180-
# e.g. gt_nums_per_image = [2, 3] which means the first image has 2 instances and the second image has 3 instances
180+
# e.g. gt_nums_per_image = [2, 3] which means
181+
# the first image has 2 instances and the second image has 3 instances
181182
# then the max_gt_num_per_image should be 3.
182183
max_gt_num_per_image = max(gt_nums_per_image)
183-
184+
184185
# the total denoising queries is depended on denoising groups and max number of instances.
185186
noised_query_nums = max_gt_num_per_image * self.denoising_groups
186187

187188
# initialize the generated noised queries to zero.
188189
# And the zero initialized queries will be assigned with noised embeddings later.
189-
noised_label_queries = torch.zeros(noised_query_nums, self.label_embed_dim).to(device).repeat(batch_size, 1, 1)
190+
noised_label_queries = (
191+
torch.zeros(noised_query_nums, self.label_embed_dim).to(device).repeat(batch_size, 1, 1)
192+
)
190193
noised_box_queries = torch.zeros(noised_query_nums, 4).to(device).repeat(batch_size, 1, 1)
191194

192-
193195
# batch index per image: [0, 1, 2, 3] for batch_size == 4
194196
batch_idx = torch.arange(0, batch_size)
195-
197+
196198
# e.g. gt_nums_per_image = [2, 3]
197199
# batch_idx = [0, 1]
198-
# then the "batch_idx_per_instance" equals to [0, 0, 1, 1, 1] which indicates which image the instance belongs to.
200+
# then the "batch_idx_per_instance" equals to [0, 0, 1, 1, 1]
201+
# which indicates which image the instance belongs to.
199202
# cuz the instances has been flattened before.
200-
batch_idx_per_instance = torch.repeat_interleave(batch_idx, torch.tensor(gt_nums_per_image).long())
203+
batch_idx_per_instance = torch.repeat_interleave(
204+
batch_idx, torch.tensor(gt_nums_per_image).long()
205+
)
201206

202207
# indicate which image the noised labels belong to. For example:
203208
# noised label: tensor([0, 1, 2, 2, 3, 4, 0, 1, 2, 2, 3, 4])
204209
# batch_idx_per_group: tensor([0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1])
205210
# which means the first label "tensor([0])"" belongs to "image_0".
206211
batch_idx_per_group = batch_idx_per_instance.repeat(self.denoising_groups, 1).flatten()
207212

208-
209-
# Cuz there might be different numbers of ground truth in each image of the same batch.
213+
# Cuz there might be different numbers of ground truth in each image of the same batch.
210214
# So there might be some padding part in noising queries.
211-
# Here we calculate the indexes for the valid queries and fill them with the noised embeddings.
215+
# Here we calculate the indexes for the valid queries and
216+
# fill them with the noised embeddings.
212217
# And leave the padding part to zeros.
213218
if len(gt_nums_per_image):
214-
valid_index_per_group = torch.cat([torch.tensor(list(range(num))) for num in gt_nums_per_image])
215219
valid_index_per_group = torch.cat(
216-
[valid_index_per_group + max_gt_num_per_image * i for i in range(self.denoising_groups)]).long()
220+
[torch.tensor(list(range(num))) for num in gt_nums_per_image]
221+
)
222+
valid_index_per_group = torch.cat(
223+
[
224+
valid_index_per_group + max_gt_num_per_image * i
225+
for i in range(self.denoising_groups)
226+
]
227+
).long()
217228
if len(batch_idx_per_group):
218229
noised_label_queries[(batch_idx_per_group, valid_index_per_group)] = label_embedding
219230
noised_box_queries[(batch_idx_per_group, valid_index_per_group)] = noised_boxes
220231

221232
# generate attention masks for transformer layers
222233
attn_mask = self.generate_query_masks(max_gt_num_per_image, device)
223234

224-
return noised_label_queries, noised_box_queries, attn_mask, self.denoising_groups, max_gt_num_per_image
235+
return (
236+
noised_label_queries,
237+
noised_box_queries,
238+
attn_mask,
239+
self.denoising_groups,
240+
max_gt_num_per_image,
241+
)

detrex/layers/multi_scale_deform_attn.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -406,4 +406,6 @@ def _dummy(*args, **kwargs):
406406
# TODO: register ops natively so there is no need to import _C.
407407
_msg = "detrex is not compiled successfully, please build following the instructions!"
408408
_args = ("detrex._C", _msg)
409-
MultiScaleDeformableAttention = create_dummy_class("MultiScaleDeformableAttention", *_args)
409+
MultiScaleDeformableAttention = create_dummy_class( # noqa
410+
"MultiScaleDeformableAttention", *_args
411+
)

detrex/layers/position_embedding.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,11 @@ def get_sine_pos_embed(
189189
temperature (int): The temperature used for scaling
190190
the position embedding. Default: 10000.
191191
exchange_xy (bool, optional): exchange pos x and pos y. \
192-
For example, input tensor is `[x, y]`, the results will
192+
For example, input tensor is `[x, y]`, the results will # noqa
193193
be `[pos(y), pos(x)]`. Defaults: True.
194194
195195
Returns:
196-
torch.Tensor: Returned position embedding
196+
torch.Tensor: Returned position embedding # noqa
197197
with shape `(None, n * num_pos_feats)`.
198198
"""
199199
scale = 2 * math.pi

detrex/layers/shape_spec.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from dataclasses import dataclass
2020
from typing import Optional
2121

22+
2223
@dataclass
2324
class ShapeSpec:
2425
"""
@@ -30,4 +31,4 @@ class ShapeSpec:
3031
channels: Optional[int] = None
3132
height: Optional[int] = None
3233
width: Optional[int] = None
33-
stride: Optional[int] = None
34+
stride: Optional[int] = None

detrex/layers/transformer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def __init__(
6767
else:
6868
assert len(attn) == num_attn, (
6969
f"The length of attn (nn.Module or List[nn.Module]) {num_attn}"
70-
f"is not consistent with the number of attention in operation_order {operation_order}"
70+
f"is not consistent with the number of attention in "
71+
f"operation_order {operation_order}"
7172
)
7273

7374
self.num_attn = num_attn

detrex/modeling/backbone/convnext.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
from functools import partial
2323
import torch
2424
import torch.nn as nn
25-
import torch.nn.functional as F
2625
from timm.models.layers import DropPath, trunc_normal_
27-
from detectron2.modeling.backbone import Backbone
2826

2927
from detrex.layers import LayerNorm
3028

29+
from detectron2.modeling.backbone import Backbone
30+
3131

3232
class Block(nn.Module):
3333
r"""ConvNeXt Block. There are two equivalent implementations:

detrex/modeling/backbone/focalnet.py

-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
# https://github.com/microsoft/FocalNet/blob/main/detection/mmdet/models/backbones/focalnet.py
2020
# ------------------------------------------------------------------------------------------------
2121

22-
import numpy as np
2322
import torch
2423
import torch.nn as nn
2524
import torch.nn.functional as F

detrex/modeling/backbone/timm_backbone.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
import torch.nn as nn
2828

2929
from detectron2.modeling.backbone import Backbone
30-
from detectron2.utils.logger import setup_logger
3130
from detectron2.utils import comm
31+
from detectron2.utils.logger import setup_logger
3232

3333
try:
3434
import timm
@@ -135,17 +135,21 @@ def __init__(
135135

136136
if feature_info is not None:
137137
output_feature_channels = {
138-
"p{}".format(out_indices[i]): feature_info.channels()[i] for i in range(len(out_indices))
138+
"p{}".format(out_indices[i]): feature_info.channels()[i]
139+
for i in range(len(out_indices))
139140
}
140141
out_feature_strides = {
141-
"p{}".format(out_indices[i]): feature_info.reduction()[i] for i in range(len(out_indices))
142+
"p{}".format(out_indices[i]): feature_info.reduction()[i]
143+
for i in range(len(out_indices))
142144
}
143145

144146
self._out_features = {"p{}".format(out_indices[i]) for i in range(len(out_indices))}
145147
self._out_feature_channels = {
146148
feat: output_feature_channels[feat] for feat in self._out_features
147149
}
148-
self._out_feature_strides = {feat: out_feature_strides[feat] for feat in self._out_features}
150+
self._out_feature_strides = {
151+
feat: out_feature_strides[feat] for feat in self._out_features
152+
}
149153

150154
def forward(self, x):
151155
"""Forward function of `TimmBackbone`.

detrex/modeling/backbone/torchvision_backbone.py

+32-30
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,15 @@
2222
from torchvision.models.feature_extraction import (
2323
create_feature_extractor,
2424
)
25+
2526
has_feature_extractor = True
2627
except ImportError:
2728
has_feature_extractor = False
2829

2930

3031
class TorchvisionBackbone(Backbone):
3132
"""A wrapper for torchvision pretrained backbones
32-
33+
3334
Please check `Feature extraction for model inspection
3435
<https://pytorch.org/vision/stable/feature_extraction.html>`_
3536
for more details.
@@ -41,51 +42,52 @@ class TorchvisionBackbone(Backbone):
4142
return_nodes (Dict[str, str]): The keys are the node names and the values are the
4243
user-specified keys for the graph module's returned dictionary.
4344
"""
44-
def __init__(self,
45-
model_name: str = "resnet50",
46-
pretrained: bool = False,
47-
return_nodes: Dict[str, str] = {
48-
"layer1": "res2",
49-
"layer2": "res3",
50-
"layer3": "res4",
51-
"layer4": "res5",
52-
},
53-
train_return_nodes: Dict[str, str] = None,
54-
eval_return_nodes: Dict[str, str] = None,
55-
tracer_kwargs: Dict[str, Any] = None,
56-
suppress_diff_warnings: bool = False,
57-
**kwargs,
58-
):
45+
46+
def __init__(
47+
self,
48+
model_name: str = "resnet50",
49+
pretrained: bool = False,
50+
return_nodes: Dict[str, str] = {
51+
"layer1": "res2",
52+
"layer2": "res3",
53+
"layer3": "res4",
54+
"layer4": "res5",
55+
},
56+
train_return_nodes: Dict[str, str] = None,
57+
eval_return_nodes: Dict[str, str] = None,
58+
tracer_kwargs: Dict[str, Any] = None,
59+
suppress_diff_warnings: bool = False,
60+
**kwargs,
61+
):
5962
super(TorchvisionBackbone, self).__init__()
60-
63+
6164
# build torchvision models
62-
self.model = getattr(torchvision.models, model_name)(
63-
pretrained=pretrained,
64-
**kwargs
65-
)
66-
65+
self.model = getattr(torchvision.models, model_name)(pretrained=pretrained, **kwargs)
66+
6767
if has_feature_extractor is False:
68-
raise RuntimeError('Failed to import create_feature_extractor from torchvision. \
69-
Please install torchvision 1.10+.')
70-
68+
raise RuntimeError(
69+
"Failed to import create_feature_extractor from torchvision. \
70+
Please install torchvision 1.10+."
71+
)
72+
7173
# turn models into feature extractor
7274
self.feature_extractor = create_feature_extractor(
73-
model = self.model,
75+
model=self.model,
7476
return_nodes=return_nodes,
7577
train_return_nodes=train_return_nodes,
7678
eval_return_nodes=eval_return_nodes,
7779
tracer_kwargs=tracer_kwargs,
78-
suppress_diff_warning=suppress_diff_warnings
80+
suppress_diff_warning=suppress_diff_warnings,
7981
)
8082

8183
def forward(self, x):
8284
"""Forward function of TorchvisionBackbone
83-
85+
8486
Args:
8587
x (torch.Tensor): the input tensor for feature extraction.
86-
88+
8789
Returns:
8890
dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor
8991
"""
9092
outs = self.feature_extractor(x)
91-
return outs
93+
return outs

0 commit comments

Comments
 (0)