Skip to content

Commit

Permalink
support dcfl
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxue0827 committed Jun 13, 2024
1 parent 2c68020 commit be49417
Show file tree
Hide file tree
Showing 11 changed files with 1,945 additions and 8 deletions.
38 changes: 37 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,42 @@ Scene graph generation (SGG) in satellite imagery (SAI) benefits promoting intel

## 🛠️ Usage

For instructions on installation, pretrained models, training and evaluation, please refer to [MMRotate 0.3.4](README_en.md).
More instructions on installation, pretrained models, training and evaluation, please refer to [MMRotate 0.3.4](README_en.md).

- Clone this repo:

```bash
git clone https://github.com/yangxue0827/RSG-MMRotate
cd RSG-MMRotate/
```

- Create a conda virtual environment and activate it:

```bash
conda create -n rsg-mmrotate python=3.8 -y
conda activate rsg-mmrotate
```

- Install Pytorch:

```bash
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
```

- Install requirements:

```bash
pip install openmim
mim install mmcv-full
mim install mmdet

cd mmrotate
pip install -r requirements/build.txt
pip install -v -e .

pip install timm
pip install ipdb
```

## 🚀 Released Models

Expand All @@ -33,6 +68,7 @@ For instructions on installation, pretrained models, training and evaluation, pl
| KLD | 25.0 | [rotated_retinanet_hbb_kld_r50_fpn_1x_rsg_oc](configs/kld/rotated_retinanet_hbb_kld_r50_fpn_1x_rsg_oc.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/rotated_retinanet_hbb_kld_r50_fpn_1x_rsg_oc.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/rotated_retinanet_hbb_kld_r50_fpn_1x_rsg_oc-343a0b83.pth?download=true) |
| GWD | 25.3 | [rotated_retinanet_hbb_gwd_r50_fpn_1x_rsg_oc](configs/gwd/rotated_retinanet_hbb_gwd_r50_fpn_1x_rsg_oc.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/rotated_retinanet_hbb_gwd_r50_fpn_1x_rsg_oc.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/rotated_retinanet_hbb_gwd_r50_fpn_1x_rsg_oc-566d2398.pth?download=true) |
| KFIoU | 25.5 | [rotated_retinanet_hbb_kfiou_r50_fpn_1x_rsg_oc](configs/kfiou/rotated_retinanet_hbb_kfiou_r50_fpn_1x_rsg_oc.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/rotated_retinanet_hbb_kfiou_r50_fpn_1x_rsg_oc.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/rotated_retinanet_hbb_kfiou_r50_fpn_1x_rsg_oc-198081a6.pth?download=true) |
| R<sup>3</sup>Det | 23.7 | [r3det_r50_fpn_1x_rsg_oc](configs/r3det/r3det_r50_fpn_1x_rsg_oc.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/r3det_r50_fpn_1x_rsg_oc.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/r3det_r50_fpn_1x_rsg_oc-c8c4a5e5.pth?download=true) |
| S2A-Net | 27.3 | [s2anet_r50_fpn_1x_rsg_le135](configs/s2anet/s2anet_r50_fpn_1x_rsg_le135.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/s2anet_r50_fpn_1x_rsg_le135.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/s2anet_r50_fpn_1x_rsg_le135-42887a81.pth?download=true) |
| FCOS | 28.1 | [rotated_fcos_r50_fpn_1x_rsg_le90](configs/rotated_fcos/rotated_fcos_r50_fpn_1x_rsg_le90.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/rotated_fcos_r50_fpn_1x_rsg_le90.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/rotated_fcos_r50_fpn_1x_rsg_le90-a579fbf7.pth?download=true) |
| CSL | 27.4 | [rotated_fcos_csl_gaussian_r50_fpn_1x_rsg_le90](configs/rotated_fcos/rotated_fcos_csl_gaussian_r50_fpn_1x_rsg_le90.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/rotated_fcos_csl_gaussian_r50_fpn_1x_rsg_le90.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/rotated_fcos_csl_gaussian_r50_fpn_1x_rsg_le90-6ab9a42a.pth?download=true) |
Expand Down
117 changes: 117 additions & 0 deletions configs/dcfl/dcfl_r50_fpn_1x_rsg_le135.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
_base_ = [
'../_base_/datasets/rsg.py', '../_base_/schedules/schedule_1x.py',
'../_base_/default_runtime.py'
]

angle_version = 'le135'
model = dict(
type='RotatedRetinaNetCrop',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
zero_init_residual=False,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_input',
num_outs=5),
bbox_head=dict(
type='RDCFLHead',
num_classes=48,
in_channels=256,
stacked_convs=4,
feat_channels=256,
assign_by_circumhbbox=None,
dcn_assign = True,
dilation_rate = 4,
anchor_generator=dict(
type='RotatedAnchorGenerator',
octave_base_scale=4,
scales_per_octave=1,
ratios=[1.0],
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHAOBBoxCoder',
angle_range=angle_version,
norm_factor=1,
edge_swap=False,
proj_xy=True,
target_means=(.0, .0, .0, .0, .0),
target_stds=(1.0, 1.0, 1.0, 1.0, 1.0)),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
reg_decoded_bbox=True,
loss_bbox=dict(
type='RotatedIoULoss',
loss_weight=1.0)),
train_cfg=dict(
assigner=dict(
type='C2FAssigner',
ignore_iof_thr=-1,
gpu_assign_thr= 1024,
iou_calculator=dict(type='RBboxMetrics2D'),
assign_metric='gjsd',
topk=16,
topq=12,
constraint='dgmm',
gauss_thr=0.6),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=2000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(iou_thr=0.4),
max_per_img=2000))

img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RResize', img_scale=(1024, 1024)),
dict(
type='RRandomFlip',
flip_ratio=[0.25, 0.25, 0.25],
direction=['horizontal', 'vertical', 'diagonal'],
version=angle_version),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(pipeline=train_pipeline, version=angle_version),
val=dict(version=angle_version),
test=dict(version=angle_version))

optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.0001,
betas=(0.9, 0.999),
weight_decay=0.05,
paramwise_cfg=dict(
custom_keys=dict(
absolute_pos_embed=dict(decay_mult=0.0),
relative_position_bias_table=dict(decay_mult=0.0),
norm=dict(decay_mult=0.0))))

checkpoint_config = dict(interval=1, max_keep_ckpts=1)
evaluation = dict(interval=6, metric='mAP')
5 changes: 0 additions & 5 deletions configs/h2rbox_v2p/h2rbox_v2p_r50_fpn_1x_rsg_le90.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,6 @@
img_prefix=data_root + 'test/images/',
version=angle_version))

data = dict(
train=dict(pipeline=train_pipeline, version=angle_version),
val=dict(version=angle_version),
test=dict(version=angle_version))

optimizer = dict(
_delete_=True,
type='AdamW',
Expand Down
96 changes: 96 additions & 0 deletions configs/h2rbox_v2p/h2rbox_v2p_r50_fpn_1x_rsg_rbox_le90.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
_base_ = [
'../_base_/datasets/rsg.py', '../_base_/schedules/schedule_1x.py',
'../_base_/default_runtime.py'
]
angle_version = 'le90'

# model settings
model = dict(
type='H2RBoxV2PDetectorCrop',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
zero_init_residual=False,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output', # use P5
num_outs=5,
relu_before_extra_convs=True),
bbox_head=dict(
type='H2RBoxV2PHead',
num_classes=48,
in_channels=256,
stacked_convs=4,
feat_channels=256,
strides=[8, 16, 32, 64, 128],
center_sampling=True,
center_sample_radius=1.5,
norm_on_bbox=True,
centerness_on_reg=True,
square_cls=[4, 44],
# resize_cls=[1],
scale_angle=False,
bbox_coder=dict(
type='DistanceAnglePointCoder', angle_version=angle_version),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='IoULoss', loss_weight=1.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
loss_ss_symmetry=dict(
type='SmoothL1Loss', loss_weight=0.2, beta=0.1)),
# training and testing settings
train_cfg=None,
test_cfg=dict(
nms_pre=2000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(iou_thr=0.1),
max_per_img=2000))

img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RResize', img_scale=(1024, 1024)),
dict(
type='RRandomFlip',
flip_ratio=[0.25, 0.25, 0.25],
direction=['horizontal', 'vertical', 'diagonal'],
version=angle_version),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]

data = dict(
train=dict(pipeline=train_pipeline, version=angle_version),
val=dict(version=angle_version),
test=dict(version=angle_version))

optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.0001,
betas=(0.9, 0.999),
weight_decay=0.05)

checkpoint_config = dict(interval=1, max_keep_ckpts=1)
evaluation = dict(interval=6, metric='mAP')

3 changes: 3 additions & 0 deletions mmrotate/core/bbox/assigners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
from .sas_assigner import SASAssigner
from .rotated_hungarian_assigner import Rotated_HungarianAssigner
from .ars_hungarian_assigner import ARS_HungarianAssigner
from .coarse2fine_assigner import C2FAssigner
from .ranking_assigner import RRankingAssigner

__all__ = [
'ConvexAssigner', 'MaxConvexIoUAssigner', 'SASAssigner', 'ATSSKldAssigner',
'ATSSObbAssigner', 'Rotated_HungarianAssigner', 'ARS_HungarianAssigner',
'C2FAssigner', 'RRankingAssigner'
]
Loading

0 comments on commit be49417

Please sign in to comment.