diff --git a/README.md b/README.md
index 9eedec8..ae0f9e3 100644
--- a/README.md
+++ b/README.md
@@ -18,7 +18,42 @@ Scene graph generation (SGG) in satellite imagery (SAI) benefits promoting intel
## 🛠️ Usage
-For instructions on installation, pretrained models, training and evaluation, please refer to [MMRotate 0.3.4](README_en.md).
+More instructions on installation, pretrained models, training and evaluation, please refer to [MMRotate 0.3.4](README_en.md).
+
+- Clone this repo:
+
+ ```bash
+ git clone https://github.com/yangxue0827/RSG-MMRotate
+ cd RSG-MMRotate/
+ ```
+
+- Create a conda virtual environment and activate it:
+
+ ```bash
+ conda create -n rsg-mmrotate python=3.8 -y
+ conda activate rsg-mmrotate
+ ```
+
+- Install Pytorch:
+
+ ```bash
+ pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
+ ```
+
+- Install requirements:
+
+ ```bash
+ pip install openmim
+ mim install mmcv-full
+ mim install mmdet
+
+ cd mmrotate
+ pip install -r requirements/build.txt
+ pip install -v -e .
+
+ pip install timm
+ pip install ipdb
+ ```
## 🚀 Released Models
@@ -33,6 +68,7 @@ For instructions on installation, pretrained models, training and evaluation, pl
| KLD | 25.0 | [rotated_retinanet_hbb_kld_r50_fpn_1x_rsg_oc](configs/kld/rotated_retinanet_hbb_kld_r50_fpn_1x_rsg_oc.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/rotated_retinanet_hbb_kld_r50_fpn_1x_rsg_oc.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/rotated_retinanet_hbb_kld_r50_fpn_1x_rsg_oc-343a0b83.pth?download=true) |
| GWD | 25.3 | [rotated_retinanet_hbb_gwd_r50_fpn_1x_rsg_oc](configs/gwd/rotated_retinanet_hbb_gwd_r50_fpn_1x_rsg_oc.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/rotated_retinanet_hbb_gwd_r50_fpn_1x_rsg_oc.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/rotated_retinanet_hbb_gwd_r50_fpn_1x_rsg_oc-566d2398.pth?download=true) |
| KFIoU | 25.5 | [rotated_retinanet_hbb_kfiou_r50_fpn_1x_rsg_oc](configs/kfiou/rotated_retinanet_hbb_kfiou_r50_fpn_1x_rsg_oc.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/rotated_retinanet_hbb_kfiou_r50_fpn_1x_rsg_oc.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/rotated_retinanet_hbb_kfiou_r50_fpn_1x_rsg_oc-198081a6.pth?download=true) |
+| R3Det | 23.7 | [r3det_r50_fpn_1x_rsg_oc](configs/r3det/r3det_r50_fpn_1x_rsg_oc.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/r3det_r50_fpn_1x_rsg_oc.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/r3det_r50_fpn_1x_rsg_oc-c8c4a5e5.pth?download=true) |
| S2A-Net | 27.3 | [s2anet_r50_fpn_1x_rsg_le135](configs/s2anet/s2anet_r50_fpn_1x_rsg_le135.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/s2anet_r50_fpn_1x_rsg_le135.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/s2anet_r50_fpn_1x_rsg_le135-42887a81.pth?download=true) |
| FCOS | 28.1 | [rotated_fcos_r50_fpn_1x_rsg_le90](configs/rotated_fcos/rotated_fcos_r50_fpn_1x_rsg_le90.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/rotated_fcos_r50_fpn_1x_rsg_le90.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/rotated_fcos_r50_fpn_1x_rsg_le90-a579fbf7.pth?download=true) |
| CSL | 27.4 | [rotated_fcos_csl_gaussian_r50_fpn_1x_rsg_le90](configs/rotated_fcos/rotated_fcos_csl_gaussian_r50_fpn_1x_rsg_le90.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/rotated_fcos_csl_gaussian_r50_fpn_1x_rsg_le90.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/rotated_fcos_csl_gaussian_r50_fpn_1x_rsg_le90-6ab9a42a.pth?download=true) |
diff --git a/configs/dcfl/dcfl_r50_fpn_1x_rsg_le135.py b/configs/dcfl/dcfl_r50_fpn_1x_rsg_le135.py
new file mode 100644
index 0000000..69744bc
--- /dev/null
+++ b/configs/dcfl/dcfl_r50_fpn_1x_rsg_le135.py
@@ -0,0 +1,117 @@
+_base_ = [
+ '../_base_/datasets/rsg.py', '../_base_/schedules/schedule_1x.py',
+ '../_base_/default_runtime.py'
+]
+
+angle_version = 'le135'
+model = dict(
+ type='RotatedRetinaNetCrop',
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=1,
+ zero_init_residual=False,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=True,
+ style='pytorch',
+ init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
+ neck=dict(
+ type='FPN',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ start_level=1,
+ add_extra_convs='on_input',
+ num_outs=5),
+ bbox_head=dict(
+ type='RDCFLHead',
+ num_classes=48,
+ in_channels=256,
+ stacked_convs=4,
+ feat_channels=256,
+ assign_by_circumhbbox=None,
+ dcn_assign = True,
+ dilation_rate = 4,
+ anchor_generator=dict(
+ type='RotatedAnchorGenerator',
+ octave_base_scale=4,
+ scales_per_octave=1,
+ ratios=[1.0],
+ strides=[8, 16, 32, 64, 128]),
+ bbox_coder=dict(
+ type='DeltaXYWHAOBBoxCoder',
+ angle_range=angle_version,
+ norm_factor=1,
+ edge_swap=False,
+ proj_xy=True,
+ target_means=(.0, .0, .0, .0, .0),
+ target_stds=(1.0, 1.0, 1.0, 1.0, 1.0)),
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ reg_decoded_bbox=True,
+ loss_bbox=dict(
+ type='RotatedIoULoss',
+ loss_weight=1.0)),
+ train_cfg=dict(
+ assigner=dict(
+ type='C2FAssigner',
+ ignore_iof_thr=-1,
+ gpu_assign_thr= 1024,
+ iou_calculator=dict(type='RBboxMetrics2D'),
+ assign_metric='gjsd',
+ topk=16,
+ topq=12,
+ constraint='dgmm',
+ gauss_thr=0.6),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False),
+ test_cfg=dict(
+ nms_pre=2000,
+ min_bbox_size=0,
+ score_thr=0.05,
+ nms=dict(iou_thr=0.4),
+ max_per_img=2000))
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', with_bbox=True),
+ dict(type='RResize', img_scale=(1024, 1024)),
+ dict(
+ type='RRandomFlip',
+ flip_ratio=[0.25, 0.25, 0.25],
+ direction=['horizontal', 'vertical', 'diagonal'],
+ version=angle_version),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
+]
+data = dict(
+ samples_per_gpu=2,
+ workers_per_gpu=2,
+ train=dict(pipeline=train_pipeline, version=angle_version),
+ val=dict(version=angle_version),
+ test=dict(version=angle_version))
+
+optimizer = dict(
+ _delete_=True,
+ type='AdamW',
+ lr=0.0001,
+ betas=(0.9, 0.999),
+ weight_decay=0.05,
+ paramwise_cfg=dict(
+ custom_keys=dict(
+ absolute_pos_embed=dict(decay_mult=0.0),
+ relative_position_bias_table=dict(decay_mult=0.0),
+ norm=dict(decay_mult=0.0))))
+
+checkpoint_config = dict(interval=1, max_keep_ckpts=1)
+evaluation = dict(interval=6, metric='mAP')
\ No newline at end of file
diff --git a/configs/h2rbox_v2p/h2rbox_v2p_r50_fpn_1x_rsg_le90.py b/configs/h2rbox_v2p/h2rbox_v2p_r50_fpn_1x_rsg_le90.py
index a37320f..bbbd186 100644
--- a/configs/h2rbox_v2p/h2rbox_v2p_r50_fpn_1x_rsg_le90.py
+++ b/configs/h2rbox_v2p/h2rbox_v2p_r50_fpn_1x_rsg_le90.py
@@ -108,11 +108,6 @@
img_prefix=data_root + 'test/images/',
version=angle_version))
-data = dict(
- train=dict(pipeline=train_pipeline, version=angle_version),
- val=dict(version=angle_version),
- test=dict(version=angle_version))
-
optimizer = dict(
_delete_=True,
type='AdamW',
diff --git a/configs/h2rbox_v2p/h2rbox_v2p_r50_fpn_1x_rsg_rbox_le90.py b/configs/h2rbox_v2p/h2rbox_v2p_r50_fpn_1x_rsg_rbox_le90.py
new file mode 100644
index 0000000..590373c
--- /dev/null
+++ b/configs/h2rbox_v2p/h2rbox_v2p_r50_fpn_1x_rsg_rbox_le90.py
@@ -0,0 +1,96 @@
+_base_ = [
+ '../_base_/datasets/rsg.py', '../_base_/schedules/schedule_1x.py',
+ '../_base_/default_runtime.py'
+]
+angle_version = 'le90'
+
+# model settings
+model = dict(
+ type='H2RBoxV2PDetectorCrop',
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=1,
+ zero_init_residual=False,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=True,
+ style='pytorch',
+ init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
+ neck=dict(
+ type='FPN',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ start_level=1,
+ add_extra_convs='on_output', # use P5
+ num_outs=5,
+ relu_before_extra_convs=True),
+ bbox_head=dict(
+ type='H2RBoxV2PHead',
+ num_classes=48,
+ in_channels=256,
+ stacked_convs=4,
+ feat_channels=256,
+ strides=[8, 16, 32, 64, 128],
+ center_sampling=True,
+ center_sample_radius=1.5,
+ norm_on_bbox=True,
+ centerness_on_reg=True,
+ square_cls=[4, 44],
+ # resize_cls=[1],
+ scale_angle=False,
+ bbox_coder=dict(
+ type='DistanceAnglePointCoder', angle_version=angle_version),
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_bbox=dict(type='IoULoss', loss_weight=1.0),
+ loss_centerness=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+ loss_ss_symmetry=dict(
+ type='SmoothL1Loss', loss_weight=0.2, beta=0.1)),
+ # training and testing settings
+ train_cfg=None,
+ test_cfg=dict(
+ nms_pre=2000,
+ min_bbox_size=0,
+ score_thr=0.05,
+ nms=dict(iou_thr=0.1),
+ max_per_img=2000))
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', with_bbox=True),
+ dict(type='RResize', img_scale=(1024, 1024)),
+ dict(
+ type='RRandomFlip',
+ flip_ratio=[0.25, 0.25, 0.25],
+ direction=['horizontal', 'vertical', 'diagonal'],
+ version=angle_version),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
+]
+
+data = dict(
+ train=dict(pipeline=train_pipeline, version=angle_version),
+ val=dict(version=angle_version),
+ test=dict(version=angle_version))
+
+optimizer = dict(
+ _delete_=True,
+ type='AdamW',
+ lr=0.0001,
+ betas=(0.9, 0.999),
+ weight_decay=0.05)
+
+checkpoint_config = dict(interval=1, max_keep_ckpts=1)
+evaluation = dict(interval=6, metric='mAP')
+
diff --git a/mmrotate/core/bbox/assigners/__init__.py b/mmrotate/core/bbox/assigners/__init__.py
index ccf64f9..f2991c9 100644
--- a/mmrotate/core/bbox/assigners/__init__.py
+++ b/mmrotate/core/bbox/assigners/__init__.py
@@ -6,8 +6,11 @@
from .sas_assigner import SASAssigner
from .rotated_hungarian_assigner import Rotated_HungarianAssigner
from .ars_hungarian_assigner import ARS_HungarianAssigner
+from .coarse2fine_assigner import C2FAssigner
+from .ranking_assigner import RRankingAssigner
__all__ = [
'ConvexAssigner', 'MaxConvexIoUAssigner', 'SASAssigner', 'ATSSKldAssigner',
'ATSSObbAssigner', 'Rotated_HungarianAssigner', 'ARS_HungarianAssigner',
+ 'C2FAssigner', 'RRankingAssigner'
]
diff --git a/mmrotate/core/bbox/assigners/coarse2fine_assigner.py b/mmrotate/core/bbox/assigners/coarse2fine_assigner.py
new file mode 100644
index 0000000..b222e62
--- /dev/null
+++ b/mmrotate/core/bbox/assigners/coarse2fine_assigner.py
@@ -0,0 +1,316 @@
+import torch
+import json
+import numpy
+
+from ..builder import build_bbox_coder
+from mmdet.core.bbox.iou_calculators import build_iou_calculator
+from mmdet.core.bbox.assigners.assign_result import AssignResult
+from mmdet.core.bbox.assigners.base_assigner import BaseAssigner
+from ..builder import ROTATED_BBOX_ASSIGNERS
+#from mmcv.utils import build_from_cfg
+
+
+@ROTATED_BBOX_ASSIGNERS.register_module()
+class C2FAssigner(BaseAssigner):
+ """Assign a corresponding gt bbox or background to each bbox.
+
+ Each proposals will be assigned with `-1`, or a semi-positive integer
+ indicating the ground truth index.
+
+ - -1: negative sample, no assigned gt
+ - semi-positive integer: positive sample, index (0-based) of assigned gt
+
+ Args:
+ pos_iou_thr (float): IoU threshold for positive bboxes.
+ neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
+ min_pos_iou (float): Minimum iou for a bbox to be considered as a
+ positive bbox. Positive samples can have smaller IoU than
+ pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
+ gt_max_assign_all (bool): Whether to assign all bboxes with the same
+ highest overlap with some gt to that gt.
+ ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
+ `gt_bboxes_ignore` is specified). Negative values mean not
+ ignoring any bboxes.
+ ignore_wrt_candidates (bool): Whether to compute the iof between
+ `bboxes` and `gt_bboxes_ignore`, or the contrary.
+ match_low_quality (bool): Whether to allow low quality matches. This is
+ usually allowed for RPN and single stage detectors, but not allowed
+ in the second stage. Details are demonstrated in Step 4.
+ gpu_assign_thr (int): The upper bound of the number of GT for GPU
+ assign. When the number of gt is above this threshold, will assign
+ on CPU device. Negative values mean not assign on CPU.
+ """
+
+ def __init__(self,
+ gt_max_assign_all=True,
+ ignore_iof_thr=-1,
+ ignore_wrt_candidates=True,
+ gpu_assign_thr=512,
+ iou_calculator=dict(type='BboxOverlaps2D'),
+ assign_metric='gjsd',
+ topk=1,
+ topq=1,
+ constraint=False,
+ gauss_thr = 1.0,
+ bbox_coder=dict(
+ type='DeltaXYWHAOBBoxCoder',
+ target_means=(.0, .0, .0, .0, .0),
+ target_stds=(1.0, 1.0, 1.0, 1.0, 1.0))):
+ self.gt_max_assign_all = gt_max_assign_all
+ self.ignore_iof_thr = ignore_iof_thr
+ self.ignore_wrt_candidates = ignore_wrt_candidates
+ self.gpu_assign_thr = gpu_assign_thr
+ self.iou_calculator = build_iou_calculator(iou_calculator)
+ self.assign_metric = assign_metric
+ self.topk = topk
+ self.topq = topq
+ self.constraint = constraint
+ self.gauss_thr = gauss_thr
+ self.bbox_coder = build_bbox_coder(bbox_coder)
+
+ def assign(self, cls_scores, bbox_preds, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
+ """Assign gt to bboxes.
+ """
+ assign_on_cpu = True if (self.gpu_assign_thr >= 0) and (
+ gt_bboxes.shape[0] > self.gpu_assign_thr) else False
+ # compute overlap and assign gt on CPU when number of GT is large
+ if assign_on_cpu:
+ device = bboxes.device
+ bboxes = bboxes.cpu()
+ gt_bboxes = gt_bboxes.cpu()
+ if gt_bboxes_ignore is not None:
+ gt_bboxes_ignore = gt_bboxes_ignore.cpu()
+ if gt_labels is not None:
+ gt_labels = gt_labels.cpu()
+
+ overlaps = self.iou_calculator(gt_bboxes, bboxes, mode=self.assign_metric)
+
+ if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
+ and gt_bboxes_ignore.numel() > 0 and bboxes.numel() > 0):
+ if self.ignore_wrt_candidates:
+ ignore_overlaps = self.iou_calculator(
+ bboxes, gt_bboxes_ignore, mode='iof')
+ ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
+ else:
+ ignore_overlaps = self.iou_calculator(
+ gt_bboxes_ignore, bboxes, mode='iof')
+ ignore_max_overlaps, _ = ignore_overlaps.max(dim=0)
+ overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
+
+ num_gts, num_bboxes = overlaps.size(0), overlaps.size(1)
+
+ # 1. assign -1 by default
+ assigned_gt_inds = overlaps.new_full((num_bboxes,),
+ -1,
+ dtype=torch.long)
+
+ if num_gts == 0 or num_bboxes == 0:
+ # No ground truth or boxes, return empty assignment
+ max_overlaps = overlaps.new_zeros((num_bboxes,))
+ if num_gts == 0:
+ # No truth, assign everything to background
+ assigned_gt_inds[:] = 0
+ if gt_labels is None:
+ assigned_labels = None
+ else:
+ assigned_labels = overlaps.new_full((num_bboxes,),
+ -1,
+ dtype=torch.long)
+ return AssignResult(
+ num_gts,
+ assigned_gt_inds,
+ max_overlaps,
+ labels=assigned_labels)
+
+ # for each anchor, which gt best overlaps with it
+ # for each anchor, the max iou of all gts
+ max_overlaps, _ = overlaps.max(dim=0)
+ # for each gt, topk anchors
+ # for each gt, the topk of all proposals
+ gt_max_overlaps, _ = overlaps.topk(self.topk, dim=1, largest=True, sorted=True) # gt_argmax_overlaps [num_gt, k]
+ assigned_gt_inds[(max_overlaps >= 0) & (max_overlaps < 0.8)] = 0
+
+ for i in range(num_gts):
+ for j in range(self.topk):
+ max_overlap_inds = overlaps[i,:] == gt_max_overlaps[i,j]
+ assigned_gt_inds[max_overlap_inds] = i + 1
+
+ device = bboxes.device
+ bbox_preds = bbox_preds.to(device)
+ cls_scores = cls_scores.to(device)
+ bbox_preds = torch.transpose(bbox_preds, 0, 1)
+ bbox_preds = self.bbox_coder.decode(bboxes, bbox_preds)
+
+ num_gt = gt_bboxes.size(0)
+ num_bboxes = bboxes.size(0)
+
+ can_positive_mask = assigned_gt_inds > 0
+ can_positive_inds = torch.nonzero(can_positive_mask)
+
+ poscan = assigned_gt_inds[can_positive_inds].squeeze(-1)
+ can_other_mask = assigned_gt_inds <= 0
+
+ can_pos_scores = cls_scores[:,can_positive_inds].squeeze(-1)
+
+ can_pos_scores = torch.transpose(can_pos_scores, 0, 1)
+ can_bbox_pred = bbox_preds[can_positive_inds,:].squeeze(-1)
+
+ can_pos_iou = self.iou_calculator(gt_bboxes.to(device), can_bbox_pred, mode ='iou')
+ can_pos_iou = can_pos_iou[poscan-1,range(poscan.size(0))]
+ can_pos_cls, _ = torch.max(can_pos_scores,1)
+
+ can_pos_quality = can_pos_iou + can_pos_cls.sigmoid()
+ can_pos_quality = can_pos_quality.unsqueeze(0).repeat(num_gt, 1) # size of gt, pos anchors
+
+ gt_poscan = torch.zeros_like(can_pos_quality) - 100 # size of gt, pos anchors
+ gt_poscan[poscan-1,range(poscan.size(0))] = can_pos_quality[poscan-1,range(poscan.size(0))]
+
+ if self.topq >= can_pos_quality.size(1):
+ topq = can_pos_quality.size(1)
+ else:
+ topq = self.topq
+ gt_max_quality, gt_argmax_quality = gt_poscan.topk(topq, dim=1, largest=True, sorted=True) # gt_argmax_quality [num_gt, q]
+
+ assign_result_pre_gt = assigned_gt_inds
+
+ assigned_gt_inds_init = assign_result_pre_gt * can_other_mask
+ assigned_pos_prior = torch.zeros((num_gt, topq, 5),device=device)
+
+ for i in range(num_gt):
+ for j in range(topq):
+ index = gt_argmax_quality[i,j]
+ remap_inds = can_positive_inds[index,0]
+ assigned_gt_inds_init[remap_inds] = assign_result_pre_gt [remap_inds]
+ assigned_pos_prior[i,j,:] = bboxes[remap_inds,:]
+ assigned_gt_inds = assigned_gt_inds_init
+
+ if self.constraint == 'dgmm':
+ device1 = gt_bboxes.device
+ xy_gt, sigma_t = self.xy_wh_r_2_xy_sigma(gt_bboxes)
+ # get the mean of the positive samples
+ pos_prior_mean = torch.mean(assigned_pos_prior[...,:2], dim=-2)
+ _, sigma_t = self.xy_wh_r_2_xy_sigma(gt_bboxes)
+ xy_pt = pos_prior_mean
+ xy_a = bboxes[...,:2]
+ xy_gt = xy_gt[...,None,:,:2].unsqueeze(-1)
+ xy_pt = xy_pt[...,None,:,:2].unsqueeze(-1)
+ xy_a = xy_a[...,:,None,:2].unsqueeze(-1)
+ inv_sigma_t = torch.stack((sigma_t[..., 1, 1], -sigma_t[..., 0, 1],
+ -sigma_t[..., 1, 0], sigma_t[..., 0, 0]),
+ dim=-1).reshape(-1, 2, 2)
+ inv_sigma_t = inv_sigma_t / sigma_t.det().unsqueeze(-1).unsqueeze(-1)
+ gaussian_gt = torch.exp(-0.5*(xy_a-xy_gt).permute(0, 1, 3, 2).matmul(inv_sigma_t).matmul(xy_a-xy_gt)).squeeze(-1).squeeze(-1)
+ gaussian_pt = torch.exp(-0.5*(xy_a-xy_pt).permute(0, 1, 3, 2).matmul(inv_sigma_t).matmul(xy_a-xy_pt)).squeeze(-1).squeeze(-1)
+ gaussian = 0.7*gaussian_gt + 0.3*gaussian_pt
+
+ inside_flag = gaussian >= torch.exp(torch.tensor([-self.gauss_thr])).to(device1)
+ length = range(assigned_gt_inds.size(0))
+ inside_mask = inside_flag[length, (assigned_gt_inds-1).clamp(min=0)]
+ assigned_gt_inds *= inside_mask
+
+ if gt_labels is not None:
+ assigned_labels = assigned_gt_inds.new_full((num_bboxes,), -1)
+ pos_inds = torch.nonzero(
+ assigned_gt_inds > 0, as_tuple=False).squeeze()
+ if pos_inds.numel() > 0:
+ assigned_labels[pos_inds] = gt_labels[
+ assigned_gt_inds[pos_inds] - 1]
+ else:
+ assigned_labels = None
+
+ assign_result = AssignResult(
+ num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)
+
+ if assign_on_cpu:
+ assign_result.gt_inds = assign_result.gt_inds.to(device)
+ assign_result.max_overlaps = assign_result.max_overlaps.to(device)
+ if assign_result.labels is not None:
+ assign_result.labels = assign_result.labels.to(device)
+
+ return assign_result
+
+ def assign_wrt_ranking(self, overlaps, gt_labels=None):
+ num_gts, num_bboxes = overlaps.size(0), overlaps.size(1)
+
+ # 1. assign -1 by default
+ assigned_gt_inds = overlaps.new_full((num_bboxes,),
+ -1,
+ dtype=torch.long)
+
+ if num_gts == 0 or num_bboxes == 0:
+ # No ground truth or boxes, return empty assignment
+ max_overlaps = overlaps.new_zeros((num_bboxes,))
+ if num_gts == 0:
+ # No truth, assign everything to background
+ assigned_gt_inds[:] = 0
+ if gt_labels is None:
+ assigned_labels = None
+ else:
+ assigned_labels = overlaps.new_full((num_bboxes,),
+ -1,
+ dtype=torch.long)
+ return AssignResult(
+ num_gts,
+ assigned_gt_inds,
+ max_overlaps,
+ labels=assigned_labels)
+
+ # for each anchor, which gt best overlaps with it
+ # for each anchor, the max iou of all gts
+ max_overlaps, _ = overlaps.max(dim=0)
+ # for each gt, topk anchors
+ # for each gt, the topk of all proposals
+ gt_max_overlaps, _ = overlaps.topk(self.topk, dim=1, largest=True, sorted=True) # gt_argmax_overlaps [num_gt, k]
+
+
+ assigned_gt_inds[(max_overlaps >= 0)
+ & (max_overlaps < 0.8)] = 0
+
+ for i in range(num_gts):
+ for j in range(self.topk):
+ max_overlap_inds = overlaps[i,:] == gt_max_overlaps[i,j]
+ assigned_gt_inds[max_overlap_inds] = i + 1
+
+ if gt_labels is not None:
+ assigned_labels = assigned_gt_inds.new_full((num_bboxes,), -1)
+ pos_inds = torch.nonzero(
+ assigned_gt_inds > 0, as_tuple=False).squeeze()
+ if pos_inds.numel() > 0:
+ assigned_labels[pos_inds] = gt_labels[
+ assigned_gt_inds[pos_inds] - 1]
+ else:
+ assigned_labels = None
+
+ return AssignResult(
+ num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)
+
+ def xy_wh_r_2_xy_sigma(self, xywhr):
+ """Convert oriented bounding box to 2-D Gaussian distribution.
+
+ Args:
+ xywhr (torch.Tensor): rbboxes with shape (N, 5).
+
+ Returns:
+ xy (torch.Tensor): center point of 2-D Gaussian distribution
+ with shape (N, 2).
+ sigma (torch.Tensor): covariance matrix of 2-D Gaussian distribution
+ with shape (N, 2, 2).
+ """
+ _shape = xywhr.shape
+ assert _shape[-1] == 5
+ xy = xywhr[..., :2]
+ wh = xywhr[..., 2:4].clamp(min=1e-7, max=1e7).reshape(-1, 2)
+ r = xywhr[..., 4]
+ cos_r = torch.cos(r)
+ sin_r = torch.sin(r)
+ R = torch.stack((cos_r, -sin_r, sin_r, cos_r), dim=-1).reshape(-1, 2, 2)
+ S = 0.5 * torch.diag_embed(wh)
+
+ sigma = R.bmm(S.square()).bmm(R.permute(0, 2,
+ 1)).reshape(_shape[:-1] + (2, 2))
+
+ return xy, sigma
+
+
+
+
\ No newline at end of file
diff --git a/mmrotate/core/bbox/assigners/ranking_assigner.py b/mmrotate/core/bbox/assigners/ranking_assigner.py
new file mode 100644
index 0000000..aa8940a
--- /dev/null
+++ b/mmrotate/core/bbox/assigners/ranking_assigner.py
@@ -0,0 +1,263 @@
+import torch
+
+from mmdet.core.bbox.builder import BBOX_ASSIGNERS
+from mmdet.core.bbox.iou_calculators import build_iou_calculator
+from mmdet.core.bbox.assigners.assign_result import AssignResult
+from mmdet.core.bbox.assigners.base_assigner import BaseAssigner
+from mmrotate.core.bbox.transforms import hbb2obb
+
+
+@BBOX_ASSIGNERS.register_module()
+class RRankingAssigner(BaseAssigner):
+ """Assign a corresponding gt bbox or background to each bbox.
+
+ Each proposals will be assigned with `-1`, or a semi-positive integer
+ indicating the ground truth index.
+
+ - -1: negative sample, no assigned gt
+ - semi-positive integer: positive sample, index (0-based) of assigned gt
+
+ Args:
+ pos_iou_thr (float): IoU threshold for positive bboxes.
+ neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
+ min_pos_iou (float): Minimum iou for a bbox to be considered as a
+ positive bbox. Positive samples can have smaller IoU than
+ pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
+ gt_max_assign_all (bool): Whether to assign all bboxes with the same
+ highest overlap with some gt to that gt.
+ ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
+ `gt_bboxes_ignore` is specified). Negative values mean not
+ ignoring any bboxes.
+ ignore_wrt_candidates (bool): Whether to compute the iof between
+ `bboxes` and `gt_bboxes_ignore`, or the contrary.
+ match_low_quality (bool): Whether to allow low quality matches. This is
+ usually allowed for RPN and single stage detectors, but not allowed
+ in the second stage. Details are demonstrated in Step 4.
+ gpu_assign_thr (int): The upper bound of the number of GT for GPU
+ assign. When the number of gt is above this threshold, will assign
+ on CPU device. Negative values mean not assign on CPU.
+ """
+
+ def __init__(self,
+ gt_max_assign_all=True,
+ ignore_iof_thr=-1,
+ ignore_wrt_candidates=True,
+ gpu_assign_thr=512,
+ iou_calculator=dict(type='BboxOverlaps2D'),
+ assign_metric='iou',
+ topk=1,
+ inside_circle=False,
+ gauss_thr = 1.5,
+ version ='le135'):
+ self.gt_max_assign_all = gt_max_assign_all
+ self.ignore_iof_thr = ignore_iof_thr
+ self.ignore_wrt_candidates = ignore_wrt_candidates
+ self.gpu_assign_thr = gpu_assign_thr
+ self.iou_calculator = build_iou_calculator(iou_calculator)
+ self.assign_metric = assign_metric
+ self.topk = topk
+ self.inside_circle = inside_circle
+ self.gauss_thr = gauss_thr
+ self.angle_version = version
+
+ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
+ """Assign gt to bboxes.
+
+ This method assign a gt bbox to every bbox (proposal/anchor), each bbox
+ will be assigned with -1, or a semi-positive number. -1 means negative
+ sample, semi-positive number is the index (0-based) of assigned gt.
+ The assignment is done in following steps, the order matters.
+
+ 1. assign every bbox to the background
+ 2. assign proposals whose iou with all gts < neg_iou_thr to 0
+ 3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
+ assign it to that bbox
+ 4. for each gt bbox, assign its nearest proposals (may be more than
+ one) to itself
+
+ Args:
+ bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`, e.g., crowd boxes in COCO.
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
+
+ Returns:
+ :obj:`AssignResult`: The assign result.
+
+ Example:
+ >>> self = MaxIoUAssigner(0.5, 0.5)
+ >>> bboxes = torch.Tensor([[0, 0, 10, 10], [10, 10, 20, 20]])
+ >>> gt_bboxes = torch.Tensor([[0, 0, 10, 9]])
+ >>> assign_result = self.assign(bboxes, gt_bboxes)
+ >>> expected_gt_inds = torch.LongTensor([1, 0])
+ >>> assert torch.all(assign_result.gt_inds == expected_gt_inds)
+ """
+
+ assign_on_cpu = True if (self.gpu_assign_thr >= 0) and (
+ gt_bboxes.shape[0] > self.gpu_assign_thr) else False
+ # compute overlap and assign gt on CPU when number of GT is large
+ if assign_on_cpu:
+ device = bboxes.device
+ bboxes = bboxes.cpu()
+ gt_bboxes = gt_bboxes.cpu()
+ if gt_bboxes_ignore is not None:
+ gt_bboxes_ignore = gt_bboxes_ignore.cpu()
+ if gt_labels is not None:
+ gt_labels = gt_labels.cpu()
+
+ overlaps = self.iou_calculator(gt_bboxes, bboxes, mode=self.assign_metric, version=self.angle_version)
+
+ if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
+ and gt_bboxes_ignore.numel() > 0 and bboxes.numel() > 0):
+ if self.ignore_wrt_candidates:
+ ignore_overlaps = self.iou_calculator(
+ bboxes, gt_bboxes_ignore, mode='iof')
+ ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
+ else:
+ ignore_overlaps = self.iou_calculator(
+ gt_bboxes_ignore, bboxes, mode='iof')
+ ignore_max_overlaps, _ = ignore_overlaps.max(dim=0)
+ overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
+ '''
+ if self.anchor_compensate == False:
+ assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
+ else:
+ assign_result = self.assign_wrt_overlaps_eucliean_distance(bboxes, gt_bboxes, overlaps)
+ '''
+ assign_result =self.assign_wrt_ranking(overlaps, gt_labels)
+
+ if self.inside_circle == 'circle':
+ center_distance = self.iou_calculator(gt_bboxes, bboxes, mode = 'center_distance2')
+ width_gt = gt_bboxes[...,2]
+ height_gt = gt_bboxes[...,3]
+ # scale [0, 32]^2 r=2, scale [32, 256]^2 r= 1.5, scale [256, +inf]^2 r=1 for scale normalization
+ '''
+ scale = width_gt * height_gt
+ scale_1 = scale <= 32*32
+ scale_2 = (scale> 32*32) & (scale <= 256*256)
+ scale_3 = scale > 256*256
+ r = [2, 1.5, 1]
+ gt_circle2 = ((width_gt/2)**2 + (height_gt/2) **2) *(scale_1*r[0]**2+scale_2*r[1]**2+scale_3*r[2]**2)
+ '''
+ r=1.5
+ gt_circle = ((width_gt/2)**2 + (height_gt/2) **2) * r * r
+ inside_flag = center_distance <= gt_circle[...,None]
+ length = range(assign_result.gt_inds.size(0))
+ inside_mask = inside_flag[(assign_result.gt_inds-1).clamp(min=0), length]
+ assign_result.gt_inds *= inside_mask
+
+ elif self.inside_circle == 'gaussian':
+
+ if gt_bboxes.size(-1) == 4:
+ gt_bboxes = hbb2obb(gt_bboxes, version=self.angle_version)
+ if bboxes.size(-1) == 4:
+ bboxes = hbb2obb(bboxes, version=self.angle_version)
+
+ device1 = gt_bboxes.device
+ xy_t, sigma_t = self.xy_wh_r_2_xy_sigma(gt_bboxes)
+ xy_a = bboxes[...,:2]
+ xy_t = xy_t[...,None,:,:2].unsqueeze(-1)
+ xy_a = xy_a[...,:,None,:2].unsqueeze(-1)
+ inv_sigma_t = torch.stack((sigma_t[..., 1, 1], -sigma_t[..., 0, 1],
+ -sigma_t[..., 1, 0], sigma_t[..., 0, 0]),
+ dim=-1).reshape(-1, 2, 2)
+ inv_sigma_t = inv_sigma_t / sigma_t.det().unsqueeze(-1).unsqueeze(-1)
+ gaussian = torch.exp(-0.5*(xy_a-xy_t).permute(0, 1, 3, 2).matmul(inv_sigma_t).matmul(xy_a-xy_t)).squeeze(-1).squeeze(-1) #/(2*3.1415926*sigma_t.det())
+ inside_flag = gaussian >= torch.exp(torch.tensor([-self.gauss_thr])).to(device1)
+ length = range(assign_result.gt_inds.size(0))
+ inside_mask = inside_flag[length, (assign_result.gt_inds-1).clamp(min=0)]
+ assign_result.gt_inds *= inside_mask
+
+
+ if assign_on_cpu:
+ assign_result.gt_inds = assign_result.gt_inds.to(device)
+ assign_result.max_overlaps = assign_result.max_overlaps.to(device)
+ if assign_result.labels is not None:
+ assign_result.labels = assign_result.labels.to(device)
+ return assign_result
+
+ def assign_wrt_ranking(self, overlaps, gt_labels=None):
+ num_gts, num_bboxes = overlaps.size(0), overlaps.size(1)
+
+ # 1. assign -1 by default
+ assigned_gt_inds = overlaps.new_full((num_bboxes,),
+ -1,
+ dtype=torch.long)
+
+ if num_gts == 0 or num_bboxes == 0:
+ # No ground truth or boxes, return empty assignment
+ max_overlaps = overlaps.new_zeros((num_bboxes,))
+ if num_gts == 0:
+ # No truth, assign everything to background
+ assigned_gt_inds[:] = 0
+ if gt_labels is None:
+ assigned_labels = None
+ else:
+ assigned_labels = overlaps.new_full((num_bboxes,),
+ -1,
+ dtype=torch.long)
+ return AssignResult(
+ num_gts,
+ assigned_gt_inds,
+ max_overlaps,
+ labels=assigned_labels)
+
+ # for each anchor, which gt best overlaps with it
+ # for each anchor, the max iou of all gts
+ max_overlaps, argmax_overlaps = overlaps.max(dim=0)
+ # for each gt, topk anchors
+ # for each gt, the topk of all proposals
+ gt_max_overlaps, gt_argmax_overlaps = overlaps.topk(self.topk, dim=1, largest=True, sorted=True) # gt_argmax_overlaps [num_gt, k]
+
+
+ assigned_gt_inds[(max_overlaps >= 0)
+ & (max_overlaps < 0.8)] = 0
+ #assign wrt ranking
+ for i in range(num_gts):
+ for j in range(self.topk):
+ max_overlap_inds = overlaps[i,:] == gt_max_overlaps[i,j]
+ assigned_gt_inds[max_overlap_inds] = i + 1
+
+ if gt_labels is not None:
+ assigned_labels = assigned_gt_inds.new_full((num_bboxes,), -1)
+ pos_inds = torch.nonzero(
+ assigned_gt_inds > 0, as_tuple=False).squeeze()
+ if pos_inds.numel() > 0:
+ assigned_labels[pos_inds] = gt_labels[
+ assigned_gt_inds[pos_inds] - 1]
+ else:
+ assigned_labels = None
+
+ return AssignResult(
+ num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)
+
+ def xy_wh_r_2_xy_sigma(self, xywhr):
+ """Convert oriented bounding box to 2-D Gaussian distribution.
+
+ Args:
+ xywhr (torch.Tensor): rbboxes with shape (N, 5).
+
+ Returns:
+ xy (torch.Tensor): center point of 2-D Gaussian distribution
+ with shape (N, 2).
+ sigma (torch.Tensor): covariance matrix of 2-D Gaussian distribution
+ with shape (N, 2, 2).
+ """
+ _shape = xywhr.shape
+ assert _shape[-1] == 5
+ xy = xywhr[..., :2]
+ wh = xywhr[..., 2:4].clamp(min=1e-7, max=1e7).reshape(-1, 2)
+ r = xywhr[..., 4]
+ cos_r = torch.cos(r)
+ sin_r = torch.sin(r)
+ R = torch.stack((cos_r, -sin_r, sin_r, cos_r), dim=-1).reshape(-1, 2, 2)
+ S = 0.5 * torch.diag_embed(wh)
+
+ sigma = R.bmm(S.square()).bmm(R.permute(0, 2,
+ 1)).reshape(_shape[:-1] + (2, 2))
+
+ return xy, sigma
+
+
+
\ No newline at end of file
diff --git a/mmrotate/core/bbox/iou_calculators/__init__.py b/mmrotate/core/bbox/iou_calculators/__init__.py
index 5902ab0..a7ff8b3 100644
--- a/mmrotate/core/bbox/iou_calculators/__init__.py
+++ b/mmrotate/core/bbox/iou_calculators/__init__.py
@@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import build_iou_calculator
from .rotate_iou2d_calculator import RBboxOverlaps2D, rbbox_overlaps
+from .rotate_metric_calculator import RBboxMetrics2D
-__all__ = ['build_iou_calculator', 'RBboxOverlaps2D', 'rbbox_overlaps']
+
+__all__ = ['build_iou_calculator', 'RBboxOverlaps2D', 'rbbox_overlaps', 'RBboxMetrics2D']
diff --git a/mmrotate/core/bbox/iou_calculators/rotate_metric_calculator.py b/mmrotate/core/bbox/iou_calculators/rotate_metric_calculator.py
new file mode 100644
index 0000000..7fd8dab
--- /dev/null
+++ b/mmrotate/core/bbox/iou_calculators/rotate_metric_calculator.py
@@ -0,0 +1,212 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from mmcv.ops import box_iou_rotated
+
+from .builder import ROTATED_IOU_CALCULATORS
+from mmrotate.core.bbox.transforms import hbb2obb
+
+
+
+@ROTATED_IOU_CALCULATORS.register_module()
+class RBboxMetrics2D(object):
+ """2D Overlaps (e.g. IoUs, GIoUs) Calculator."""
+
+ def __call__(self,
+ bboxes1,
+ bboxes2,
+ mode='iou',
+ is_aligned=False,
+ version='oc'):
+ """Calculate IoU between 2D bboxes.
+
+ Args:
+ bboxes1 (torch.Tensor): bboxes have shape (m, 5) in
+ format, or shape (m, 6) in
+ format.
+ bboxes2 (torch.Tensor): bboxes have shape (m, 5) in
+ format, shape (m, 6) in
+ format, or be empty.
+ If ``is_aligned `` is ``True``, then m and n must be equal.
+ mode (str): "iou" (intersection over union), "iof" (intersection
+ over foreground), or "giou" (generalized intersection over
+ union).
+ is_aligned (bool, optional): If True, then m and n must be equal.
+ Default False.
+ version (str, optional): Angle representations. Defaults to 'oc'.
+
+ Returns:
+ Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,)
+ """
+ assert bboxes1.size(-1) in [0, 2, 4, 5, 6]
+ assert bboxes2.size(-1) in [0, 2, 4, 5, 6]
+
+ if bboxes1.size(-1) == 4:
+ bboxes1 = hbb2obb(bboxes1, version)
+ if bboxes2.size(-1) == 4:
+ bboxes2 = hbb2obb(bboxes2, version)
+
+ if bboxes2.size(-1) == 6:
+ bboxes2 = bboxes2[..., :5]
+ if bboxes1.size(-1) == 6:
+ bboxes1 = bboxes1[..., :5]
+ return rbbox_metrics(bboxes1.contiguous(), bboxes2.contiguous(), mode,
+ is_aligned)
+
+ def __repr__(self):
+ """str: a string describing the module"""
+ repr_str = self.__class__.__name__ + '()'
+ return repr_str
+
+
+def rbbox_metrics(bboxes1, bboxes2, mode='iou', is_aligned=False):
+ """Calculate overlap between two set of bboxes.
+
+ Args:
+ bboxes1 (torch.Tensor): shape (B, m, 5) in format
+ or empty.
+ bboxes2 (torch.Tensor): shape (B, n, 5) in format
+ or empty.
+ mode (str): "iou" (intersection over union), "iof" (intersection over
+ foreground) or "giou" (generalized intersection over union).
+ Default "iou".
+ is_aligned (bool, optional): If True, then m and n must be equal.
+ Default False.
+
+ Returns:
+ Tensor: shape (m, n) if ``is_aligned`` is False else shape (m,)
+ """
+ assert mode in ['iou', 'iof','gjsd','center_distance2']
+ # Either the boxes are empty or the length of boxes's last dimension is 5
+ if mode in ['center_distance2']:
+ pass
+ else:
+ assert (bboxes1.size(-1) == 5 or bboxes1.size(0) == 0)
+ assert (bboxes2.size(-1) == 5 or bboxes2.size(0) == 0)
+
+ rows = bboxes1.size(0)
+ cols = bboxes2.size(0)
+ if is_aligned:
+ assert rows == cols
+
+ if rows * cols == 0:
+ return bboxes1.new(rows, 1) if is_aligned else bboxes1.new(rows, cols)
+
+ if mode in ['iou','iof']:
+ # resolve `rbbox_overlaps` abnormal when input rbbox is too small.
+ clamped_bboxes1 = bboxes1.detach().clone()
+ clamped_bboxes2 = bboxes2.detach().clone()
+ clamped_bboxes1[:, 2:4].clamp_(min=1e-3)
+ clamped_bboxes2[:, 2:4].clamp_(min=1e-3)
+
+ return box_iou_rotated(clamped_bboxes1, clamped_bboxes2, mode, is_aligned)
+
+ if mode == 'gjsd':
+ g_bboxes1 = xy_wh_r_2_xy_sigma(bboxes1)
+ g_bboxes2 = xy_wh_r_2_xy_sigma(bboxes2)
+ gjsd = get_gjsd(g_bboxes1,g_bboxes2)
+ distance = 1/(1+gjsd)
+
+ return distance
+
+ if mode == 'center_distance2':
+ center1 = bboxes1[..., :, None, :2]
+ center2 = bboxes2[..., None, :, :2]
+ whs = center1[..., :2] - center2[..., :2]
+
+ center_distance2 = whs[..., 0] * whs[..., 0] + whs[..., 1] * whs[..., 1] + 1e-6 #
+
+ #distance = torch.sqrt(center_distance2)
+
+ return center_distance2
+
+
+def xy_wh_r_2_xy_sigma(xywhr):
+ """Convert oriented bounding box to 2-D Gaussian distribution.
+
+ Args:
+ xywhr (torch.Tensor): rbboxes with shape (N, 5).
+
+ Returns:
+ xy (torch.Tensor): center point of 2-D Gaussian distribution
+ with shape (N, 2).
+ sigma (torch.Tensor): covariance matrix of 2-D Gaussian distribution
+ with shape (N, 2, 2).
+ """
+ _shape = xywhr.shape
+ assert _shape[-1] == 5
+ xy = xywhr[..., :2]
+ wh = xywhr[..., 2:4].clamp(min=1e-7, max=1e7).reshape(-1, 2)
+ r = xywhr[..., 4]
+ cos_r = torch.cos(r)
+ sin_r = torch.sin(r)
+ R = torch.stack((cos_r, -sin_r, sin_r, cos_r), dim=-1).reshape(-1, 2, 2)
+ S = 0.5 * torch.diag_embed(wh)
+
+ sigma = R.bmm(S.square()).bmm(R.permute(0, 2,
+ 1)).reshape(_shape[:-1] + (2, 2))
+
+ return xy, sigma
+
+
+def get_gjsd(pred, target, alpha=0.5):
+ xy_p, Sigma_p = pred # mu_1, sigma_1
+ xy_t, Sigma_t = target # mu_2, sigma_2
+
+ Sigma_p = Sigma_p.reshape(-1, 2, 2)
+ Sigma_t = Sigma_t.reshape(-1, 2, 2)
+
+
+ xy_p = xy_p[...,:,None,:2]
+ xy_t = xy_t[...,None,:,:2]
+
+ # get the inverse of Sigma_p and Sigma_t
+ Sigma_p_inv = torch.stack((Sigma_p[..., 1, 1], -Sigma_p[..., 0, 1],
+ -Sigma_p[..., 1, 0], Sigma_p[..., 0, 0]),
+ dim=-1).reshape(-1, 2, 2)
+ Sigma_p_inv = Sigma_p_inv / Sigma_p.det().unsqueeze(-1).unsqueeze(-1)
+ Sigma_t_inv = torch.stack((Sigma_t[..., 1, 1], -Sigma_t[..., 0, 1],
+ -Sigma_t[..., 1, 0], Sigma_t[..., 0, 0]),
+ dim=-1).reshape(-1, 2, 2)
+ Sigma_t_inv = Sigma_t_inv / Sigma_t.det().unsqueeze(-1).unsqueeze(-1)
+
+ Sigma_p = Sigma_p[...,:,None,:2,:2]
+ Sigma_p_inv = Sigma_p_inv[...,:,None,:2,:2]
+ Sigma_t = Sigma_t[...,None,:,:2,:2]
+ Sigma_t_inv = Sigma_t_inv[...,None,:,:2,:2]
+
+ Sigma_alpha_ori = ((1-alpha)*Sigma_p_inv + alpha*Sigma_t_inv)
+
+ # get the inverse of Sigma_alpha_ori, namely Sigma_alpha
+ Sigma_alpha = torch.stack((Sigma_alpha_ori[..., 1, 1], -Sigma_alpha_ori[..., 0, 1],
+ -Sigma_alpha_ori[..., 1, 0], Sigma_alpha_ori[..., 0, 0]),
+ dim=-1).reshape(Sigma_alpha_ori.size(0), Sigma_alpha_ori.size(1), 2, 2)
+ Sigma_alpha = Sigma_alpha / Sigma_alpha_ori.det().unsqueeze(-1).unsqueeze(-1)
+ # get the inverse of Sigma_alpha, namely Sigma_alpha_inv
+ Sigma_alpha_inv = torch.stack((Sigma_alpha[..., 1, 1], -Sigma_alpha[..., 0, 1],
+ -Sigma_alpha[..., 1, 0], Sigma_alpha[..., 0, 0]),
+ dim=-1).reshape(Sigma_alpha.size(0),Sigma_alpha.size(1), 2, 2)
+ Sigma_alpha_inv = Sigma_alpha_inv / Sigma_alpha.det().unsqueeze(-1).unsqueeze(-1)
+
+ # mu_alpha
+ xy_p = xy_p.unsqueeze(-1)
+ xy_t = xy_t.unsqueeze(-1)
+
+ mu_alpha_1 = (1-alpha)* Sigma_p_inv.matmul(xy_p) + alpha * Sigma_t_inv.matmul(xy_t)
+ mu_alpha = Sigma_alpha.matmul(mu_alpha_1)
+
+ # the first part of GJSD
+ first_part = (1-alpha) * xy_p.permute(0,1,3,2).matmul(Sigma_p_inv).matmul(xy_p) + alpha * xy_t.permute(0,1,3,2).matmul(Sigma_t_inv).matmul(xy_t) - mu_alpha.permute(0,1,3,2).matmul(Sigma_alpha_inv).matmul(mu_alpha)
+ second_part = ((Sigma_p.det() ** (1-alpha))*(Sigma_t.det() ** alpha))/(Sigma_alpha.det())
+ second_part = second_part.log()
+
+ if first_part.is_cuda:
+ gjsd = 0.5 * (first_part.half().squeeze(-1).squeeze(-1) + second_part.half())
+ #distance = 1/(1+gjsd)
+ else:
+ gjsd = 0.5 * (first_part.squeeze(-1).squeeze(-1) + second_part)
+ #distance = 1/(1+gjsd)
+
+ return gjsd
+
+
+
diff --git a/mmrotate/models/dense_heads/__init__.py b/mmrotate/models/dense_heads/__init__.py
index ad92a47..a0ff23e 100644
--- a/mmrotate/models/dense_heads/__init__.py
+++ b/mmrotate/models/dense_heads/__init__.py
@@ -24,6 +24,7 @@
from .psc_rotated_fcos_head import PSCRFCOSHead
from .kld_reppoints_head import KLDRepPointsHead
from .h2rbox_v2p_head import H2RBoxV2PHead
+from .dcfl_head import RDCFLHead
__all__ = [
'RotatedAnchorHead', 'RotatedRetinaHead', 'RotatedRPNHead',
@@ -33,5 +34,6 @@
'RotatedATSSHead', 'RotatedAnchorFreeHead', 'RotatedFCOSHead',
'CSLRFCOSHead', 'OrientedRepPointsHead', 'RotatedDETRHead',
'RotatedDeformableDETRHead', 'ARSDeformableDETRHead', 'DNARSDeformableDETRHead',
- 'H2RBoxHead', 'PSCRFCOSHead', 'KLDRepPointsHead', 'H2RBoxV2PHead'
+ 'H2RBoxHead', 'PSCRFCOSHead', 'KLDRepPointsHead', 'H2RBoxV2PHead',
+ 'RDCFLHead'
]
diff --git a/mmrotate/models/dense_heads/dcfl_head.py b/mmrotate/models/dense_heads/dcfl_head.py
new file mode 100644
index 0000000..f50a845
--- /dev/null
+++ b/mmrotate/models/dense_heads/dcfl_head.py
@@ -0,0 +1,895 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch
+from mmcv.ops import DeformConv2d, ModulatedDeformConv2d, modulated_deform_conv2d
+from mmcv.cnn import ConvModule
+from mmcv.runner import force_fp32
+from mmcv.utils import print_log
+from mmdet.core import images_to_levels, multi_apply, unmap
+from mmrotate.core import obb2hbb, rotated_anchor_inside_flags
+
+from ..builder import ROTATED_HEADS, build_loss
+from .rotated_anchor_head import RotatedAnchorHead
+
+
+class ModulatedDeformConvG(ModulatedDeformConv2d):
+ """A ModulatedDeformable Conv Encapsulation that acts as normal Conv
+ layers.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int): Same as nn.Conv2d, while tuple is not supported.
+ padding (int): Same as nn.Conv2d, while tuple is not supported.
+ dilation (int): Same as nn.Conv2d, while tuple is not supported.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+
+ _version = 2
+
+ def __init__(self, *args, **kwargs):
+ super(ModulatedDeformConvG, self).__init__(*args, **kwargs)
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ padding=self.padding,
+ dilation=self.dilation,
+ bias=True)
+ self.init_weights()
+
+ def init_weights(self):
+ super(ModulatedDeformConvG, self).init_weights()
+ if hasattr(self, 'conv_offset'):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ out = self.conv_offset(x)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+ out = modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
+ self.stride, self.padding,
+ self.dilation, self.groups,
+ self.deform_groups)
+ return out, offset, mask
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ version = local_metadata.get('version', None)
+
+ if version is None or version < 2:
+ # the key is different in early versions
+ # In version < 2, ModulatedDeformConvPack
+ # loads previous benchmark models.
+ if (prefix + 'conv_offset.weight' not in state_dict
+ and prefix[:-1] + '_offset.weight' in state_dict):
+ state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
+ prefix[:-1] + '_offset.weight')
+ if (prefix + 'conv_offset.bias' not in state_dict
+ and prefix[:-1] + '_offset.bias' in state_dict):
+ state_dict[prefix +
+ 'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
+ '_offset.bias')
+
+ if version is not None and version > 1:
+ print_log(
+ f'ModulatedDeformConvPack {prefix.rstrip(".")} is upgraded to '
+ 'version 2.',
+ logger='root')
+
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, missing_keys, unexpected_keys,
+ error_msgs)
+
+
+@ROTATED_HEADS.register_module()
+class RDCFLHead(RotatedAnchorHead):
+ r"""An anchor-based head used in `RotatedRetinaNet
+ `_.
+
+ The head contains two subnetworks. The first classifies anchor boxes and
+ the second regresses deltas for the anchors.
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ stacked_convs (int, optional): Number of stacked convolutions.
+ conv_cfg (dict, optional): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict, optional): Config dict for normalization layer.
+ Default: None.
+ anchor_generator (dict): Config dict for anchor generator
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ """ # noqa: W605
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ stacked_convs=4,
+ conv_cfg=None,
+ norm_cfg=None,
+ dcn_assign = False,
+ dilation_rate = 2,
+ anchor_generator=dict(
+ type='AnchorGenerator',
+ octave_base_scale=4,
+ scales_per_octave=3,
+ ratios=[0.5, 1.0, 2.0],
+ strides=[8, 16, 32, 64, 128]),
+ init_cfg=dict(
+ type='Normal',
+ layer='Conv2d',
+ std=0.01,
+ override=dict(
+ type='Normal',
+ name='retina_cls',
+ std=0.01,
+ bias_prob=0.01)),
+ **kwargs):
+ self.stacked_convs = stacked_convs
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.dcn_assign = dcn_assign
+ self.dilation_rate = dilation_rate
+ super(RDCFLHead, self).__init__(
+ num_classes,
+ in_channels,
+ anchor_generator=anchor_generator,
+ init_cfg=init_cfg,
+ **kwargs)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.relu = nn.ReLU(inplace=True)
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ if self.dcn_assign == False:
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ else:
+ for i in range(self.stacked_convs-2):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ for i in range(1):
+ self.cls_convs.append(
+ DeformConv2d(
+ self.feat_channels,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ groups=1,
+ bias=False))
+ self.reg_convs.append(
+ DeformConv2d(
+ self.feat_channels,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ groups=1,
+ bias=False))
+ for i in range(1):
+ self.cls_convs.append(
+ ModulatedDeformConv2d(
+ self.feat_channels,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ groups=1,
+ bias=False))
+ self.reg_convs.append(
+ ModulatedDeformConvG(
+ self.feat_channels,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ groups=1,
+ bias=False))
+ self.retina_cls = nn.Conv2d(
+ self.feat_channels,
+ self.num_anchors * self.cls_out_channels,
+ 3,
+ padding=1)
+ self.retina_reg = nn.Conv2d(
+ self.feat_channels, self.num_anchors * 5, 3, padding=1)
+
+
+ def forward_single(self, x):
+ """Forward feature of a single scale level.
+
+ Args:
+ x (torch.Tensor): Features of a single scale level.
+
+ Returns:
+ tuple (torch.Tensor):
+
+ - cls_score (torch.Tensor): Cls scores for a single scale \
+ level the channels number is num_anchors * num_classes.
+ - bbox_pred (torch.Tensor): Box energies / deltas for a \
+ single scale level, the channels number is num_anchors * 4.
+ """
+ cls_feat = x
+ reg_feat = x
+ if self.dcn_assign == False:
+ for cls_conv in self.cls_convs:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.reg_convs:
+ reg_feat = reg_conv(reg_feat)
+ else:
+ for reg_conv in self.reg_convs[:-2]:
+ reg_feat = reg_conv(reg_feat)
+
+ init_t = torch.Tensor(reg_feat.size(0), 1, reg_feat.size(-2), reg_feat.size(-1))
+ item = torch.ones_like(init_t, device=reg_feat.device) * (self.dilation_rate - 1)
+ zeros = torch.zeros_like(item, device=reg_feat.device)
+ sampling_loc = torch.cat((-item,-item,-item,zeros,-item,item,zeros,-item, zeros, zeros, zeros,item,item,-item,item,zeros,item,item), dim=1)
+
+ reg_feat = self.reg_convs[self.stacked_convs - 2](reg_feat, sampling_loc)
+ reg_feat, offsets_reg, mask_reg = self.reg_convs[self.stacked_convs - 1](reg_feat)
+
+ for cls_conv in self.cls_convs[:-2]:
+ cls_feat = cls_conv(cls_feat)
+ cls_feat = self.cls_convs[self.stacked_convs - 2](cls_feat, sampling_loc)
+ cls_feat = self.cls_convs[self.stacked_convs - 1](cls_feat, offsets_reg, mask_reg)
+
+ # offset batch_size * 18 * feature_size * feature_size (128,64,32,16,8), offset [y0, x0, y1, x1, y2, x2, ..., y8, x8]
+ cls_score = self.retina_cls(cls_feat)
+ bbox_pred = self.retina_reg(reg_feat)
+ return cls_score, bbox_pred, offsets_reg
+
+
+ def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights, bbox_targets, bbox_weights, num_total_samples):
+ """Compute loss of a single scale level.
+
+ Args:
+ cls_score (torch.Tensor): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W).
+ bbox_pred (torch.Tensor): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 5, H, W).
+ anchors (torch.Tensor): Box reference for each scale level with
+ shape (N, num_total_anchors, 5).
+ labels (torch.Tensor): Labels of each anchors with shape
+ (N, num_total_anchors).
+ label_weights (torch.Tensor): Label weights of each anchor with
+ shape (N, num_total_anchors)
+ bbox_targets (torch.Tensor): BBox regression targets of each anchor
+ weight shape (N, num_total_anchors, 5).
+ bbox_weights (torch.Tensor): BBox regression loss weights of each
+ anchor with shape (N, num_total_anchors, 5).
+ num_total_samples (int): If sampling, num total samples equal to
+ the number of total anchors; Otherwise, it is the number of
+ positive anchors.
+
+ Returns:
+ tuple (torch.Tensor):
+
+ - loss_cls (torch.Tensor): cls. loss for each scale level.
+ - loss_bbox (torch.Tensor): reg. loss for each scale level.
+ """
+ # classification loss
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+ cls_score = cls_score.permute(0, 2, 3,
+ 1).reshape(-1, self.cls_out_channels)
+
+ # regression loss
+ bbox_targets = bbox_targets.reshape(-1, 5)
+ bbox_weights = bbox_weights.reshape(-1, 5)
+ bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 5)
+
+ if self.reg_decoded_bbox:
+ anchors = anchors.reshape(-1, 5)
+ bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
+
+ loss_cls = self.loss_cls(
+ cls_score, labels, label_weights, avg_factor=num_total_samples)
+
+ loss_bbox = self.loss_bbox(
+ bbox_pred,
+ bbox_targets,
+ bbox_weights,
+ avg_factor=num_total_samples)
+
+ return loss_cls, loss_bbox
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ offsets,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 5, H, W)
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 5) in [cx, cy, w, h, a] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss. Default: None
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.anchor_generator.num_levels
+
+ device = cls_scores[0].device
+
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+ cls_reg_targets = self.get_targets(
+ cls_scores,
+ bbox_preds,
+ anchor_list,
+ valid_flag_list,
+ offsets,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels)
+ if cls_reg_targets is None:
+ return None
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
+ num_total_pos, num_total_neg) = cls_reg_targets
+ num_total_samples = (
+ num_total_pos + num_total_neg if self.sampling else num_total_pos)
+
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+ concat_anchor_list = []
+ for i, _ in enumerate(anchor_list):
+ concat_anchor_list.append(torch.cat(anchor_list[i]))
+ all_anchor_list = images_to_levels(concat_anchor_list,
+ num_level_anchors)
+ losses_cls, losses_bbox = multi_apply(
+ self.loss_single,
+ cls_scores,
+ bbox_preds,
+ all_anchor_list,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ num_total_samples=num_total_samples)
+ return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
+
+ def _get_targets_single(self,
+ flat_cls_scores,
+ flat_bbox_preds,
+ flat_anchors,
+ valid_flags,
+ offsets,
+ offsets_ori,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ img_meta,
+ label_channels=1,
+ unmap_outputs=False):
+ """Compute regression and classification targets for anchors in a
+ single image.
+
+ Args:
+ flat_anchors (torch.Tensor): Multi-level anchors of the image,
+ which are concatenated into a single tensor of shape
+ (num_anchors, 5)
+ valid_flags (torch.Tensor): Multi level valid flags of the image,
+ which are concatenated into a single tensor of
+ shape (num_anchors,).
+ gt_bboxes (torch.Tensor): Ground truth bboxes of the image,
+ shape (num_gts, 5).
+ img_meta (dict): Meta info of the image.
+ gt_bboxes_ignore (torch.Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 5).
+ img_meta (dict): Meta info of the image.
+ gt_labels (torch.Tensor): Ground truth labels of each box,
+ shape (num_gts,).
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Returns:
+ tuple (list[Tensor]):
+
+ - labels_list (list[Tensor]): Labels of each level
+ - label_weights_list (list[Tensor]): Label weights of each \
+ level
+ - bbox_targets_list (list[Tensor]): BBox targets of each level
+ - bbox_weights_list (list[Tensor]): BBox weights of each level
+ - num_total_pos (int): Number of positive samples in all images
+ - num_total_neg (int): Number of negative samples in all images
+ """
+ inside_flags = rotated_anchor_inside_flags(
+ flat_anchors, valid_flags, img_meta['img_shape'][:2],
+ self.train_cfg.allowed_border)
+ if not inside_flags.any():
+ return (None, ) * 7
+ anchors = flat_anchors
+ detached_offsets = offsets.detach()
+ dy = torch.zeros(1, detached_offsets.size(1)).cuda()
+ dx = torch.zeros(1, detached_offsets.size(1)).cuda()
+
+ for i in range(9):
+ dy += detached_offsets[2*i]/9
+ dx += detached_offsets[2*i+1]/9
+
+ flat_anchors[...,0] = flat_anchors[...,0] + dx
+ flat_anchors[...,1] = flat_anchors[...,1] + dy
+
+ deformable_anchors = flat_anchors
+
+
+ if self.assign_by_circumhbbox is not None:
+ gt_bboxes_assign = obb2hbb(gt_bboxes, self.assign_by_circumhbbox)
+ assign_result = self.assigner.assign(
+ flat_cls_scores, flat_bbox_preds, deformable_anchors, gt_bboxes_assign, gt_bboxes_ignore,
+ None if self.sampling else gt_labels)
+ else:
+ assign_result = self.assigner.assign(
+ flat_cls_scores, flat_bbox_preds, deformable_anchors, gt_bboxes, gt_bboxes_ignore,
+ None if self.sampling else gt_labels)
+
+ sampling_result = self.sampler.sample(assign_result, anchors,
+ gt_bboxes)
+
+ num_valid_anchors = anchors.shape[0]
+ bbox_targets = torch.zeros_like(anchors)
+ bbox_weights = torch.zeros_like(anchors)
+ labels = anchors.new_full((num_valid_anchors, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+
+ if len(pos_inds) > 0:
+ if not self.reg_decoded_bbox:
+ pos_bbox_targets = self.bbox_coder.encode(
+ anchors[pos_inds, :], sampling_result.pos_gt_bboxes) # WARNING: when encode, anchors changed
+ else:
+ pos_bbox_targets = sampling_result.pos_gt_bboxes
+ bbox_targets[pos_inds, :] = pos_bbox_targets
+ bbox_weights[pos_inds, :] = 1.0
+ if gt_labels is None:
+ # Only rpn gives gt_labels as None
+ # Foreground is the first class since v2.5.0
+ labels[pos_inds] = 0
+ else:
+ labels[pos_inds] = gt_labels[
+ sampling_result.pos_assigned_gt_inds]
+ if self.train_cfg.pos_weight <= 0:
+ label_weights[pos_inds] = 1.0
+ else:
+ label_weights[pos_inds] = self.train_cfg.pos_weight
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ # map up to original set of anchors
+ if unmap_outputs:
+ num_total_anchors = flat_anchors.size(0)
+ labels = unmap(
+ labels, num_total_anchors, inside_flags,
+ fill=self.num_classes) # fill bg label
+ label_weights = unmap(label_weights, num_total_anchors,
+ inside_flags)
+ bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
+ bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
+
+
+ return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
+ neg_inds, sampling_result)
+
+ def get_targets(self,
+ cls_scores_list,
+ bbox_pred_list,
+ anchor_list,
+ valid_flag_list,
+ offsets,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ label_channels=1,
+ unmap_outputs=True,
+ return_sampling_results=False):
+ """Compute regression and classification targets for anchors in
+ multiple images.
+
+ Args:
+ anchor_list (list[list[Tensor]]): Multi level anchors of each
+ image. The outer list indicates images, and the inner list
+ corresponds to feature levels of the image. Each element of
+ the inner list is a tensor of shape (num_anchors, 5).
+ valid_flag_list (list[list[Tensor]]): Multi level valid flags of
+ each image. The outer list indicates images, and the inner list
+ corresponds to feature levels of the image. Each element of
+ the inner list is a tensor of shape (num_anchors, )
+ offsets (list[list[Tensor]]): Offsets of DCN.
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ img_metas (list[dict]): Meta info of each image.
+ gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
+ ignored.
+ gt_labels_list (list[Tensor]): Ground truth labels of each box.
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Returns:
+ tuple: Usually returns a tuple containing learning targets.
+
+ - labels_list (list[Tensor]): Labels of each level.
+ - label_weights_list (list[Tensor]): Label weights of each \
+ level.
+ - bbox_targets_list (list[Tensor]): BBox targets of each level.
+ - bbox_weights_list (list[Tensor]): BBox weights of each level.
+ - num_total_pos (int): Number of positive samples in all \
+ images.
+ - num_total_neg (int): Number of negative samples in all \
+ images.
+
+ additional_returns: This function enables user-defined returns from
+ `self._get_targets_single`. These returns are currently refined
+ to properties at each feature map (i.e. having HxW dimension).
+ The results will be concatenated after the end
+ """
+ num_imgs = len(img_metas)
+ assert len(anchor_list) == len(valid_flag_list) == num_imgs
+
+ # anchor number of multi levels, [128^2, 64^2, 32^2, 16^2, 8^2]
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+
+ # concat all level anchors to a single tensor
+
+ concat_anchor_list = []
+ concat_valid_flag_list = []
+ for i in range(num_imgs):
+ assert len(anchor_list[i]) == len(valid_flag_list[i])
+ concat_anchor_list.append(torch.cat(anchor_list[i])) # a list whose len is batch size, each element is a tensor
+ concat_valid_flag_list.append(torch.cat(valid_flag_list[i]))
+
+ # concat all level offsets to a single tensor
+ concat_offsets = []
+ concat_offsets_ori = []
+ lvl_offsets = []
+ lvl_offsets_ori = []
+ factor = img_metas[0]['img_shape'][0]/256
+
+ # concat all level cls/reg results to a single tensor
+ lvl_scores = []
+ lvl_bboxes = []
+ concat_cls_scores_list = []
+ concat_bbox_pred_list = []
+
+ for i in range(len(cls_scores_list)):
+ reshaped_scores = cls_scores_list[i].detach().reshape(num_imgs,self.num_classes,-1)
+ reshaped_bboxes = bbox_pred_list[i].detach().reshape(num_imgs,5,-1)
+ lvl_scores.append(reshaped_scores)
+ lvl_bboxes.append(reshaped_bboxes)
+ cat_lvl_scores = torch.cat(lvl_scores, dim=-1)
+ cat_lvl_bboxes = torch.cat(lvl_bboxes, dim=-1)
+
+ for j in range(num_imgs):
+ concat_cls_scores_list.append(cat_lvl_scores[j,...])
+ concat_bbox_pred_list.append(cat_lvl_bboxes[j,...])
+
+ # multiply a factor to each offset
+ for k in range(len(offsets)):
+ reshaped_offsets_ori = offsets[k].reshape(num_imgs,18,-1)
+ reshaped_offsets = reshaped_offsets_ori*factor
+ lvl_offsets_ori.append(reshaped_offsets_ori)
+ lvl_offsets.append(reshaped_offsets)
+ factor = factor*2
+ cat_lvl_offsets = torch.cat(lvl_offsets, dim=2)
+
+ for j in range(num_imgs):
+ concat_offsets.append(cat_lvl_offsets[j,...])
+
+ # concat the offsets of multi_level
+ cat_lvl_offsets_ori = torch.cat(lvl_offsets_ori, dim=2)
+ for j in range(num_imgs):
+ concat_offsets_ori.append(cat_lvl_offsets_ori[j,...])
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+ if gt_labels_list is None:
+ gt_labels_list = [None for _ in range(num_imgs)]
+ results = multi_apply(
+ self._get_targets_single,
+ concat_cls_scores_list,
+ concat_bbox_pred_list,
+ concat_anchor_list,
+ concat_valid_flag_list,
+ concat_offsets,
+ concat_offsets_ori,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ img_metas,
+ label_channels=label_channels,
+ unmap_outputs=unmap_outputs)
+ (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
+ pos_inds_list, neg_inds_list, sampling_results_list) = results[:7]
+ rest_results = list(results[7:]) # user-added return values
+ # no valid anchors
+ if any([labels is None for labels in all_labels]):
+ return None
+ # sampled anchors of all images
+ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+ num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+ # split targets to a list w.r.t. multiple levels
+ labels_list = images_to_levels(all_labels, num_level_anchors)
+ label_weights_list = images_to_levels(all_label_weights,
+ num_level_anchors)
+ bbox_targets_list = images_to_levels(all_bbox_targets,
+ num_level_anchors)
+ bbox_weights_list = images_to_levels(all_bbox_weights,
+ num_level_anchors)
+ res = (labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, num_total_pos, num_total_neg)
+ if return_sampling_results:
+ res = res + (sampling_results_list, )
+ for i, r in enumerate(rest_results): # user-added return values
+ rest_results[i] = images_to_levels(r, num_level_anchors)
+
+ return res + tuple(rest_results)
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def filter_bboxes(self, cls_scores, bbox_preds):
+ """Filter predicted bounding boxes at each position of the feature
+ maps. Only one bounding boxes with highest score will be left at each
+ position. This filter will be used in R3Det prior to the first feature
+ refinement stage.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 5, H, W)
+
+ Returns:
+ list[list[Tensor]]: best or refined rbboxes of each level \
+ of each image.
+ """
+ num_levels = len(cls_scores)
+ assert num_levels == len(bbox_preds)
+
+ num_imgs = cls_scores[0].size(0)
+
+ for i in range(num_levels):
+ assert num_imgs == cls_scores[i].size(0) == bbox_preds[i].size(0)
+
+ device = cls_scores[0].device
+ featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
+ mlvl_anchors = self.anchor_generator.grid_priors(
+ featmap_sizes, device=device)
+
+ bboxes_list = [[] for _ in range(num_imgs)]
+
+ for lvl in range(num_levels):
+ cls_score = cls_scores[lvl]
+ bbox_pred = bbox_preds[lvl]
+
+ anchors = mlvl_anchors[lvl]
+
+ cls_score = cls_score.permute(0, 2, 3, 1)
+ cls_score = cls_score.reshape(num_imgs, -1, self.num_anchors,
+ self.cls_out_channels)
+
+ cls_score, _ = cls_score.max(dim=-1, keepdim=True)
+ best_ind = cls_score.argmax(dim=-2, keepdim=True)
+ best_ind = best_ind.expand(-1, -1, -1, 5)
+
+ bbox_pred = bbox_pred.permute(0, 2, 3, 1)
+ bbox_pred = bbox_pred.reshape(num_imgs, -1, self.num_anchors, 5)
+ best_pred = bbox_pred.gather(
+ dim=-2, index=best_ind).squeeze(dim=-2)
+
+ anchors = anchors.reshape(-1, self.num_anchors, 5)
+
+ for img_id in range(num_imgs):
+ best_ind_i = best_ind[img_id]
+ best_pred_i = best_pred[img_id]
+ best_anchor_i = anchors.gather(
+ dim=-2, index=best_ind_i).squeeze(dim=-2)
+ best_bbox_i = self.bbox_coder.decode(best_anchor_i,
+ best_pred_i)
+ bboxes_list[img_id].append(best_bbox_i.detach())
+
+ return bboxes_list
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def refine_bboxes(self, cls_scores, bbox_preds):
+ """This function will be used in S2ANet, whose num_anchors=1.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, 5, H, W)
+
+ Returns:
+ list[list[Tensor]]: refined rbboxes of each level of each image.
+ """
+ num_levels = len(cls_scores)
+ assert num_levels == len(bbox_preds)
+ num_imgs = cls_scores[0].size(0)
+ for i in range(num_levels):
+ assert num_imgs == cls_scores[i].size(0) == bbox_preds[i].size(0)
+
+ device = cls_scores[0].device
+ featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
+ mlvl_anchors = self.anchor_generator.grid_priors(
+ featmap_sizes, device=device)
+
+ bboxes_list = [[] for _ in range(num_imgs)]
+
+ for lvl in range(num_levels):
+ bbox_pred = bbox_preds[lvl]
+ bbox_pred = bbox_pred.permute(0, 2, 3, 1)
+ bbox_pred = bbox_pred.reshape(num_imgs, -1, 5)
+ anchors = mlvl_anchors[lvl]
+
+ for img_id in range(num_imgs):
+ bbox_pred_i = bbox_pred[img_id]
+ decode_bbox_i = self.bbox_coder.decode(anchors, bbox_pred_i)
+ bboxes_list[img_id].append(decode_bbox_i.detach())
+
+ return bboxes_list
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def get_bboxes(self,
+ cls_scores,
+ bbox_preds,
+ offsets,
+ img_metas,
+ cfg=None,
+ rescale=False,
+ with_nms=True):
+ """Transform network output for a batch into bbox predictions.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 5, H, W)
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ cfg (mmcv.Config | None): Test / postprocessing configuration,
+ if None, test_cfg would be used
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+
+ Returns:
+ list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is an (n, 6) tensor, where the first 5 columns
+ are bounding box positions (cx, cy, w, h, a) and the
+ 6-th column is a score between 0 and 1. The second item is a
+ (n,) tensor where each item is the predicted class label of the
+ corresponding box.
+
+ Example:
+ >>> import mmcv
+ >>> self = AnchorHead(
+ >>> num_classes=9,
+ >>> in_channels=1,
+ >>> anchor_generator=dict(
+ >>> type='AnchorGenerator',
+ >>> scales=[8],
+ >>> ratios=[0.5, 1.0, 2.0],
+ >>> strides=[4,]))
+ >>> img_metas = [{'img_shape': (32, 32, 3), 'scale_factor': 1}]
+ >>> cfg = mmcv.Config(dict(
+ >>> score_thr=0.00,
+ >>> nms=dict(type='nms', iou_thr=1.0),
+ >>> max_per_img=10))
+ >>> feat = torch.rand(1, 1, 3, 3)
+ >>> cls_score, bbox_pred = self.forward_single(feat)
+ >>> # note the input lists are over different levels, not images
+ >>> cls_scores, bbox_preds = [cls_score], [bbox_pred]
+ >>> result_list = self.get_bboxes(cls_scores, bbox_preds,
+ >>> img_metas, cfg)
+ >>> det_bboxes, det_labels = result_list[0]
+ >>> assert len(result_list) == 1
+ >>> assert det_bboxes.shape[1] == 5
+ >>> assert len(det_bboxes) == len(det_labels) == cfg.max_per_img
+ """
+ assert len(cls_scores) == len(bbox_preds)
+ num_levels = len(cls_scores)
+
+ device = cls_scores[0].device
+ featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
+ mlvl_anchors = self.anchor_generator.grid_priors(
+ featmap_sizes, device=device)
+
+ result_list = []
+ for img_id, _ in enumerate(img_metas):
+ offset_list = [
+ offsets[i][img_id].detach() for i in range(num_levels)
+ ]
+ cls_score_list = [
+ cls_scores[i][img_id].detach() for i in range(num_levels)
+ ]
+ bbox_pred_list = [
+ bbox_preds[i][img_id].detach() for i in range(num_levels)
+ ]
+ img_shape = img_metas[img_id]['img_shape']
+ scale_factor = img_metas[img_id]['scale_factor']
+ if with_nms:
+ # some heads don't support with_nms argument
+ proposals = self._get_bboxes_single(cls_score_list,
+ bbox_pred_list,
+ mlvl_anchors, img_shape,
+ scale_factor, cfg, rescale)
+ else:
+ proposals = self._get_bboxes_single(cls_score_list,
+ bbox_pred_list,
+ mlvl_anchors, img_shape,
+ scale_factor, cfg, rescale,
+ with_nms)
+ result_list.append(proposals)
+
+ return result_list
+
+
+
+