From 92ec70c28c5fc123bbbe2c1bc318e32c47f2c0ce Mon Sep 17 00:00:00 2001 From: zytx121 <592267829@qq.com> Date: Wed, 15 Feb 2023 10:46:50 +0800 Subject: [PATCH 1/3] init --- projects/rotated_mask_rcnn/README.md | 111 ++++++++ .../configs/_base_/dota_mask.py | 117 +++++++++ .../rotated-mask-orcnn_r50_fpn_1x_dota.py | 163 ++++++++++++ .../rotated-mask-rcnn_r50_fpn_1x_dota.py | 154 +++++++++++ projects/rotated_mask_rcnn/module/__init__.py | 4 + .../rotated_mask_rcnn/module/mask_target.py | 153 +++++++++++ .../module/rotated_fcn_mask_head.py | 246 ++++++++++++++++++ .../module/standard_roi_head.py | 72 +++++ 8 files changed, 1020 insertions(+) create mode 100644 projects/rotated_mask_rcnn/README.md create mode 100644 projects/rotated_mask_rcnn/configs/_base_/dota_mask.py create mode 100644 projects/rotated_mask_rcnn/configs/rotated-mask-orcnn_r50_fpn_1x_dota.py create mode 100644 projects/rotated_mask_rcnn/configs/rotated-mask-rcnn_r50_fpn_1x_dota.py create mode 100644 projects/rotated_mask_rcnn/module/__init__.py create mode 100644 projects/rotated_mask_rcnn/module/mask_target.py create mode 100644 projects/rotated_mask_rcnn/module/rotated_fcn_mask_head.py create mode 100644 projects/rotated_mask_rcnn/module/standard_roi_head.py diff --git a/projects/rotated_mask_rcnn/README.md b/projects/rotated_mask_rcnn/README.md new file mode 100644 index 000000000..af2e7f027 --- /dev/null +++ b/projects/rotated_mask_rcnn/README.md @@ -0,0 +1,111 @@ +# Rotated Mask RCNN + +## Description + + + +This project implements a Mask RCNN for rotated boxes. Benefiting from the BoxType design, we only need to modify the code slightly in mmrotate, and then we can support the instance segmentation task. + +## Usage + + + +### Training commands + +In MMRotate's root directory, run the following command to train the model: + +```bash +python tools/train.py projects/rotated_mask_rcnn/configs/rotated-mask-rcnn_r50_fpn_1x_dota.py +``` + +### Testing commands + +In MMRotate's root directory, run the following command to test the model: + +```bash +python tools/test.py projects/rotated_mask_rcnn/configs/rotated-mask-rcnn_r50_fpn_1x_dota.py ${CHECKPOINT_PATH} +``` + +## Results + + + +| Backbone | mAP | Angle | lr schd | Mem (GB) | Inf Time (fps) | Aug | Batch Size | Configs | Download | +| :----------------------: | :---: | :---: | :-----: | :------: | :------------: | :-: | :--------: | :--------------------------------------------------------------------------------: | :----------------------: | +| ResNet50 (1024,1024,200) | 72.71 | le90 | 1x | - | - | - | 2 | [rotated-mask-rcnn_r50_fpn_1x_dota](confsigs/rotated-mask-rcnn_r50_fpn_1x_dota.py) | [model](<>) \| [log](<>) | +| ResNet50 (1024,1024,200) | 70.74 | le90 | 1x | - | - | - | 2 | [rotated-mask-orcnn_r50_fpn_1x_dota](confsigs/rotated-mask-orcnn_r50_fpn_1x_dota.py) | [model](<>) \| [log](<>) | + +Although the rotated box indicator will drop slightly after adding mask head, it may help improve the instance segmentation task. We hope this project can inspire you and welcome you to explore more uses of mmrotate! + +## Citation + + + +```bibtex +@article{He_2017, + title={Mask R-CNN}, + journal={2017 IEEE International Conference on Computer Vision (ICCV)}, + publisher={IEEE}, + author={He, Kaiming and Gkioxari, Georgia and Dollar, Piotr and Girshick, Ross}, + year={2017}, + month={Oct} +} +``` + +## Checklist + + + +- [ ] Milestone 1: PR-ready, and acceptable to be one of the `projects/`. + + - [ ] Finish the code + + + + - [ ] Basic docstrings & proper citation + + + + - [ ] Test-time correctness + + + + - [ ] A full README + + + +- [ ] Milestone 2: Indicates a successful model implementation. + + - [ ] Training-time correctness + + + +- [ ] Milestone 3: Good to be a part of our core package! + + - [ ] Type hints and docstrings + + + + - [ ] Unit tests + + + + - [ ] Code polishing + + + + - [ ] Metafile.yml + + + +- [ ] Move your modules into the core package following the codebase's file hierarchy structure. + + + +- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure. diff --git a/projects/rotated_mask_rcnn/configs/_base_/dota_mask.py b/projects/rotated_mask_rcnn/configs/_base_/dota_mask.py new file mode 100644 index 000000000..4e40faad2 --- /dev/null +++ b/projects/rotated_mask_rcnn/configs/_base_/dota_mask.py @@ -0,0 +1,117 @@ +# dataset settings +dataset_type = 'mmdet.CocoDataset' +data_root = 'data/split_ms_dota/' +file_client_args = dict(backend='disk') + +train_pipeline = [ + dict(type='mmdet.LoadImageFromFile', file_client_args=file_client_args), + dict( + type='mmdet.LoadAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict(type='ConvertMask2BoxType', box_type='rbox', keep_mask=True), + dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), + dict( + type='mmdet.RandomFlip', + prob=0.75, + direction=['horizontal', 'vertical', 'diagonal']), + dict(type='mmdet.PackDetInputs') +] +val_pipeline = [ + dict(type='mmdet.LoadImageFromFile', file_client_args=file_client_args), + dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), + # avoid bboxes being resized + dict( + type='mmdet.LoadAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict(type='ConvertMask2BoxType', box_type='qbox', keep_mask=True), + dict( + type='mmdet.PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'instances')) +] +test_pipeline = [ + dict(type='mmdet.LoadImageFromFile', file_client_args=file_client_args), + dict(type='mmdet.Resize', scale=(1024, 1024), keep_ratio=True), + dict( + type='mmdet.PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +metainfo = dict( + classes=('plane', 'baseball-diamond', 'bridge', 'ground-track-field', + 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court', + 'basketball-court', 'storage-tank', 'soccer-ball-field', + 'roundabout', 'harbor', 'swimming-pool', 'helicopter')) + +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + batch_sampler=None, + dataset=dict( + type=dataset_type, + metainfo=metainfo, + data_root=data_root, + ann_file='train/train.json', + data_prefix=dict(img='train/images/'), + filter_cfg=dict(filter_empty_gt=True), + pipeline=train_pipeline)) +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + metainfo=metainfo, + data_root=data_root, + ann_file='val/val.json', + data_prefix=dict(img='val/images/'), + test_mode=True, + pipeline=val_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = dict(type='RotatedCocoMetric', metric=['bbox', 'segm'], classwise=True) +test_evaluator = val_evaluator + +# inference on test dataset and format the output results +# for submission. Note: the test set has no annotation. +# test_dataloader = dict( +# batch_size=1, +# num_workers=2, +# persistent_workers=True, +# drop_last=False, +# sampler=dict(type='DefaultSampler', shuffle=False), +# dataset=dict( +# type=dataset_type, +# ann_file='test/test.json', +# data_prefix=dict(img='test/images/'), +# test_mode=True, +# pipeline=test_pipeline)) +# test_evaluator = dict( +# type='DOTAMetric', +# format_only=True, +# merge_patches=True, +# outfile_prefix='./work_dirs/dota/Task1') + +# If you don't have test.json, please use this test_dataloader. +# test_dataloader = dict( +# batch_size=1, +# num_workers=2, +# persistent_workers=True, +# drop_last=False, +# sampler=dict(type='DefaultSampler', shuffle=False), +# dataset=dict( +# type='DOTADataset', +# data_root=data_root, +# data_prefix=dict(img_path='test/images/'), +# img_shape=(1024, 1024), +# test_mode=True, +# pipeline=test_pipeline)) \ No newline at end of file diff --git a/projects/rotated_mask_rcnn/configs/rotated-mask-orcnn_r50_fpn_1x_dota.py b/projects/rotated_mask_rcnn/configs/rotated-mask-orcnn_r50_fpn_1x_dota.py new file mode 100644 index 000000000..5bb623867 --- /dev/null +++ b/projects/rotated_mask_rcnn/configs/rotated-mask-orcnn_r50_fpn_1x_dota.py @@ -0,0 +1,163 @@ +_base_ = [ + './_base_/dota_mask.py', 'mmrotate::_base_/schedules/schedule_1x.py', + 'mmrotate::_base_/default_runtime.py' +] + +custom_imports = dict(imports=['projects.mask_rcnn_rbox.module']) + +angle_version = 'le90' +model = dict( + type='mmdet.MaskRCNN', + data_preprocessor=dict( + type='mmdet.DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_mask=True, + pad_size_divisor=32, + boxtype2tensor=False), + backbone=dict( + type='mmdet.ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='mmdet.FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='OrientedRPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='mmdet.AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64], + use_box_type=True), + bbox_coder=dict( + type='MidpointOffsetCoder', + angle_version=angle_version, + target_means=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + target_stds=[1.0, 1.0, 1.0, 1.0, 0.5, 0.5]), + loss_cls=dict( + type='mmdet.CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict( + type='mmdet.SmoothL1Loss', + beta=0.1111111111111111, + loss_weight=1.0)), + roi_head=dict( + type='mmdet.StandardRoIHead', + bbox_roi_extractor=dict( + type='RotatedSingleRoIExtractor', + roi_layer=dict( + type='RoIAlignRotated', + out_size=7, + sample_num=2, + clockwise=True), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='mmdet.Shared2FCBBoxHead', + predict_box_type='rbox', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=15, + reg_predictor_cfg=dict(type='mmdet.Linear'), + cls_predictor_cfg=dict(type='mmdet.Linear'), + bbox_coder=dict( + type='DeltaXYWHTRBBoxCoder', + angle_version=angle_version, + norm_factor=None, + edge_swap=True, + proj_xy=True, + target_means=(.0, .0, .0, .0, .0), + target_stds=(0.1, 0.1, 0.2, 0.2, 0.1)), + reg_class_agnostic=True, + loss_cls=dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict( + type='mmdet.SmoothL1Loss', beta=1.0, loss_weight=1.0)), + mask_roi_extractor=dict( + type='RotatedSingleRoIExtractor', + roi_layer=dict( + type='RoIAlignRotated', + out_size=14, + sample_num=2, + clockwise=True), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + mask_head=dict( + type='ORCNNFCNMaskHead', + num_convs=4, + in_channels=256, + conv_out_channels=256, + num_classes=15, + loss_mask=dict( + type='mmdet.CrossEntropyLoss', use_mask=True, loss_weight=1.0))), + train_cfg=dict( + rpn=dict( + assigner=dict( + type='mmdet.MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1, + iou_calculator=dict(type='RBbox2HBboxOverlaps2D')), + sampler=dict( + type='mmdet.RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=2000, + nms=dict(type='nms', iou_threshold=0.8), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='mmdet.MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=True, + ignore_iof_thr=-1, + iou_calculator=dict(type='RBboxOverlaps2D')), + sampler=dict( + type='mmdet.RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=2000, + max_per_img=2000, + nms=dict(type='nms', iou_threshold=0.8), + min_bbox_size=0), + rcnn=dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms_rotated', iou_threshold=0.1), + max_per_img=2000, + mask_thr_binary=0.5))) + +optim_wrapper = dict(optimizer=dict(lr=0.005)) diff --git a/projects/rotated_mask_rcnn/configs/rotated-mask-rcnn_r50_fpn_1x_dota.py b/projects/rotated_mask_rcnn/configs/rotated-mask-rcnn_r50_fpn_1x_dota.py new file mode 100644 index 000000000..7f03b7f5d --- /dev/null +++ b/projects/rotated_mask_rcnn/configs/rotated-mask-rcnn_r50_fpn_1x_dota.py @@ -0,0 +1,154 @@ +_base_ = [ + './_base_/dota_mask.py', 'mmrotate::_base_/schedules/schedule_1x.py', + 'mmrotate::_base_/default_runtime.py' +] + +custom_imports = dict(imports=['projects.mask_rcnn_rbox.module']) + +angle_version = 'le90' +model = dict( + type='mmdet.MaskRCNN', + data_preprocessor=dict( + type='mmdet.DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_mask=True, + pad_size_divisor=32, + boxtype2tensor=False), + backbone=dict( + type='mmdet.ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='mmdet.FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + rpn_head=dict( + type='mmdet.RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='mmdet.AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64], + use_box_type=True), + bbox_coder=dict( + type='DeltaXYWHHBBoxCoder', + target_means=[0.0, 0.0, 0.0, 0.0], + target_stds=[1.0, 1.0, 1.0, 1.0], + use_box_type=True), + loss_cls=dict( + type='mmdet.CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict( + type='mmdet.SmoothL1Loss', + beta=0.1111111111111111, + loss_weight=1.0)), + roi_head=dict( + type='RStandardRoIHead', + bbox_roi_extractor=dict( + type='mmdet.SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='mmdet.Shared2FCBBoxHead', + predict_box_type='rbox', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=15, + reg_predictor_cfg=dict(type='mmdet.Linear'), + cls_predictor_cfg=dict(type='mmdet.Linear'), + bbox_coder=dict( + type='DeltaXYWHTHBBoxCoder', + angle_version=angle_version, + norm_factor=2, + edge_swap=True, + target_means=(.0, .0, .0, .0, .0), + target_stds=(0.1, 0.1, 0.2, 0.2, 0.1)), + reg_class_agnostic=True, + loss_cls=dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0), + loss_bbox=dict( + type='mmdet.SmoothL1Loss', beta=1.0, loss_weight=1.0)), + mask_roi_extractor=dict( + type='mmdet.SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + mask_head=dict( + type='RotatedFCNMaskHead', + num_convs=4, + in_channels=256, + conv_out_channels=256, + num_classes=15, + loss_mask=dict( + type='mmdet.CrossEntropyLoss', use_mask=True, loss_weight=1.0))), + train_cfg=dict( + rpn=dict( + assigner=dict( + type='mmdet.MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1, + iou_calculator=dict(type='RBbox2HBboxOverlaps2D')), + sampler=dict( + type='mmdet.RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=0, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=2000, + max_per_img=2000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='mmdet.MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=True, + ignore_iof_thr=-1, + iou_calculator=dict(type='RBbox2HBboxOverlaps2D')), + sampler=dict( + type='mmdet.RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + mask_size=28, + pos_weight=-1, + debug=False)), + test_cfg=dict( + rpn=dict( + nms_pre=2000, + max_per_img=2000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms_rotated', iou_threshold=0.1), + max_per_img=2000, + mask_thr_binary=0.5))) + +optim_wrapper = dict(optimizer=dict(lr=0.005)) diff --git a/projects/rotated_mask_rcnn/module/__init__.py b/projects/rotated_mask_rcnn/module/__init__.py new file mode 100644 index 000000000..9e601aaab --- /dev/null +++ b/projects/rotated_mask_rcnn/module/__init__.py @@ -0,0 +1,4 @@ +from .rotated_fcn_mask_head import RotatedFCNMaskHead, ORCNNFCNMaskHead +from .standard_roi_head import RStandardRoIHead + +__all__ = ['RotatedFCNMaskHead', 'RStandardRoIHead', 'ORCNNFCNMaskHead'] diff --git a/projects/rotated_mask_rcnn/module/mask_target.py b/projects/rotated_mask_rcnn/module/mask_target.py new file mode 100644 index 000000000..97ec950cd --- /dev/null +++ b/projects/rotated_mask_rcnn/module/mask_target.py @@ -0,0 +1,153 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +from torch.nn.modules.utils import _pair +from mmcv.ops.roi_align_rotated import roi_align_rotated +from mmrotate.structures import rbox2hbox + + +def mask_target(pos_proposals_list, pos_assigned_gt_inds_list, gt_masks_list, + cfg): + """Compute mask target for positive proposals in multiple images. + Args: + pos_proposals_list (list[Tensor]): Positive proposals in multiple + images, each has shape (num_pos, 4). + pos_assigned_gt_inds_list (list[Tensor]): Assigned GT indices for each + positive proposals, each has shape (num_pos,). + gt_masks_list (list[:obj:`BaseInstanceMasks`]): Ground truth masks of + each image. + cfg (dict): Config dict that specifies the mask size. + Returns: + Tensor: Mask target of each image, has shape (num_pos, w, h). + Example: + >>> from mmengine.config import Config + >>> import mmdet + >>> from mmdet.data_elements.mask import BitmapMasks + >>> from mmdet.data_elements.mask.mask_target import * + >>> H, W = 17, 18 + >>> cfg = Config({'mask_size': (13, 14)}) + >>> rng = np.random.RandomState(0) + >>> # Positive proposals (tl_x, tl_y, br_x, br_y) for each image + >>> pos_proposals_list = [ + >>> torch.Tensor([ + >>> [ 7.2425, 5.5929, 13.9414, 14.9541], + >>> [ 7.3241, 3.6170, 16.3850, 15.3102], + >>> ]), + >>> torch.Tensor([ + >>> [ 4.8448, 6.4010, 7.0314, 9.7681], + >>> [ 5.9790, 2.6989, 7.4416, 4.8580], + >>> [ 0.0000, 0.0000, 0.1398, 9.8232], + >>> ]), + >>> ] + >>> # Corresponding class index for each proposal for each image + >>> pos_assigned_gt_inds_list = [ + >>> torch.LongTensor([7, 0]), + >>> torch.LongTensor([5, 4, 1]), + >>> ] + >>> # Ground truth mask for each true object for each image + >>> gt_masks_list = [ + >>> BitmapMasks(rng.rand(8, H, W), height=H, width=W), + >>> BitmapMasks(rng.rand(6, H, W), height=H, width=W), + >>> ] + >>> mask_targets = mask_target( + >>> pos_proposals_list, pos_assigned_gt_inds_list, + >>> gt_masks_list, cfg) + >>> assert mask_targets.shape == (5,) + cfg['mask_size'] + """ + cfg_list = [cfg for _ in range(len(pos_proposals_list))] + mask_targets = map(mask_target_single, pos_proposals_list, + pos_assigned_gt_inds_list, gt_masks_list, cfg_list) + mask_targets = list(mask_targets) + if len(mask_targets) > 0: + mask_targets = torch.cat(mask_targets) + return mask_targets + + +def mask_target_single(pos_proposals, pos_assigned_gt_inds, gt_masks, cfg): + """Compute mask target for each positive proposal in the image. + Args: + pos_proposals (Tensor): Positive proposals. + pos_assigned_gt_inds (Tensor): Assigned GT inds of positive proposals. + gt_masks (:obj:`BaseInstanceMasks`): GT masks in the format of Bitmap + or Polygon. + cfg (dict): Config dict that indicate the mask size. + Returns: + Tensor: Mask target of each positive proposals in the image. + Example: + >>> from mmengine.config import Config + >>> import mmdet + >>> from mmdet.data_elements.mask import BitmapMasks + >>> from mmdet.data_elements.mask.mask_target import * # NOQA + >>> H, W = 32, 32 + >>> cfg = Config({'mask_size': (7, 11)}) + >>> rng = np.random.RandomState(0) + >>> # Masks for each ground truth box (relative to the image) + >>> gt_masks_data = rng.rand(3, H, W) + >>> gt_masks = BitmapMasks(gt_masks_data, height=H, width=W) + >>> # Predicted positive boxes in one image + >>> pos_proposals = torch.FloatTensor([ + >>> [ 16.2, 5.5, 19.9, 20.9], + >>> [ 17.3, 13.6, 19.3, 19.3], + >>> [ 14.8, 16.4, 17.0, 23.7], + >>> [ 0.0, 0.0, 16.0, 16.0], + >>> [ 4.0, 0.0, 20.0, 16.0], + >>> ]) + >>> # For each predicted proposal, its assignment to a gt mask + >>> pos_assigned_gt_inds = torch.LongTensor([0, 1, 2, 1, 1]) + >>> mask_targets = mask_target_single( + >>> pos_proposals, pos_assigned_gt_inds, gt_masks, cfg) + >>> assert mask_targets.shape == (5,) + cfg['mask_size'] + """ + device = pos_proposals.device + mask_size = _pair(cfg.mask_size) + binarize = not cfg.get('soft_mask_target', False) + num_pos = pos_proposals.size(0) + if num_pos > 0: + pos_proposals = rbox2hbox(pos_proposals.tensor) # convert to hbox + proposals_np = pos_proposals.cpu().numpy() + maxh, maxw = gt_masks.height, gt_masks.width + proposals_np[:, [0, 2]] = np.clip(proposals_np[:, [0, 2]], 0, maxw) + proposals_np[:, [1, 3]] = np.clip(proposals_np[:, [1, 3]], 0, maxh) + # proposals_np[:, 2] = np.clip(proposals_np[:, 2], 0, maxw) + # proposals_np[:, 3] = np.clip(proposals_np[:, 3], 0, maxh) + pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy() + + mask_targets = gt_masks.crop_and_resize( + proposals_np, + mask_size, + device=device, + inds=pos_assigned_gt_inds, + binarize=binarize).to_ndarray() + + # out_h, out_w = out_shape + # resized_masks = [] + # for i in range(len(proposals_np)): + # mask = gt_masks[pos_assigned_gt_inds[i]] + # bbox = proposals_np[i, :] + # # x1, y1, x2, y2 = bbox + # x, y, w, h, t = bbox + # w = np.maximum(w, 1) + # h = np.maximum(h, 1) + # h_scale = out_h / max(h, 0.1) # avoid too large scale + # w_scale = out_w / max(w, 0.1) + + # resized_mask = [] + # for p in mask: + # p = p.copy() + # # crop + # # pycocotools will clip the boundary + # p[0::2] = p[0::2] - bbox[0] + # p[1::2] = p[1::2] - bbox[1] + + # # resize + # p[0::2] = p[0::2] * w_scale + # p[1::2] = p[1::2] * h_scale + # resized_mask.append(p) + # resized_masks.append(resized_mask) + # mask_targets = PlygonMasks(resized_masks, *out_shape) + + mask_targets = torch.from_numpy(mask_targets).float().to(device) + else: + mask_targets = pos_proposals.new_zeros((0, ) + mask_size) + + return mask_targets \ No newline at end of file diff --git a/projects/rotated_mask_rcnn/module/rotated_fcn_mask_head.py b/projects/rotated_mask_rcnn/module/rotated_fcn_mask_head.py new file mode 100644 index 000000000..bf9247197 --- /dev/null +++ b/projects/rotated_mask_rcnn/module/rotated_fcn_mask_head.py @@ -0,0 +1,246 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.config import ConfigDict +from torch import Tensor + +from mmdet.models.task_modules.samplers import SamplingResult +from mmdet.models.roi_heads.mask_heads import FCNMaskHead +from mmdet.utils import ConfigType, InstanceList, OptConfigType, OptMultiConfig +from mmrotate.registry import MODELS +from .mask_target import mask_target +from mmdet.structures.bbox import scale_boxes +from mmrotate.structures.bbox import RotatedBoxes +from mmrotate.structures import rbox2hbox + +BYTES_PER_FLOAT = 4 +# TODO: This memory limit may be too much or too little. It would be better to +# determine it based on available resources. +GPU_MEM_LIMIT = 1024**3 # 1 GB memory limit + + +@MODELS.register_module() +class RotatedFCNMaskHead(FCNMaskHead): + + def _predict_by_feat_single(self, + mask_preds: Tensor, + bboxes: Tensor, + labels: Tensor, + img_meta: dict, + rcnn_test_cfg: ConfigDict, + rescale: bool = False, + activate_map: bool = False) -> Tensor: + """Get segmentation masks from mask_preds and bboxes. + Args: + mask_preds (Tensor): Predicted foreground masks, has shape + (n, num_classes, h, w). + bboxes (Tensor): Predicted bboxes, has shape (n, 4) + labels (Tensor): Labels of bboxes, has shape (n, ) + img_meta (dict): image information. + rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + activate_map (book): Whether get results with augmentations test. + If True, the `mask_preds` will not process with sigmoid. + Defaults to False. + Returns: + Tensor: Encoded masks, has shape (n, img_w, img_h) + Example: + >>> from mmengine.config import Config + >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA + >>> N = 7 # N = number of extracted ROIs + >>> C, H, W = 11, 32, 32 + >>> # Create example instance of FCN Mask Head. + >>> self = FCNMaskHead(num_classes=C, num_convs=0) + >>> inputs = torch.rand(N, self.in_channels, H, W) + >>> mask_preds = self.forward(inputs) + >>> # Each input is associated with some bounding box + >>> bboxes = torch.Tensor([[1, 1, 42, 42 ]] * N) + >>> labels = torch.randint(0, C, size=(N,)) + >>> rcnn_test_cfg = Config({'mask_thr_binary': 0, }) + >>> ori_shape = (H * 4, W * 4) + >>> scale_factor = (1, 1) + >>> rescale = False + >>> img_meta = {'scale_factor': scale_factor, + ... 'ori_shape': ori_shape} + >>> # Encoded masks are a list for each category. + >>> encoded_masks = self._get_seg_masks_single( + ... mask_preds, bboxes, labels, + ... img_meta, rcnn_test_cfg, rescale) + >>> assert encoded_masks.size()[0] == N + >>> assert encoded_masks.size()[1:] == ori_shape + """ + bboxes = RotatedBoxes(bboxes) + scale_factor = bboxes.new_tensor(img_meta['scale_factor']).repeat( + (1, 2)) + img_h, img_w = img_meta['ori_shape'][:2] + device = bboxes.device + + if not activate_map: + mask_preds = mask_preds.sigmoid() + else: + # In AugTest, has been activated before + mask_preds = bboxes.new_tensor(mask_preds) + + if rescale: # in-placed rescale the bboxes + scale_factor = [1 / s for s in img_meta['scale_factor']] + bboxes = scale_boxes(bboxes, scale_factor) + else: + w_scale, h_scale = scale_factor[0, 0], scale_factor[0, 1] + img_h = np.round(img_h * h_scale.item()).astype(np.int32) + img_w = np.round(img_w * w_scale.item()).astype(np.int32) + + N = len(mask_preds) + # The actual implementation split the input into chunks, + # and paste them chunk by chunk. + if device.type == 'cpu': + # CPU is most efficient when they are pasted one by one with + # skip_empty=True, so that it performs minimal number of + # operations. + num_chunks = N + else: + # GPU benefits from parallelism for larger chunks, + # but may have memory issue + # the types of img_w and img_h are np.int32, + # when the image resolution is large, + # the calculation of num_chunks will overflow. + # so we need to change the types of img_w and img_h to int. + # See https://github.com/open-mmlab/mmdetection/pull/5191 + num_chunks = int( + np.ceil(N * int(img_h) * int(img_w) * BYTES_PER_FLOAT / + GPU_MEM_LIMIT)) + assert (num_chunks <= + N), 'Default GPU_MEM_LIMIT is too small; try increasing it' + chunks = torch.chunk(torch.arange(N, device=device), num_chunks) + + threshold = rcnn_test_cfg.mask_thr_binary + im_mask = torch.zeros( + N, + img_h, + img_w, + device=device, + dtype=torch.bool if threshold >= 0 else torch.uint8) + + if not self.class_agnostic: + mask_preds = mask_preds[range(N), labels][:, None] + + for inds in chunks: + masks_chunk, spatial_inds = _do_paste_mask( + mask_preds[inds], + bboxes[inds], + img_h, + img_w, + skip_empty=device.type == 'cpu') + + if threshold >= 0: + masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool) + else: + # for visualization and debugging + masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8) + + im_mask[(inds, ) + spatial_inds] = masks_chunk + return im_mask + +def _do_paste_mask(masks: Tensor, + boxes: Tensor, + img_h: int, + img_w: int, + skip_empty: bool = True) -> tuple: + """Paste instance masks according to boxes. + This implementation is modified from + https://github.com/facebookresearch/detectron2/ + Args: + masks (Tensor): N, 1, H, W + boxes (Tensor): N, 4 + img_h (int): Height of the image to be pasted. + img_w (int): Width of the image to be pasted. + skip_empty (bool): Only paste masks within the region that + tightly bound all boxes, and returns the results this region only. + An important optimization for CPU. + Returns: + tuple: (Tensor, tuple). The first item is mask tensor, the second one + is the slice object. + If skip_empty == False, the whole image will be pasted. It will + return a mask of shape (N, img_h, img_w) and an empty tuple. + If skip_empty == True, only area around the mask will be pasted. + A mask of shape (N, h', w') and its start and end coordinates + in the original image will be returned. + """ + boxes = rbox2hbox(boxes.tensor) + # On GPU, paste all masks together (up to chunk size) + # by using the entire image to sample the masks + # Compared to pasting them one by one, + # this has more operations but is faster on COCO-scale dataset. + device = masks.device + if skip_empty: + x0_int, y0_int = torch.clamp( + boxes.min(dim=0).values.floor()[:2] - 1, + min=0).to(dtype=torch.int32) + x1_int = torch.clamp( + boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32) + y1_int = torch.clamp( + boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32) + else: + x0_int, y0_int = 0, 0 + x1_int, y1_int = img_w, img_h + x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1 + + N = masks.shape[0] + + img_y = torch.arange(y0_int, y1_int, device=device).to(torch.float32) + 0.5 + img_x = torch.arange(x0_int, x1_int, device=device).to(torch.float32) + 0.5 + img_y = (img_y - y0) / (y1 - y0) * 2 - 1 + img_x = (img_x - x0) / (x1 - x0) * 2 - 1 + # img_x, img_y have shapes (N, w), (N, h) + # IsInf op is not supported with ONNX<=1.7.0 + if not torch.onnx.is_in_onnx_export(): + if torch.isinf(img_x).any(): + inds = torch.where(torch.isinf(img_x)) + img_x[inds] = 0 + if torch.isinf(img_y).any(): + inds = torch.where(torch.isinf(img_y)) + img_y[inds] = 0 + + gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1)) + gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1)) + grid = torch.stack([gx, gy], dim=3) + + img_masks = F.grid_sample( + masks.to(dtype=torch.float32), grid, align_corners=False) + + if skip_empty: + return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int)) + else: + return img_masks[:, 0], () + +@MODELS.register_module() +class ORCNNFCNMaskHead(RotatedFCNMaskHead): + + def get_targets(self, sampling_results: List[SamplingResult], + batch_gt_instances: InstanceList, + rcnn_train_cfg: ConfigDict) -> Tensor: + """Calculate the ground truth for all samples in a batch according to + the sampling_results. + Args: + sampling_results (List[obj:SamplingResult]): Assign results of + all images in a batch after sampling. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes``, ``labels``, and + ``masks`` attributes. + rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN. + Returns: + Tensor: Mask target of each positive proposals in the image. + """ + pos_proposals = [res.pos_priors for res in sampling_results] + pos_assigned_gt_inds = [ + res.pos_assigned_gt_inds for res in sampling_results + ] + gt_masks = [res.masks for res in batch_gt_instances] + mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds, + gt_masks, rcnn_train_cfg) + return mask_targets diff --git a/projects/rotated_mask_rcnn/module/standard_roi_head.py b/projects/rotated_mask_rcnn/module/standard_roi_head.py new file mode 100644 index 000000000..74cc3f203 --- /dev/null +++ b/projects/rotated_mask_rcnn/module/standard_roi_head.py @@ -0,0 +1,72 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Optional, Tuple + +import torch +from torch import Tensor + +from mmrotate.registry import MODELS +from mmdet.structures.bbox import bbox2roi +from mmdet.utils import InstanceList +from mmdet.models.utils import empty_instances + +from mmdet.models.roi_heads import StandardRoIHead +from mmrotate.structures import rbox2hbox + +@MODELS.register_module() +class RStandardRoIHead(StandardRoIHead): + + def predict_mask(self, + x: Tuple[Tensor], + batch_img_metas: List[dict], + results_list: InstanceList, + rescale: bool = False) -> InstanceList: + """Perform forward propagation of the mask head and predict detection + results on the features of the upstream network. + + Args: + x (tuple[Tensor]): Feature maps of all scale level. + batch_img_metas (list[dict]): List of image information. + results_list (list[:obj:`InstanceData`]): Detection results of + each image. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + + Returns: + list[:obj:`InstanceData`]: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, H, W). + """ + # don't need to consider aug_test. + bboxes = [rbox2hbox(res.bboxes) for res in results_list] + mask_rois = bbox2roi(bboxes) + if mask_rois.shape[0] == 0: + results_list = empty_instances( + batch_img_metas, + mask_rois.device, + task_type='mask', + instance_results=results_list, + mask_thr_binary=self.test_cfg.mask_thr_binary) + return results_list + + mask_results = self._mask_forward(x, mask_rois) + mask_preds = mask_results['mask_preds'] + # split batch mask prediction back to each image + num_mask_rois_per_img = [len(res) for res in results_list] + mask_preds = mask_preds.split(num_mask_rois_per_img, 0) + + # TODO: Handle the case where rescale is false + results_list = self.mask_head.predict_by_feat( + mask_preds=mask_preds, + results_list=results_list, + batch_img_metas=batch_img_metas, + rcnn_test_cfg=self.test_cfg, + rescale=rescale) + return results_list From 920b732ed0da3586bdd1f8b6dcffb88dc6a0a375 Mon Sep 17 00:00:00 2001 From: zytx121 <592267829@qq.com> Date: Wed, 15 Feb 2023 11:01:53 +0800 Subject: [PATCH 2/3] upload img --- projects/rotated_mask_rcnn/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/projects/rotated_mask_rcnn/README.md b/projects/rotated_mask_rcnn/README.md index af2e7f027..9f8feae55 100644 --- a/projects/rotated_mask_rcnn/README.md +++ b/projects/rotated_mask_rcnn/README.md @@ -7,6 +7,10 @@ Author: @xxx. This is an implementation of \[XXX\]. --> This project implements a Mask RCNN for rotated boxes. Benefiting from the BoxType design, we only need to modify the code slightly in mmrotate, and then we can support the instance segmentation task. +