diff --git a/README.md b/README.md index cf99525..2d36a03 100644 --- a/README.md +++ b/README.md @@ -24,9 +24,12 @@ For instructions on installation, pretrained models, training and evaluation, pl ### Oriented Object Detection -| Detector | mAP | Configs | Download | None | +| Detector | mAP | Configs | Download | Note | | :--------: |:---:|:-------:|:--------:|:----:| +| Deformable DETR | 17.1 | [deformable_detr_r50_1x_rsg](configs/ars_detr/deformable_detr_r50_1x_rsg.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/deformable_detr_r50_1x_rsg.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/deformable_detr_r50_1x_rsg-fe862bb3.pth?download=true) | +| ARS-DETR | 28.1 | [dn_arw_arm_arcsl_rdetr_r50_1x_rsg](configs/ars_detr/dn_arw_arm_arcsl_rdetr_r50_1x_rsg.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/dn_arw_arm_arcsl_rdetr_r50_1x_rsg.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/dn_arw_arm_arcsl_rdetr_r50_1x_rsg-cbb34897.pth?download=true) | | RetinaNet | 21.8 | [rotated_retinanet_hbb_r50_fpn_1x_rsg_oc](configs/rotated_retinanet/rotated_retinanet_hbb_r50_fpn_1x_rsg_oc.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/rotated_retinanet_hbb_r50_fpn_1x_rsg_oc.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/rotated_retinanet_hbb_r50_fpn_1x_rsg_oc-3ec35d77.pth?download=true) | +| ATSS | 20.4 | [rotated_atss_hbb_r50_fpn_1x_rsg_oc](configs/rotated_atss/rotated_atss_hbb_r50_fpn_1x_rsg_oc.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/rotated_atss_hbb_r50_fpn_1x_rsg_oc.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/rotated_atss_hbb_r50_fpn_1x_rsg_oc-f65f07c2.pth?download=true) | | KLD | 25.0 | [rotated_retinanet_hbb_kld_r50_fpn_1x_rsg_oc](configs/kld/rotated_retinanet_hbb_kld_r50_fpn_1x_rsg_oc.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/rotated_retinanet_hbb_kld_r50_fpn_1x_rsg_oc.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/rotated_retinanet_hbb_kld_r50_fpn_1x_rsg_oc-343a0b83.pth?download=true) | | GWD | 25.3 | [rotated_retinanet_hbb_gwd_r50_fpn_1x_rsg_oc](configs/gwd/rotated_retinanet_hbb_gwd_r50_fpn_1x_rsg_oc.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/rotated_retinanet_hbb_gwd_r50_fpn_1x_rsg_oc.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/rotated_retinanet_hbb_gwd_r50_fpn_1x_rsg_oc-566d2398.pth?download=true) | | KFIoU | 25.5 | [rotated_retinanet_hbb_kfiou_r50_fpn_1x_rsg_oc](configs/kfiou/rotated_retinanet_hbb_kfiou_r50_fpn_1x_rsg_oc.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/rotated_retinanet_hbb_kfiou_r50_fpn_1x_rsg_oc.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/rotated_retinanet_hbb_kfiou_r50_fpn_1x_rsg_oc-198081a6.pth?download=true) | @@ -41,7 +44,8 @@ For instructions on installation, pretrained models, training and evaluation, pl | Gliding Vertex | 30.7 | [gliding_vertex_r50_fpn_1x_rsg_le90](configs/gliding_vertex/gliding_vertex_r50_fpn_1x_rsg_le90.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/gliding_vertex_r50_fpn_1x_rsg_le90.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/gliding_vertex_r50_fpn_1x_rsg_le90-5c0bc879.pth?download=true) | | Oriented RCNN | 33.2 | [oriented_rcnn_r50_fpn_1x_rsg_le90](configs/oriented_rcnn/oriented_rcnn_r50_fpn_1x_rsg_le90.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/oriented_rcnn_r50_fpn_1x_rsg_le90.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/oriented_rcnn_r50_fpn_1x_rsg_le90-0b66f6a4.pth?download=true) | | RoI Transformer | 35.7 | [roi_trans_r50_fpn_1x_rsg_le90](configs/roi_trans/roi_trans_r50_fpn_1x_rsg_le90.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/roi_trans_r50_fpn_1x_rsg_le90.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/roi_trans_r50_fpn_1x_rsg_le90-e42f64d6.pth?download=true) | -| ReDet | 39.1 | [redet_re50_refpn_1x_rsg_le90](configs/redet/redet_re50_refpn_1x_rsg_le90.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/redet_re50_refpn_1x_rsg_le90.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/redet_re50_refpn_1x_rsg_le90-d163f450.pth?download=true) | +| ReDet | 39.1 | [redet_re50_refpn_1x_rsg_le90](configs/redet/redet_re50_refpn_1x_rsg_le90.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/redet_re50_refpn_1x_rsg_le90.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/redet_re50_refpn_1x_rsg_le90-d163f450.pth?download=true) | [ReResNet50](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/re_resnet50_c8_batch256-25b16846.pth?download=true) | +| Oriented RCNN | 40.7 | [oriented_rcnn_swin-l_fpn_1x_rsg_le90](configs/oriented_rcnn/oriented_rcnn_swin-l_fpn_1x_rsg_le90.py) | [log](https://huggingface.co/yangxue/RSG-MMRotate/raw/main/oriented_rcnn_swin-l_fpn_1x_rsg_le90.log) \| [ckpt](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/oriented_rcnn_swin-l_fpn_1x_rsg_le90-fe6f9e2d.pth?download=true) | [Swin-L](https://huggingface.co/yangxue/RSG-MMRotate/resolve/main/swin_large_patch4_window7_224_22k_20220412-aeecf2aa.pth?download=true) | ## 🖊️ Citation @@ -50,7 +54,7 @@ If you find this work helpful for your research, please consider giving this rep ```bibtex @article{li2024scene, title={Scene Graph Generation in Large-Size VHR Satellite Imagery: A Large-Scale Dataset and A Context-Aware Approach}, - author={L1, Yansheng and Wang, Linlin and Wang, Tingzhu and Wang, Qi and Sun, Xian and Yang, Xue and Wang, Wenbin and Luo, Junwei and Deng, Youming and Li, Haifeng and Dang, Bo and Zhang, Yongjun and Yan Junchi}, + author={L1, Yansheng and Wang, Linlin and Wang, Tingzhu and Yang, Xue and Wang, Qi and Sun, Xian and Wang, Wenbin and Luo, Junwei and Deng, Youming and Li, Haifeng and Dang, Bo and Zhang, Yongjun and Yan Junchi}, journal={arXiv preprint arXiv:}, year={2024} } diff --git a/configs/h2rbox/h2rbox_r50_fpn_1x_rsg_le90_adamw5e-5.py b/configs/h2rbox/h2rbox_r50_fpn_1x_rsg_le90_adamw5e-5.py new file mode 100644 index 0000000..cc70a8e --- /dev/null +++ b/configs/h2rbox/h2rbox_r50_fpn_1x_rsg_le90_adamw5e-5.py @@ -0,0 +1,135 @@ +_base_ = [ + '../_base_/datasets/rsg.py', + '../_base_/schedules/schedule_1x.py', + '../_base_/default_runtime.py' +] +angle_version = 'le90' + +# model settings +model = dict( + type='H2RBox', + crop_size=(1024, 1024), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + zero_init_residual=False, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', # use P5 + num_outs=5, + relu_before_extra_convs=True), + bbox_head=dict( + type='H2RBoxHead', + num_classes=48, + in_channels=256, + stacked_convs=4, + feat_channels=256, + strides=[8, 16, 32, 64, 128], + center_sampling=True, + center_sample_radius=1.5, + norm_on_bbox=True, + centerness_on_reg=True, + separate_angle=False, + scale_angle=True, + reassigner='one2one', + rect_classes=[4, 44], + bbox_coder=dict( + type='DistanceAnglePointCoder', angle_version=angle_version), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='IoULoss', loss_weight=1.0), + loss_bbox_aug=dict( + type='H2RBoxLoss', + loss_weight=0.4, + center_loss_cfg=dict(type='L1Loss', loss_weight=0.0), + shape_loss_cfg=dict(type='IoULoss', loss_weight=1.0), + angle_loss_cfg=dict(type='L1Loss', loss_weight=1.0)), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), + # training and testing settings + train_cfg=None, + test_cfg=dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(iou_thr=0.1), + max_per_img=2000)) + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='FilterNoCenterObject', img_scale=(1024, 1024), crop_size=(1024, 1024)), + dict(type='RResize', img_scale=(1024, 1024)), + dict( + type='RRandomFlip', + flip_ratio=[0.25, 0.25, 0.25], + direction=['horizontal', 'vertical', 'diagonal'], + version=angle_version), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=1), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + scale_factor=1.0, + flip=False, + transforms=[ + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=64), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img']) + ]) +] + +data_root = 'data/RSG/' +data = dict( + train=dict(type='RSGWSOODDataset', pipeline=train_pipeline, + ann_file=data_root + 'train/annfiles/', + img_prefix=data_root + 'train/images/', + version=angle_version), + val=dict(type='RSGWSOODDataset', pipeline=test_pipeline, + ann_file=data_root + 'test/annfiles/', + img_prefix=data_root + 'test/images/', + version=angle_version), + test=dict(type='RSGWSOODDataset', pipeline=test_pipeline, + ann_file=data_root + 'test/annfiles/', + img_prefix=data_root + 'test/images/', + version=angle_version)) + +log_config = dict(interval=50) + +optimizer = dict( + _delete_=True, + type='AdamW', + lr=0.00005, + betas=(0.9, 0.999), + weight_decay=0.05, + paramwise_cfg=dict( + custom_keys={ + 'absolute_pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + })) + +checkpoint_config = dict(interval=1, max_keep_ckpts=1) +evaluation = dict(interval=6, metric='mAP') diff --git a/configs/h2rbox/h2rbox_r50_fpn_3x_rsg_le90.py b/configs/h2rbox/h2rbox_r50_fpn_3x_rsg_le90.py new file mode 100644 index 0000000..8962160 --- /dev/null +++ b/configs/h2rbox/h2rbox_r50_fpn_3x_rsg_le90.py @@ -0,0 +1,135 @@ +_base_ = [ + '../_base_/datasets/rsg.py', + '../_base_/schedules/schedule_3x.py', + '../_base_/default_runtime.py' +] +angle_version = 'le90' + +# model settings +model = dict( + type='H2RBox', + crop_size=(1024, 1024), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + zero_init_residual=False, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', # use P5 + num_outs=5, + relu_before_extra_convs=True), + bbox_head=dict( + type='H2RBoxHead', + num_classes=48, + in_channels=256, + stacked_convs=4, + feat_channels=256, + strides=[8, 16, 32, 64, 128], + center_sampling=True, + center_sample_radius=1.5, + norm_on_bbox=True, + centerness_on_reg=True, + separate_angle=False, + scale_angle=True, + reassigner='one2one', + rect_classes=[4, 44], + bbox_coder=dict( + type='DistanceAnglePointCoder', angle_version=angle_version), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='IoULoss', loss_weight=1.0), + loss_bbox_aug=dict( + type='H2RBoxLoss', + loss_weight=0.4, + center_loss_cfg=dict(type='L1Loss', loss_weight=0.0), + shape_loss_cfg=dict(type='IoULoss', loss_weight=1.0), + angle_loss_cfg=dict(type='L1Loss', loss_weight=1.0)), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), + # training and testing settings + train_cfg=None, + test_cfg=dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(iou_thr=0.1), + max_per_img=2000)) + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='FilterNoCenterObject', img_scale=(1024, 1024), crop_size=(1024, 1024)), + dict(type='RResize', img_scale=(1024, 1024)), + dict( + type='RRandomFlip', + flip_ratio=[0.25, 0.25, 0.25], + direction=['horizontal', 'vertical', 'diagonal'], + version=angle_version), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=1), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + scale_factor=1.0, + flip=False, + transforms=[ + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=64), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img']) + ]) +] + +data_root = 'data/RSG/' +data = dict( + train=dict(type='RSGWSOODDataset', pipeline=train_pipeline, + ann_file=data_root + 'train/annfiles/', + img_prefix=data_root + 'train/images/', + version=angle_version), + val=dict(type='RSGWSOODDataset', pipeline=test_pipeline, + ann_file=data_root + 'test/annfiles/', + img_prefix=data_root + 'test/images/', + version=angle_version), + test=dict(type='RSGWSOODDataset', pipeline=test_pipeline, + ann_file=data_root + 'test/annfiles/', + img_prefix=data_root + 'test/images/', + version=angle_version)) + +log_config = dict(interval=50) + +optimizer = dict( + _delete_=True, + type='AdamW', + lr=0.0001, + betas=(0.9, 0.999), + weight_decay=0.05, + paramwise_cfg=dict( + custom_keys={ + 'absolute_pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.) + })) + +checkpoint_config = dict(interval=1, max_keep_ckpts=1) +evaluation = dict(interval=6, metric='mAP') diff --git a/mmrotate/models/detectors/__init__.py b/mmrotate/models/detectors/__init__.py index 3ec4f8b..c6fd9e3 100644 --- a/mmrotate/models/detectors/__init__.py +++ b/mmrotate/models/detectors/__init__.py @@ -15,6 +15,7 @@ from .two_stage import RotatedTwoStageDetector from .two_stage_crop import RotatedTwoStageDetectorCrop from .rotated_detr import RotatedDETR +from .rotated_detr_crop import RotatedDETRCrop from .rotated_deformable_detr import RotatedDeformableDETR from .ars_detr import ARSDETR from .h2rbox import H2RBox @@ -26,5 +27,6 @@ 'GlidingVertex', 'ReDet', 'R3Det', 'S2ANet', 'RotatedRepPoints', 'RotatedBaseDetector', 'RotatedSingleStageDetectorCrop', 'RotatedTwoStageDetector', 'RotatedSingleStageDetector', 'RotatedFCOS', 'RotatedTwoStageDetectorCrop', - 'RotatedDETR', 'RotatedDeformableDETR', 'ARSDETR', 'H2RBox','R3DetCrop','S2ANetCrop' + 'RotatedDETR', 'RotatedDeformableDETR', 'ARSDETR', 'H2RBox','R3DetCrop','S2ANetCrop', + 'RotatedDETRCrop' ] diff --git a/mmrotate/models/detectors/ars_detr.py b/mmrotate/models/detectors/ars_detr.py index 8fbfb74..cfb2025 100644 --- a/mmrotate/models/detectors/ars_detr.py +++ b/mmrotate/models/detectors/ars_detr.py @@ -2,11 +2,11 @@ # from ..builder import DETECTORS # from mmdet.models.builder import DETECTORS from ..builder import ROTATED_DETECTORS -from .rotated_detr import RotatedDETR +from .rotated_detr_crop import RotatedDETRCrop @ROTATED_DETECTORS.register_module() -class ARSDETR(RotatedDETR): +class ARSDETR(RotatedDETRCrop): def __init__(self, *args, **kwargs): - super(RotatedDETR, self).__init__(*args, **kwargs) + super(RotatedDETRCrop, self).__init__(*args, **kwargs) diff --git a/mmrotate/models/detectors/r3det_crop.py b/mmrotate/models/detectors/r3det_crop.py index 63bbd23..0b78198 100644 --- a/mmrotate/models/detectors/r3det_crop.py +++ b/mmrotate/models/detectors/r3det_crop.py @@ -8,7 +8,8 @@ from .base import RotatedBaseDetector from .utils import FeatureRefineModule -from .single_stage_img_split_bridge_tools import * + +from mmrotate.models.detectors.single_stage_img_split_bridge_tools import * from mmdet.utils import get_device def resize_bboxes(bboxes,scale): @@ -44,8 +45,10 @@ def list2tensor(img_lists): ''' inputs = [] for img in img_lists: - inputs.append(img.cpu()) + inputs.append(img.cpu()) # 转移到cpu上,否则torch.stack内存不足 inputs = torch.stack(inputs, dim=0) + # inputs = torch.stack(inputs, dim=0).to(get_device()) + return inputs def FullImageCrop(self, imgs, bboxes, labels, patch_shape, @@ -62,17 +65,18 @@ def FullImageCrop(self, imgs, bboxes, labels, patch_shape, Returns: dict[str, Tensor]: A dictionary of loss components. """ - out_imgs=[] - out_bboxes=[] - out_labels=[] - out_metas=[] + out_imgs = [] + out_bboxes = [] + out_labels = [] + out_metas = [] device = get_device() img_rate_thr = 0.6 # 图片与wins窗口的交并比阈值 iof_thr = 0.1 # 裁剪后的标签占原标签的比值阈值 + padding_value = [0.0081917211329, -0.004901960784, 0.0055655449953] # 归一化后的padding值 if mode == 'train': # for i in range(imgs.shape[0]): - for img, bbox, label in zip(imgs, bboxes, labels): + for img, bbox, label in zip(imgs, [bboxes], [labels]): p_imgs = [] p_bboxes = [] p_labels = [] @@ -80,14 +84,14 @@ def FullImageCrop(self, imgs, bboxes, labels, patch_shape, img = img.cpu() # patch info = dict() - info['labels'] = np.array(torch.tensor(label, device='cpu',requires_grad=False)) + info['labels'] = np.array(torch.tensor(label, device='cpu', requires_grad=False)) info['ann'] = {'bboxes': {}} - info['width'] = img.shape[1] - info['height'] = img.shape[2] + info['width'] = img.shape[2] + info['height'] = img.shape[1] tmp_boxes = torch.tensor(bbox, device='cpu', requires_grad=False) - info['ann']['bboxes'] = np.array(obb2poly(tmp_boxes, self.version)) - + info['ann']['bboxes'] = np.array(obb2poly(tmp_boxes, self.version)) # 这里将OBB转换为8点表示形式 + bbbox = info['ann']['bboxes'] sizes = [patch_shape[0]] # gaps=[0] windows = get_sliding_window(info, sizes, gaps, img_rate_thr) @@ -95,20 +99,29 @@ def FullImageCrop(self, imgs, bboxes, labels, patch_shape, patchs, patch_infos = crop_and_save_img(info, windows, window_anns, img, no_padding=True, - padding_value=[104, 116, 124]) + # no_padding=False, + padding_value=padding_value) + # 对每张大图分解成的子图集合中的每张子图遍历 for i, patch_info in enumerate(patch_infos): if jump_empty_patch: - if patch_info['labels'] == [-1]: - continue + # 如果该patch中不含有效标签,将其跳过不输出,可在训练时使用 + if patch_info['labels'] == [-1]: + # print('Patch does not contain box.\n') + continue obj = patch_info['ann'] - tmp_boxes = poly2obb(torch.tensor(obj['bboxes']), self.version) + if min(obj['bboxes'].shape) == 0: # 张量为空 + tmp_boxes = poly2obb(torch.tensor(obj['bboxes']), 'oc') # oc转化可以处理空张量 + else: + tmp_boxes = poly2obb(torch.tensor(obj['bboxes']), self.version) # 转化回5参数 p_bboxes.append(tmp_boxes.to(device)) + # p_trunc.append(torch.tensor(obj['trunc'],device=device)) # 是否截断,box全部在win内部时为false + ## 若box超出win范围则trunc为true p_labels.append(torch.tensor(patch_info['labels'], device=device)) p_metas.append({'x_start': torch.tensor(patch_info['x_start'], device=device), 'y_start': torch.tensor(patch_info['y_start'], device=device), - 'shape': patch_shape, 'trunc': torch.tensor(obj['trunc'], device=device)}) + 'shape': patch_shape, 'trunc': torch.tensor(obj['trunc'], device=device),'img_shape': patch_shape, 'scale_factor': 1}) patch = patchs[i] p_imgs.append(patch.to(device)) @@ -118,7 +131,11 @@ def FullImageCrop(self, imgs, bboxes, labels, patch_shape, out_labels.append(p_labels) out_metas.append(p_metas) - elif mode =='test': + #### change for sgdet + # poly2obb(out_bboxes, self.version) + return out_imgs, out_bboxes, out_labels, out_metas + + elif mode == 'test': p_imgs = [] p_metas = [] img = imgs.cpu().squeeze(0) @@ -126,20 +143,21 @@ def FullImageCrop(self, imgs, bboxes, labels, patch_shape, info = dict() info['labels'] = np.array(torch.tensor([], device='cpu')) info['ann'] = {'bboxes': {}} - info['width'] = img.shape[1] - info['height'] = img.shape[2] + info['width'] = img.shape[2] + info['height'] = img.shape[1] sizes = [patch_shape[0]] # gaps=[0] windows = get_sliding_window(info, sizes, gaps, img_rate_thr) - patchs, patch_infos = crop_img_withoutann(info, windows,img, + patchs, patch_infos = crop_img_withoutann(info, windows, img, no_padding=False, - padding_value=[104, 116, 124]) + padding_value=padding_value) + # 对每张大图分解成的子图集合中的每张子图遍历 for i, patch_info in enumerate(patch_infos): p_metas.append({'x_start': torch.tensor(patch_info['x_start'], device=device), 'y_start': torch.tensor(patch_info['y_start'], device=device), - 'shape': patch_shape,'img_shape':patch_shape, 'scale_factor':1}) + 'shape': patch_shape, 'img_shape': patch_shape, 'scale_factor': 1}) patch = patchs[i] p_imgs.append(patch.to(device)) @@ -149,15 +167,7 @@ def FullImageCrop(self, imgs, bboxes, labels, patch_shape, return out_imgs, out_metas - - return out_imgs, out_bboxes,out_labels, out_metas - -def get_single_img(fea_g_necks, i): - fea_g_neck=[] - for idx in range(len(fea_g_necks)): - fea_g_neck.append(fea_g_necks[idx][i]) - - return tuple(fea_g_neck) + return out_imgs, out_bboxes, out_labels, out_metas def relocate(idx, local_bboxes, patch_meta): # put patches' local bboxes to full img via patch_meta @@ -174,8 +184,6 @@ def relocate(idx, local_bboxes, patch_meta): return - - @ROTATED_DETECTORS.register_module() class R3DetCrop(RotatedBaseDetector): """Rotated Refinement RetinaNet.""" @@ -309,7 +317,8 @@ def forward_train(self, # ] # return bbox_results - def simple_test(self, img, img_meta, rescale=False): + def simple_test(self, img, img_metas, rescale=False): + # 为了使用裁剪小图策略推理标准模型 """Test function without test time augmentation. Args: @@ -323,6 +332,7 @@ def simple_test(self, img, img_meta, rescale=False): The outer list corresponds to each image. The inner list \ corresponds to each class. """ + # print('single stage infetence!!!!!!') gaps = [200] patch_shape = (1024, 1024) p_bs = 4 # patch batchsize @@ -335,7 +345,9 @@ def simple_test(self, img, img_meta, rescale=False): local_bboxes_lists=[] for i in range(img.shape[0]): j = 0 - patches = list2tensor(p_imgs[i]) + # patches = list2tensor(p_imgs[i]) # list to tensor,此时放在cpu上 + p_imgs[i]=torch.stack(p_imgs[i], dim=0) + patches=p_imgs[i] patches_meta = p_metas[i] # patch batchsize while j < len(p_imgs[i]): @@ -344,36 +356,36 @@ def simple_test(self, img, img_meta, rescale=False): patch_meta = patches_meta[j:] else: patch = patches[j:j + p_bs] - patch_meta = patches_meta[j:j + p_bs] + patch_meta = patches_meta[j:j + p_bs] # x_start and y_start + with torch.no_grad(): - patch=patch.cuda() - x = self.extract_feat(patch) - outs = self.bbox_head(x) - rois = self.bbox_head.filter_bboxes(*outs) + fea_l_neck = self.extract_feat(patch) + outs_local = self.bbox_head(fea_l_neck) + rois_local = self.bbox_head.filter_bboxes(*outs_local) # rois: list(indexed by images) of list(indexed by levels) for i in range(self.num_refine_stages): - x_refine = self.feat_refine_module[i](x, rois) - outs = self.refine_head[i](x_refine) + x_refine = self.feat_refine_module[i](fea_l_neck, rois_local) + outs_local = self.refine_head[i](x_refine) if i + 1 in range(self.num_refine_stages): - rois = self.refine_head[i].refine_bboxes(*outs, rois=rois) - bbox_inputs = outs + (patch_meta, self.test_cfg, rescale) - local_bbox_list = self.refine_head[-1].get_bboxes(*bbox_inputs, rois=rois) - + rois_local = self.refine_head[i].refine_bboxes(*outs_local, rois=rois_local) + + bbox_inputs = outs_local + (patch_meta, self.test_cfg, True) + local_bbox_list = self.refine_head[-1].get_bboxes(*bbox_inputs, rois=rois_local) + for idx, res_list in enumerate(local_bbox_list): det_bboxes, det_labels = res_list relocate(idx, det_bboxes, patch_meta) local_bboxes_lists.append(local_bbox_list) j = j+p_bs - + bbox_list = [merge_results([local_bboxes_lists],iou_thr=0.4)] bbox_results = [ - rbbox2result(det_bboxes, det_labels, - self.refine_head[-1].num_classes) + rbbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) for det_bboxes, det_labels in bbox_list ] return bbox_results - + def aug_test(self, imgs, img_metas, **kwargs): """Test function with test time augmentation.""" pass diff --git a/mmrotate/models/detectors/r3det_crop_bak.py b/mmrotate/models/detectors/r3det_crop_bak.py new file mode 100644 index 0000000..51393f7 --- /dev/null +++ b/mmrotate/models/detectors/r3det_crop_bak.py @@ -0,0 +1,381 @@ +# Copyright (c) SJTU. All rights reserved. +import warnings + +from mmcv.runner import ModuleList + +from mmrotate.core import rbbox2result +from ..builder import ROTATED_DETECTORS, build_backbone, build_head, build_neck +from .base import RotatedBaseDetector +from .utils import FeatureRefineModule + +from .single_stage_img_split_bridge_tools import * +from mmdet.utils import get_device + +def resize_bboxes(bboxes,scale): + """Resize bounding boxes with scales.""" + + orig_shape = bboxes.shape + out_boxxes=bboxes.clone().reshape((-1, 5)) + # bboxes = bboxes.reshape((-1, 5)) + w_scale = scale + h_scale = scale + out_boxxes[:, 0] *= w_scale + out_boxxes[:, 1] *= h_scale + out_boxxes[:, 2:4] *= np.sqrt(w_scale * h_scale) + + return out_boxxes + +def resize(images, shape, label=False): + ''' + resize PIL images + shape: (w, h) + ''' + resized = list(images) + for i in range(len(images)): + if label: + resized[i] = images[i].resize(shape, Image.NEAREST) + else: + resized[i] = images[i].resize(shape, Image.BILINEAR) + return resized + +def list2tensor(img_lists): + ''' + images: list of list of tensor images + ''' + inputs = [] + for img in img_lists: + inputs.append(img.cpu()) + inputs = torch.stack(inputs, dim=0) + return inputs + +def FullImageCrop(self, imgs, bboxes, labels, patch_shape, + gaps, + jump_empty_patch=False, + mode='train'): + """ + Args: + imgs (Tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + bboxes (list[Tensor]): Each item are the truth boxes for each + image in [tl_x, tl_y, br_x, br_y] format. + labels (list[Tensor]): Class indices corresponding to each box + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + out_imgs=[] + out_bboxes=[] + out_labels=[] + out_metas=[] + device = get_device() + img_rate_thr = 0.6 # 图片与wins窗口的交并比阈值 + iof_thr = 0.1 # 裁剪后的标签占原标签的比值阈值 + padding_value = [0.0081917211329, -0.004901960784, 0.0055655449953] # 归一化后的padding值 + + + if mode == 'train': + # for i in range(imgs.shape[0]): + for img, bbox, label in zip(imgs, bboxes, labels): + p_imgs = [] + p_bboxes = [] + p_labels = [] + p_metas = [] + img = img.cpu() + # patch + info = dict() + info['labels'] = np.array(torch.tensor(label, device='cpu',requires_grad=False)) + info['ann'] = {'bboxes': {}} + info['width'] = img.shape[2] + info['height'] = img.shape[1] + + tmp_boxes = torch.tensor(bbox, device='cpu', requires_grad=False) + info['ann']['bboxes'] = np.array(obb2poly(tmp_boxes, self.version)) + + sizes = [patch_shape[0]] + # gaps=[0] + windows = get_sliding_window(info, sizes, gaps, img_rate_thr) + window_anns = get_window_obj(info, windows, iof_thr) + patchs, patch_infos = crop_and_save_img(info, windows, window_anns, + img, + no_padding=True, + padding_value=padding_value) + + for i, patch_info in enumerate(patch_infos): + if jump_empty_patch: + if patch_info['labels'] == [-1]: + continue + + obj = patch_info['ann'] + tmp_boxes = poly2obb(torch.tensor(obj['bboxes']), self.version) + p_bboxes.append(tmp_boxes.to(device)) + p_labels.append(torch.tensor(patch_info['labels'], device=device)) + p_metas.append({'x_start': torch.tensor(patch_info['x_start'], device=device), + 'y_start': torch.tensor(patch_info['y_start'], device=device), + 'shape': patch_shape, 'trunc': torch.tensor(obj['trunc'], device=device)}) + + patch = patchs[i] + p_imgs.append(patch.to(device)) + + out_imgs.append(p_imgs) + out_bboxes.append(p_bboxes) + out_labels.append(p_labels) + out_metas.append(p_metas) + + elif mode =='test': + p_imgs = [] + p_metas = [] + img = imgs.cpu().squeeze(0) + # patch + info = dict() + info['labels'] = np.array(torch.tensor([], device='cpu')) + info['ann'] = {'bboxes': {}} + info['width'] = img.shape[2] + info['height'] = img.shape[1] + + sizes = [patch_shape[0]] + # gaps=[0] + windows = get_sliding_window(info, sizes, gaps, img_rate_thr) + patchs, patch_infos = crop_img_withoutann(info, windows,img, + no_padding=False, + padding_value=padding_value) + + for i, patch_info in enumerate(patch_infos): + p_metas.append({'x_start': torch.tensor(patch_info['x_start'], device=device), + 'y_start': torch.tensor(patch_info['y_start'], device=device), + 'shape': patch_shape,'img_shape':patch_shape, 'scale_factor':1}) + + patch = patchs[i] + p_imgs.append(patch.to(device)) + + out_imgs.append(p_imgs) + out_metas.append(p_metas) + + return out_imgs, out_metas + + + return out_imgs, out_bboxes,out_labels, out_metas + +def get_single_img(fea_g_necks, i): + fea_g_neck=[] + for idx in range(len(fea_g_necks)): + fea_g_neck.append(fea_g_necks[idx][i]) + + return tuple(fea_g_neck) + +def relocate(idx, local_bboxes, patch_meta): + # put patches' local bboxes to full img via patch_meta + meta=patch_meta[idx] + top = meta['y_start'] + left = meta['x_start'] + + for i in range(len(local_bboxes)): + bbox = local_bboxes[i] + if bbox.size()[0] == 0: + continue + bbox[0] += left + bbox[1] += top + + return + + + +@ROTATED_DETECTORS.register_module() +class R3DetCrop(RotatedBaseDetector): + """Rotated Refinement RetinaNet.""" + + def __init__(self, + num_refine_stages, + backbone, + neck=None, + bbox_head=None, + frm_cfgs=None, + refine_heads=None, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None): + super(R3DetCrop, self).__init__(init_cfg) + if pretrained: + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + backbone.pretrained = pretrained + self.backbone = build_backbone(backbone) + self.num_refine_stages = num_refine_stages + if neck is not None: + self.neck = build_neck(neck) + if train_cfg is not None: + bbox_head.update(train_cfg=train_cfg['s0']) + bbox_head.update(test_cfg=test_cfg) + self.bbox_head = build_head(bbox_head) + self.feat_refine_module = ModuleList() + self.refine_head = ModuleList() + for i, (frm_cfg, + refine_head) in enumerate(zip(frm_cfgs, refine_heads)): + self.feat_refine_module.append(FeatureRefineModule(**frm_cfg)) + if train_cfg is not None: + refine_head.update(train_cfg=train_cfg['sr'][i]) + refine_head.update(test_cfg=test_cfg) + self.refine_head.append(build_head(refine_head)) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + def extract_feat(self, img): + """Directly extract features from the backbone+neck.""" + x = self.backbone(img) + if self.with_neck: + x = self.neck(x) + return x + + def forward_dummy(self, img): + """Used for computing network flops. + + See `mmedetection/tools/get_flops.py` + """ + x = self.extract_feat(img) + outs = self.bbox_head(x) + rois = self.bbox_head.filter_bboxes(*outs) + # rois: list(indexed by images) of list(indexed by levels) + for i in range(self.num_refine_stages): + x_refine = self.feat_refine_module[i](x, rois) + outs = self.refine_head[i](x_refine) + if i + 1 in range(self.num_refine_stages): + rois = self.refine_head[i].refine_bboxes(*outs, rois=rois) + return outs + + def forward_train(self, + img, + img_metas, + gt_bboxes, + gt_labels, + gt_bboxes_ignore=None): + """Forward function.""" + losses = dict() + x = self.extract_feat(img) + + outs = self.bbox_head(x) + + loss_inputs = outs + (gt_bboxes, gt_labels, img_metas) + loss_base = self.bbox_head.loss( + *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) + for name, value in loss_base.items(): + losses[f's0.{name}'] = value + + rois = self.bbox_head.filter_bboxes(*outs) + # rois: list(indexed by images) of list(indexed by levels) + for i in range(self.num_refine_stages): + lw = self.train_cfg.stage_loss_weights[i] + + x_refine = self.feat_refine_module[i](x, rois) + outs = self.refine_head[i](x_refine) + loss_inputs = outs + (gt_bboxes, gt_labels, img_metas) + loss_refine = self.refine_head[i].loss( + *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore, rois=rois) + for name, value in loss_refine.items(): + losses[f'sr{i}.{name}'] = ([v * lw for v in value] + if 'loss' in name else value) + + if i + 1 in range(self.num_refine_stages): + rois = self.refine_head[i].refine_bboxes(*outs, rois=rois) + + return losses + + # def simple_test(self, img, img_meta, rescale=False): + # """Test function without test time augmentation. + + # Args: + # imgs (list[torch.Tensor]): List of multiple images + # img_metas (list[dict]): List of image information. + # rescale (bool, optional): Whether to rescale the results. + # Defaults to False. + + # Returns: + # list[list[np.ndarray]]: BBox results of each image and classes. \ + # The outer list corresponds to each image. The inner list \ + # corresponds to each class. + # """ + # x = self.extract_feat(img) + # outs = self.bbox_head(x) + # rois = self.bbox_head.filter_bboxes(*outs) + # # rois: list(indexed by images) of list(indexed by levels) + # for i in range(self.num_refine_stages): + # x_refine = self.feat_refine_module[i](x, rois) + # outs = self.refine_head[i](x_refine) + # if i + 1 in range(self.num_refine_stages): + # rois = self.refine_head[i].refine_bboxes(*outs, rois=rois) + + # bbox_inputs = outs + (img_meta, self.test_cfg, rescale) + # bbox_list = self.refine_head[-1].get_bboxes(*bbox_inputs, rois=rois) + # bbox_results = [ + # rbbox2result(det_bboxes, det_labels, + # self.refine_head[-1].num_classes) + # for det_bboxes, det_labels in bbox_list + # ] + # return bbox_results + + def simple_test(self, img, img_meta, rescale=False): + """Test function without test time augmentation. + + Args: + imgs (list[torch.Tensor]): List of multiple images + img_metas (list[dict]): List of image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[list[np.ndarray]]: BBox results of each image and classes. \ + The outer list corresponds to each image. The inner list \ + corresponds to each class. + """ + gaps = [200] + patch_shape = (1024, 1024) + p_bs = 4 # patch batchsize + # Crop full img into patches + gt_bboxes=[] + gt_labels=[] + p_imgs, p_metas = FullImageCrop(self, img, gt_bboxes, gt_labels, + patch_shape=patch_shape, + gaps=gaps, mode='test') + local_bboxes_lists=[] + for i in range(img.shape[0]): + j = 0 + patches = list2tensor(p_imgs[i]) + patches_meta = p_metas[i] + # patch batchsize + while j < len(p_imgs[i]): + if (j+p_bs) >= len(p_imgs[i]): + patch = patches[j:] + patch_meta = patches_meta[j:] + else: + patch = patches[j:j + p_bs] + patch_meta = patches_meta[j:j + p_bs] + with torch.no_grad(): + patch=patch.cuda() + x = self.extract_feat(patch) + outs = self.bbox_head(x) + rois = self.bbox_head.filter_bboxes(*outs) + # rois: list(indexed by images) of list(indexed by levels) + for i in range(self.num_refine_stages): + x_refine = self.feat_refine_module[i](x, rois) + outs = self.refine_head[i](x_refine) + if i + 1 in range(self.num_refine_stages): + rois = self.refine_head[i].refine_bboxes(*outs, rois=rois) + bbox_inputs = outs + (patch_meta, self.test_cfg, rescale) + local_bbox_list = self.refine_head[-1].get_bboxes(*bbox_inputs, rois=rois) + + for idx, res_list in enumerate(local_bbox_list): + det_bboxes, det_labels = res_list + relocate(idx, det_bboxes, patch_meta) + local_bboxes_lists.append(local_bbox_list) + + j = j+p_bs + + bbox_list = [merge_results([local_bboxes_lists],iou_thr=0.4)] + bbox_results = [ + rbbox2result(det_bboxes, det_labels, + self.refine_head[-1].num_classes) + for det_bboxes, det_labels in bbox_list + ] + return bbox_results + + def aug_test(self, imgs, img_metas, **kwargs): + """Test function with test time augmentation.""" + pass diff --git a/mmrotate/models/detectors/rotated_deformable_detr.py b/mmrotate/models/detectors/rotated_deformable_detr.py index 63175f4..549917c 100644 --- a/mmrotate/models/detectors/rotated_deformable_detr.py +++ b/mmrotate/models/detectors/rotated_deformable_detr.py @@ -2,11 +2,11 @@ # from ..builder import DETECTORS # from mmdet.models.builder import DETECTORS from ..builder import ROTATED_DETECTORS -from .rotated_detr import RotatedDETR +from .rotated_detr_crop import RotatedDETRCrop @ROTATED_DETECTORS.register_module() -class RotatedDeformableDETR(RotatedDETR): +class RotatedDeformableDETR(RotatedDETRCrop): def __init__(self, *args, **kwargs): - super(RotatedDETR, self).__init__(*args, **kwargs) \ No newline at end of file + super(RotatedDETRCrop, self).__init__(*args, **kwargs) \ No newline at end of file diff --git a/mmrotate/models/detectors/rotated_detr_crop.py b/mmrotate/models/detectors/rotated_detr_crop.py new file mode 100644 index 0000000..7bde4b7 --- /dev/null +++ b/mmrotate/models/detectors/rotated_detr_crop.py @@ -0,0 +1,401 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings + +import torch +import numpy as np +import torch.nn as nn + +from ..builder import ROTATED_DETECTORS, build_head +from .single_stage_crop import RotatedSingleStageDetectorCrop +from .utils import FeatureRefineModule +#from mmdet.core import choose_best_match_batch, gt_mask_bp_obbs_list, choose_best_Rroi_batch + + +from mmrotate.models.detectors.single_stage_img_split_bridge_tools import * +from mmdet.utils import get_device + +def resize_bboxes(bboxes,scale): + """Resize bounding boxes with scales.""" + + orig_shape = bboxes.shape + out_boxxes=bboxes.clone().reshape((-1, 5)) + # bboxes = bboxes.reshape((-1, 5)) + w_scale = scale + h_scale = scale + out_boxxes[:, 0] *= w_scale + out_boxxes[:, 1] *= h_scale + out_boxxes[:, 2:4] *= np.sqrt(w_scale * h_scale) + + return out_boxxes + +def resize(images, shape, label=False): + ''' + resize PIL images + shape: (w, h) + ''' + resized = list(images) + for i in range(len(images)): + if label: + resized[i] = images[i].resize(shape, Image.NEAREST) + else: + resized[i] = images[i].resize(shape, Image.BILINEAR) + return resized + +def list2tensor(img_lists): + ''' + images: list of list of tensor images + ''' + inputs = [] + for img in img_lists: + inputs.append(img.cpu()) # 转移到cpu上,否则torch.stack内存不足 + inputs = torch.stack(inputs, dim=0) + # inputs = torch.stack(inputs, dim=0).to(get_device()) + + return inputs + +def FullImageCrop(self, imgs, bboxes, labels, patch_shape, + gaps, + jump_empty_patch=False, + mode='train'): + """ + Args: + imgs (Tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + bboxes (list[Tensor]): Each item are the truth boxes for each + image in [tl_x, tl_y, br_x, br_y] format. + labels (list[Tensor]): Class indices corresponding to each box + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + out_imgs = [] + out_bboxes = [] + out_labels = [] + out_metas = [] + device = get_device() + img_rate_thr = 0.6 # 图片与wins窗口的交并比阈值 + iof_thr = 0.1 # 裁剪后的标签占原标签的比值阈值 + padding_value = [0.0081917211329, -0.004901960784, 0.0055655449953] # 归一化后的padding值 + + if mode == 'train': + # for i in range(imgs.shape[0]): + for img, bbox, label in zip(imgs, [bboxes], [labels]): + p_imgs = [] + p_bboxes = [] + p_labels = [] + p_metas = [] + img = img.cpu() + # patch + info = dict() + info['labels'] = np.array(torch.tensor(label, device='cpu', requires_grad=False)) + info['ann'] = {'bboxes': {}} + info['width'] = img.shape[2] + info['height'] = img.shape[1] + + tmp_boxes = torch.tensor(bbox, device='cpu', requires_grad=False) + info['ann']['bboxes'] = np.array(obb2poly(tmp_boxes, self.version)) # 这里将OBB转换为8点表示形式 + bbbox = info['ann']['bboxes'] + sizes = [patch_shape[0]] + # gaps=[0] + windows = get_sliding_window(info, sizes, gaps, img_rate_thr) + window_anns = get_window_obj(info, windows, iof_thr) + patchs, patch_infos = crop_and_save_img(info, windows, window_anns, + img, + no_padding=True, + # no_padding=False, + padding_value=padding_value) + + # 对每张大图分解成的子图集合中的每张子图遍历 + for i, patch_info in enumerate(patch_infos): + if jump_empty_patch: + # 如果该patch中不含有效标签,将其跳过不输出,可在训练时使用 + + if patch_info['labels'] == [-1]: + # print('Patch does not contain box.\n') + continue + obj = patch_info['ann'] + if min(obj['bboxes'].shape) == 0: # 张量为空 + tmp_boxes = poly2obb(torch.tensor(obj['bboxes']), 'oc') # oc转化可以处理空张量 + else: + tmp_boxes = poly2obb(torch.tensor(obj['bboxes']), self.version) # 转化回5参数 + p_bboxes.append(tmp_boxes.to(device)) + # p_trunc.append(torch.tensor(obj['trunc'],device=device)) # 是否截断,box全部在win内部时为false + ## 若box超出win范围则trunc为true + p_labels.append(torch.tensor(patch_info['labels'], device=device)) + p_metas.append({'x_start': torch.tensor(patch_info['x_start'], device=device), + 'y_start': torch.tensor(patch_info['y_start'], device=device), + 'shape': patch_shape, 'trunc': torch.tensor(obj['trunc'], device=device), + 'img_shape': patch_shape, 'scale_factor': 1}) + + patch = patchs[i] + p_imgs.append(patch.to(device)) + + out_imgs.append(p_imgs) + out_bboxes.append(p_bboxes) + out_labels.append(p_labels) + out_metas.append(p_metas) + + #### change for sgdet + # poly2obb(out_bboxes, self.version) + return out_imgs, out_bboxes, out_labels, out_metas + + elif mode == 'test': + p_imgs = [] + p_metas = [] + img = imgs.cpu().squeeze(0) + # patch + info = dict() + info['labels'] = np.array(torch.tensor([], device='cpu')) + info['ann'] = {'bboxes': {}} + info['width'] = img.shape[2] + info['height'] = img.shape[1] + + sizes = [patch_shape[0]] + # gaps=[0] + windows = get_sliding_window(info, sizes, gaps, img_rate_thr) + patchs, patch_infos = crop_img_withoutann(info, windows, img, + no_padding=False, + padding_value=padding_value) + + # 对每张大图分解成的子图集合中的每张子图遍历 + for i, patch_info in enumerate(patch_infos): + p_metas.append({'x_start': torch.tensor(patch_info['x_start'], device=device), + 'y_start': torch.tensor(patch_info['y_start'], device=device), + 'shape': patch_shape, 'img_shape': patch_shape + (3,), + 'batch_input_shape': patch_shape, 'scale_factor': 1}) + + patch = patchs[i] + p_imgs.append(patch.to(device)) + + out_imgs.append(p_imgs) + out_metas.append(p_metas) + + return out_imgs, out_metas + + return out_imgs, out_bboxes, out_labels, out_metas + + + +def relocate(idx, local_bboxes, patch_meta): + # put patches' local bboxes to full img via patch_meta + meta=patch_meta[idx] + top = meta['y_start'] + left = meta['x_start'] + + for i in range(len(local_bboxes)): + bbox = local_bboxes[i] + if bbox.size()[0] == 0: + continue + bbox[0] += left + bbox[1] += top + + return + +@ROTATED_DETECTORS.register_module() +class RotatedDETRCrop(RotatedSingleStageDetectorCrop): + r"""Implementation of `DETR: End-to-End Object Detection with + Transformers `_""" + + def __init__(self, + backbone, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None): + super(RotatedDETRCrop, self).__init__(backbone, None, bbox_head, train_cfg, + test_cfg, pretrained, init_cfg) + + # over-write `forward_dummy` because: + # the forward of bbox_head requires img_metas + def rbox2result(self, bboxes, labels, num_classes): + """Convert detection results to a list of numpy arrays. + + Args: + bboxes (torch.Tensor | np.ndarray): shape (n, 5) + labels (torch.Tensor | np.ndarray): shape (n, ) + num_classes (int): class number, including background class + + Returns: + list(ndarray): bbox results of each class + """ + if bboxes.shape[0] == 0: + return [np.zeros((0, 9), dtype=np.float32) for i in range(num_classes)] # TODOinsert + else: + if isinstance(bboxes, torch.Tensor): + bboxes = bboxes.detach().cpu().numpy() # dets,rboxes[keep],scores_k[keep] + labels = labels.detach().cpu().numpy() + + return [bboxes[labels == i, :] for i in range(num_classes)] + + def forward_dummy(self, img): + """Used for computing network flops. + + See `mmdetection/tools/analysis_tools/get_flops.py` + """ + warnings.warn('Warning! MultiheadAttention in DETR does not ' + 'support flops computation! Do not use the ' + 'results in your papers!') + + batch_size, _, height, width = img.shape + dummy_img_metas = [ + dict( + batch_input_shape=(height, width), + img_shape=(height, width, 3)) for _ in range(batch_size) + ] + x = self.extract_feat(img) + outs = self.bbox_head(x, dummy_img_metas) + return outs + + # over-write `onnx_export` because: + # (1) the forward of bbox_head requires img_metas + # (2) the different behavior (e.g. construction of `masks`) between + # torch and ONNX model, during the forward of bbox_head + def onnx_export(self, img, img_metas): + """Test function for exporting to ONNX, without test time augmentation. + + Args: + img (torch.Tensor): input images. + img_metas (list[dict]): List of image information. + + Returns: + tuple[Tensor, Tensor]: dets of shape [N, num_det, 5] + and class labels of shape [N, num_det]. + """ + x = self.extract_feat(img) + # forward of this head requires img_metas + outs = self.bbox_head.forward_onnx(x, img_metas) + # get shape as tensor + img_shape = torch._shape_as_tensor(img)[2:] + img_metas[0]['img_shape_for_onnx'] = img_shape + + det_bboxes, det_labels = self.bbox_head.onnx_export(*outs, img_metas) + + return det_bboxes, det_labels + + def forward_train(self, + img, + img_metas, + gt_bboxes, + gt_labels, + gt_bboxes_ignore=None): + """ + Args: + img (Tensor): Input images of shape (N, C, H, W). + Typically these should be mean centered and std scaled. + img_metas (list[dict]): A List of image info dict where each dict + has: 'img_shape', 'scale_factor', 'flip', and may also contain + 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. + For details on the values of these keys see + :class:`mmdet.datasets.pipelines.Collect`. + gt_bboxes (list[Tensor]): Each item are the truth boxes for each + image in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): Class indices corresponding to each box + gt_bboxes_ignore (None | list[Tensor]): Specify which bounding + boxes can be ignored when computing the loss. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + super(RotatedSingleStageDetectorCrop, self).forward_train(img, img_metas) + x = self.extract_feat(img) + cost_matrix = np.asarray(x[0].cpu().detach()) + contain_nan = (True in np.isnan(cost_matrix)) + if contain_nan: + a = 1 + print('Find!!!') + for i in range(len(img_metas)): + print('The image is', img_metas[i]['file_name']) + losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes, + gt_labels, gt_bboxes_ignore) + return losses + + def imshow_gpu_tensor(self, tensor): + from PIL import Image + from torchvision import transforms + device = tensor[0].device + mean = torch.tensor([123.675, 116.28, 103.53]) + std = torch.tensor([58.395, 57.12, 57.375]) + mean = mean.to(device) + std = std.to(device) + tensor = (tensor[0].squeeze() * std[:, None, None]) + mean[:, None, None] + tensor = tensor[0:1] + if len(tensor.shape) == 4: + image = tensor.permute(0, 2, 3, 1).cpu().clone().numpy() + else: + image = tensor.permute(1, 2, 0).cpu().clone().numpy() + image = image.astype(np.uint8).squeeze() + image = transforms.ToPILImage()(image) + image = image.resize((256, 256), Image.ANTIALIAS) + image.show(image) + + # def simple_test(self, img, img_metas, rescale=False): + # cfg = self.test_cfg + + # feat = self.extract_feat(img) + # results_list = self.bbox_head.simple_test_bboxes(feat, img_metas, rescale=rescale) + # # bbox_list = self.bbox_head.get_bboxes( + # # *outs, img_metas, rescale=rescale) + # # skip post-processing when exporting to ONNX + # bbox_results = [ + # self.rbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) + # for det_bboxes, det_labels in results_list] + # return bbox_results + + + def simple_test(self, img, img_metas, rescale=False): + # 为了使用裁剪小图策略推理标准模型 + """Test function without test time augmentation. + Args: + imgs (list[torch.Tensor]): List of multiple images + img_metas (list[dict]): List of image information. + rescale (bool, optional): Whether to rescale the results. + Defaults to False. + + Returns: + list[list[np.ndarray]]: BBox results of each image and classes. \ + The outer list corresponds to each image. The inner list \ + corresponds to each class. + """ + gaps = [200] + patch_shape = (1024, 1024) + p_bs = 4 # patch batchsize + gt_bboxes=[] + gt_labels=[] + p_imgs, p_metas = FullImageCrop(self, img, gt_bboxes, gt_labels, + patch_shape=patch_shape, + gaps=gaps, mode='test') + + local_bboxes_lists=[] + for i in range(img.shape[0]): + j = 0 + p_imgs[i]=torch.stack(p_imgs[i], dim=0) + patches=p_imgs[i] + patches_meta = p_metas[i] + + while j < len(p_imgs[i]): + if (j+p_bs) >= len(p_imgs[i]): + patch = patches[j:] + patch = patches[j:] + patch_meta = patches_meta[j:] + else: + patch = patches[j:j + p_bs] + patch_meta = patches_meta[j:j + p_bs] # x_start and y_start + + with torch.no_grad(): + fea_l_neck = self.extract_feat(patch) + local_results_list = self.bbox_head.simple_test_bboxes(fea_l_neck, patch_meta, rescale=False) + + for idx, res_list in enumerate(local_results_list): + det_bboxes, det_labels = res_list + relocate(idx, det_bboxes, patch_meta) + local_bboxes_lists.append(local_results_list) + j = j+p_bs + + bbox_list = [merge_results([local_bboxes_lists],iou_thr=0.4)] + + bbox_results = [ + rbbox2result(det_bboxes, det_labels, self.bbox_head.num_classes) + for det_bboxes, det_labels in bbox_list + ] + + return bbox_results diff --git a/mmrotate/models/detectors/s2anet_crop.py b/mmrotate/models/detectors/s2anet_crop.py index f9dae48..5da6bd8 100644 --- a/mmrotate/models/detectors/s2anet_crop.py +++ b/mmrotate/models/detectors/s2anet_crop.py @@ -343,7 +343,6 @@ def simple_test(self, img, img_metas, rescale=False): # patch batchsize while j < len(p_imgs[i]): if (j+p_bs) >= len(p_imgs[i]): - patch = patches[j:] patch = patches[j:] patch_meta = patches_meta[j:] else: