diff --git a/configs/rotated_rtmdet/_base_/tta.py b/configs/rotated_rtmdet/_base_/tta.py new file mode 100644 index 000000000..5052d9e17 --- /dev/null +++ b/configs/rotated_rtmdet/_base_/tta.py @@ -0,0 +1,35 @@ +tta_model = dict( + type='RotatedTTAModel', + tta_cfg=dict(nms=dict(type='nms_rotated', iou_threshold=0.1), max_per_img=2000)) + +img_scales = [(1024, 1024), (800, 800), (1200, 1200)] +tta_pipeline = [ + dict(type='mmdet.LoadImageFromFile', file_client_args=dict(backend='disk')), + dict( + type='mmdet.TestTimeAug', + transforms=[ + [ + dict(type='mmdet.Resize', scale=s, keep_ratio=True) + for s in img_scales + ], + [ + # ``RandomFlip`` must be placed before ``Pad``, otherwise + # bounding box coordinates after flipping cannot be + # recovered correctly. + dict(type='mmdet.RandomFlip', prob=1.), + dict(type='mmdet.RandomFlip', prob=0.) + ], + [ + dict( + type='mmdet.Pad', + size=(1200, 1200), + pad_val=dict(img=(114, 114, 114))), + ], + [ + dict( + type='mmdet.PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction')) + ] + ]) +] \ No newline at end of file diff --git a/configs/rotated_rtmdet/rotated_rtmdet_l-3x-dota.py b/configs/rotated_rtmdet/rotated_rtmdet_l-3x-dota.py index 7587fbb19..db848bee1 100644 --- a/configs/rotated_rtmdet/rotated_rtmdet_l-3x-dota.py +++ b/configs/rotated_rtmdet/rotated_rtmdet_l-3x-dota.py @@ -1,6 +1,6 @@ _base_ = [ './_base_/default_runtime.py', './_base_/schedule_3x.py', - './_base_/dota_rr.py' + './_base_/dota_rr.py', './_base_/tta.py' ] checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-l_8xb256-rsb-a1-600e_in1k-6a760974.pth' # noqa diff --git a/mmrotate/models/__init__.py b/mmrotate/models/__init__.py index ca7701153..8de05b912 100644 --- a/mmrotate/models/__init__.py +++ b/mmrotate/models/__init__.py @@ -7,4 +7,5 @@ from .necks import * # noqa: F401, F403 from .roi_heads import * # noqa: F401, F403 from .task_modules import * # noqa: F401,F403 +from .test_time_augs import * # noqa: F401, F403 from .utils import * # noqa: F401, F403 diff --git a/mmrotate/models/test_time_augs/__init__.py b/mmrotate/models/test_time_augs/__init__.py new file mode 100644 index 000000000..329cb3499 --- /dev/null +++ b/mmrotate/models/test_time_augs/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .rotated_tta import RotatedTTAModel + + +__all__ = [ + 'RotatedTTAModel' +] \ No newline at end of file diff --git a/mmrotate/models/test_time_augs/rotated_tta.py b/mmrotate/models/test_time_augs/rotated_tta.py new file mode 100644 index 000000000..d019d32d0 --- /dev/null +++ b/mmrotate/models/test_time_augs/rotated_tta.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch + +from torch import Tensor + +from mmdet.models.test_time_augs import DetTTAModel +from mmrotate.registry import MODELS + + +def bbox_flip(bboxes: Tensor, + img_shape: Tuple[int], + direction: str = 'horizontal') -> Tensor: + """Flip bboxes horizontally or vertically. + + Args: + bboxes (Tensor): Shape (..., 5*k) + img_shape (Tuple[int]): Image shape. + direction (str): Flip direction, options are "horizontal", "vertical", + "diagonal". Default: "horizontal" + + Returns: + Tensor: Flipped bboxes. + """ + assert bboxes.shape[-1] % 5 == 0 + assert direction in ['horizontal', 'vertical', 'diagonal'] + flipped = bboxes.clone() + if direction == 'horizontal': + flipped[..., 0] = img_shape[1] - flipped[..., 0] + flipped[..., 4] = -flipped[..., 4] + elif direction == 'vertical': + flipped[..., 1] = img_shape[0] - flipped[..., 1] + flipped[..., 4] = -flipped[..., 4] + else: + flipped[..., 0] = img_shape[1] - flipped[..., 0] + flipped[..., 1] = img_shape[0] - flipped[..., 1] + return flipped + +@MODELS.register_module() +class RotatedTTAModel(DetTTAModel): + + def merge_aug_bboxes(self, aug_bboxes: List[Tensor], + aug_scores: List[Tensor], + img_metas: List[str]) -> Tuple[Tensor, Tensor]: + """Merge augmented detection bboxes and scores. + Args: + aug_bboxes (list[Tensor]): shape (n, 5*#class) + aug_scores (list[Tensor] or None): shape (n, #class) + Returns: + tuple[Tensor]: ``bboxes`` with shape (n,5), where + 4 represent (x, y, w, h, t) + and ``scores`` with shape (n,). + """ + recovered_bboxes = [] + for bboxes, img_info in zip(aug_bboxes, img_metas): + ori_shape = img_info['ori_shape'] + flip = img_info['flip'] + flip_direction = img_info['flip_direction'] + if flip: + bboxes = bbox_flip( + bboxes=bboxes, + img_shape=ori_shape, + direction=flip_direction) + recovered_bboxes.append(bboxes) + bboxes = torch.cat(recovered_bboxes, dim=0) + if aug_scores is None: + return bboxes + else: + scores = torch.cat(aug_scores, dim=0) + return bboxes, scores \ No newline at end of file diff --git a/tools/test.py b/tools/test.py index 4d416112d..9f5498fd7 100644 --- a/tools/test.py +++ b/tools/test.py @@ -4,7 +4,7 @@ import os.path as osp from mmdet.utils import register_all_modules as register_all_modules_mmdet -from mmengine.config import Config, DictAction +from mmengine.config import Config, ConfigDict, DictAction from mmengine.evaluator import DumpResults from mmengine.registry import RUNNERS from mmengine.runner import Runner @@ -24,6 +24,10 @@ def parse_args(): '--out', type=str, help='dump predictions to a pickle file for offline evaluation') + parser.add_argument( + '--tta', + action='store_true', + help='Whether to use test time augmentation') parser.add_argument( '--show', action='store_true', help='show prediction results') parser.add_argument( @@ -103,6 +107,19 @@ def main(): if args.show or args.show_dir: cfg = trigger_visualization_hook(cfg, args) + if args.tta: + assert 'tta_model' in cfg, 'Cannot find ``tta_model`` in config.' \ + " Can't use tta !" + assert 'tta_pipeline' in cfg, 'Cannot find ``tta_pipeline`` ' \ + "in config. Can't use tta !" + + cfg.model = ConfigDict(**cfg.tta_model, module=cfg.model) + test_data_cfg = cfg.test_dataloader.dataset + while 'dataset' in test_data_cfg: + test_data_cfg = test_data_cfg['dataset'] + + test_data_cfg.pipeline = cfg.tta_pipeline + # build the runner from config if 'runner_type' not in cfg: # build the default runner