Skip to content

Commit

Permalink
add custom configs with swin backbone for retinanet and vfnet
Browse files Browse the repository at this point in the history
  • Loading branch information
dnth committed Jan 26, 2022
1 parent 08c6587 commit 129566e
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 0 deletions.
16 changes: 16 additions & 0 deletions custom_configs/retinanet/retinanet_swin-b-p4-w7_fpn_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
_base_ = ["./retinanet_swin-t-p4-w7_fpn_1x_coco.py"]

pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth' # noqa

model = dict(
backbone=dict(
embed_dims=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
out_indices=(0, 1, 2, 3),
init_cfg=dict(type='Pretrained', checkpoint=pretrained),
),
neck=dict(
in_channels=[128, 256, 512, 1024],
),
)
10 changes: 10 additions & 0 deletions custom_configs/retinanet/retinanet_swin-s-p4-w7_fpn_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
_base_ = ["./retinanet_swin-t-p4-w7_fpn_1x_coco.py"]

pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth' # noqa

model = dict(
backbone=dict(
depths=[2, 2, 18, 2],
init_cfg=dict(type='Pretrained', checkpoint=pretrained),
),
)
30 changes: 30 additions & 0 deletions custom_configs/retinanet/retinanet_swin-t-p4-w7_fpn_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
_base_ = [
'../_base_/models/retinanet_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' # noqa
model = dict(
backbone=dict(
_delete_=True,
type='SwinTransformer',
embed_dims=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
patch_norm=True,
out_indices=(1, 2, 3),
# Please only add indices that would be used
# in FPN, otherwise some parameter will not be used
with_cp=False,
convert_weights=True,
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
neck=dict(in_channels=[192, 384, 768], start_level=0, num_outs=5))

optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
16 changes: 16 additions & 0 deletions custom_configs/vfnet/vfnet_swin-b-p4-w7_fpn_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
_base_ = ["./vfnet_swin-t-p4-w7_fpn_1x_coco.py"]

pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth' # noqa

model = dict(
backbone=dict(
embed_dims=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
out_indices=(0, 1, 2, 3),
init_cfg=dict(type='Pretrained', checkpoint=pretrained),
),
neck=dict(
in_channels=[128, 256, 512, 1024],
),
)
10 changes: 10 additions & 0 deletions custom_configs/vfnet/vfnet_swin-s-p4-w7_fpn_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
_base_ = ["./vfnet_swin-t-p4-w7_fpn_1x_coco.py"]

pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth' # noqa

model = dict(
backbone=dict(
depths=[2, 2, 18, 2],
init_cfg=dict(type='Pretrained', checkpoint=pretrained),
),
)
118 changes: 118 additions & 0 deletions custom_configs/vfnet/vfnet_swin-t-p4-w7_fpn_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth' # noqa
# model settings
model = dict(
type='VFNet',
backbone=dict(
#_delete_=True,
type='SwinTransformer',
embed_dims=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
patch_norm=True,
out_indices=(1, 2, 3),
# Please only add indices that would be used
# in FPN, otherwise some parameter will not be used
with_cp=False,
convert_weights=True,
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
neck=dict(
type='FPN',
in_channels=[192, 384, 768],
out_channels=256,
start_level=0,
add_extra_convs='on_output', # use P5
num_outs=5,
relu_before_extra_convs=True),
bbox_head=dict(
type='VFNetHead',
num_classes=80,
in_channels=256,
stacked_convs=3,
feat_channels=256,
strides=[8, 16, 32, 64, 128],
center_sampling=False,
dcn_on_last_conv=False,
use_atss=True,
use_vfl=True,
loss_cls=dict(
type='VarifocalLoss',
use_sigmoid=True,
alpha=0.75,
gamma=2.0,
iou_weighted=True,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=1.5),
loss_bbox_refine=dict(type='GIoULoss', loss_weight=2.0)),
# training and testing settings
train_cfg=dict(
assigner=dict(type='ATSSAssigner', topk=9),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))

# data setting
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

# optimizer
optimizer = dict(
lr=0.01, paramwise_cfg=dict(bias_lr_mult=2., bias_decay_mult=0.))
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.1,
step=[8, 11])
runner = dict(type='EpochBasedRunner', max_epochs=12)

0 comments on commit 129566e

Please sign in to comment.