diff --git a/mmdet/datasets/transforms/__init__.py b/mmdet/datasets/transforms/__init__.py index c8c40f3660c..b5ab3758382 100644 --- a/mmdet/datasets/transforms/__init__.py +++ b/mmdet/datasets/transforms/__init__.py @@ -18,8 +18,8 @@ MinIoURandomCrop, MixUp, Mosaic, Pad, PhotoMetricDistortion, RandomAffine, RandomCenterCropPad, RandomCrop, RandomErasing, - RandomFlip, RandomShift, Resize, SegRescale, - YOLOXHSVRandomAug) + RandomFlip, RandomShift, Resize, ResizeShortestEdge, + SegRescale, YOLOXHSVRandomAug) from .wrappers import MultiBranch, ProposalBroadcaster, RandomOrder __all__ = [ @@ -37,5 +37,5 @@ 'LoadEmptyAnnotations', 'RandomOrder', 'CachedMosaic', 'CachedMixUp', 'FixShapeResize', 'ProposalBroadcaster', 'InferencerLoader', 'LoadTrackAnnotations', 'BaseFrameSample', 'UniformRefFrameSample', - 'PackTrackInputs', 'PackReIDInputs', 'FixScaleResize' + 'PackTrackInputs', 'PackReIDInputs', 'FixScaleResize', 'ResizeShortestEdge' ] diff --git a/mmdet/datasets/transforms/transforms.py b/mmdet/datasets/transforms/transforms.py index f128bd82674..9768d71eaed 100644 --- a/mmdet/datasets/transforms/transforms.py +++ b/mmdet/datasets/transforms/transforms.py @@ -60,8 +60,7 @@ def _fixed_scale_size( def rescale_size(old_size: tuple, scale: Union[float, int, tuple], - return_scale: bool = False, - short_side_mode: bool = False) -> tuple: + return_scale: bool = False) -> tuple: """Calculate the new size to be rescaled to. Args: @@ -84,12 +83,8 @@ def rescale_size(old_size: tuple, elif isinstance(scale, tuple): max_long_edge = max(scale) max_short_edge = min(scale) - if short_side_mode: - short, long = (w, h) if w <= h else (h, w) - scale_factor = max_short_edge / short - else: - scale_factor = min(max_long_edge / max(h, w), - max_short_edge / min(h, w)) + scale_factor = min(max_long_edge / max(h, w), + max_short_edge / min(h, w)) else: raise TypeError( f'Scale must be a number or tuple of int, but got {type(scale)}') @@ -107,8 +102,7 @@ def imrescale( scale: Union[float, Tuple[int, int]], return_scale: bool = False, interpolation: str = 'bilinear', - backend: Optional[str] = None, - short_side_mode: bool = False, + backend: Optional[str] = None ) -> Union[np.ndarray, Tuple[np.ndarray, float]]: """Resize image while keeping the aspect ratio. @@ -127,10 +121,7 @@ def imrescale( ndarray: The rescaled image. """ h, w = img.shape[:2] - new_size, scale_factor = rescale_size((w, h), - scale, - return_scale=True, - short_side_mode=short_side_mode) + new_size, scale_factor = rescale_size((w, h), scale, return_scale=True) rescaled_img = imresize( img, new_size, interpolation=interpolation, backend=backend) if return_scale: @@ -259,17 +250,6 @@ class FixScaleResize(Resize): """Compared to Resize, FixScaleResize fixes the scaling issue when `keep_ratio=true`.""" - def __init__( - self, - *args, - short_side_mode: bool = False, - **kwargs, - ) -> None: - super().__init__(*args, **kwargs) - self.short_side_mode = short_side_mode - if short_side_mode is True: - assert self.scale and self.keep_ratio is True - def _resize_img(self, results): """Resize images with ``results['scale']``.""" if results.get('img', None) is not None: @@ -279,8 +259,7 @@ def _resize_img(self, results): results['scale'], interpolation=self.interpolation, return_scale=True, - backend=self.backend, - short_side_mode=self.short_side_mode) + backend=self.backend) new_h, new_w = img.shape[:2] h, w = results['img'].shape[:2] w_scale = new_w / w @@ -298,6 +277,81 @@ def _resize_img(self, results): results['keep_ratio'] = self.keep_ratio +@TRANSFORMS.register_module() +class ResizeShortestEdge(BaseTransform): + """Resize the image and mask while keeping the aspect ratio unchanged. + + Modified from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/augmentation_impl.py#L130 # noqa:E501 + + This transform attempts to scale the shorter edge to the given + `scale`, as long as the longer edge does not exceed `max_size`. + If `max_size` is reached, then downscale so that the longer + edge does not exceed `max_size`. + + Required Keys: + - img + - gt_seg_map (optional) + Modified Keys: + - img + - img_shape + - gt_seg_map (optional)) + Added Keys: + - scale + - scale_factor + - keep_ratio + + Args: + scale (Union[int, Tuple[int, int]]): The target short edge length. + If it's tuple, will select the min value as the short edge length. + max_size (int): The maximum allowed longest edge length. + """ + + def __init__(self, + scale: Union[int, Tuple[int, int]], + max_size: Optional[int] = None, + resize_type: str = 'Resize', + **resize_kwargs) -> None: + super().__init__() + self.scale = scale + self.max_size = max_size + + self.resize_cfg = dict(type=resize_type, **resize_kwargs) + self.resize = TRANSFORMS.build({'scale': 0, **self.resize_cfg}) + + def _get_output_shape(self, img, short_edge_length) -> Tuple[int, int]: + """Compute the target image shape with the given `short_edge_length`. + + Args: + img (np.ndarray): The input image. + short_edge_length (Union[int, Tuple[int, int]]): The target short + edge length. If it's tuple, will select the min value as the + short edge length. + """ + h, w = img.shape[:2] + if isinstance(short_edge_length, int): + size = short_edge_length * 1.0 + elif isinstance(short_edge_length, tuple): + size = min(short_edge_length) * 1.0 + scale = size / min(h, w) + if h < w: + new_h, new_w = size, scale * w + else: + new_h, new_w = scale * h, size + + if self.max_size and max(new_h, new_w) > self.max_size: + scale = self.max_size * 1.0 / max(new_h, new_w) + new_h *= scale + new_w *= scale + + new_h = int(new_h + 0.5) + new_w = int(new_w + 0.5) + return (new_w, new_h) + + def transform(self, results: dict) -> dict: + self.resize.scale = self._get_output_shape(results['img'], self.scale) + return self.resize(results) + + @TRANSFORMS.register_module() class FixShapeResize(Resize): """Resize images & bbox & seg to the specified size. diff --git a/projects/XDecoder/README.md b/projects/XDecoder/README.md index 473424ac28b..9b235b55798 100644 --- a/projects/XDecoder/README.md +++ b/projects/XDecoder/README.md @@ -8,7 +8,9 @@ We present X-Decoder, a generalized decoding model that can predict pixel-level segmentation and language tokens seamlessly. X-Decodert takes as input two types of queries: (i) generic non-semantic queries and (ii) semantic queries induced from text inputs, to decode different pixel-level and token-level outputs in the same semantic space. With such a novel design, X-Decoder is the first work that provides a unified way to support all types of image segmentation and a variety of vision-language (VL) tasks. Further, our design enables seamless interactions across tasks at different granularities and brings mutual benefits by learning a common and rich pixel-level visual-semantic understanding space, without any pseudo-labeling. After pretraining on a mixed set of a limited amount of segmentation data and millions of image-text pairs, X-Decoder exhibits strong transferability to a wide range of downstream tasks in both zero-shot and finetuning settings. Notably, it achieves (1) state-of-the-art results on open-vocabulary segmentation and referring segmentation on eight datasets; (2) better or competitive finetuned performance to other generalist and specialist models on segmentation and VL tasks; and (3) flexibility for efficient finetuning and novel task composition (e.g., referring captioning and image editing). -![img](https://raw.githubusercontent.com/microsoft/X-Decoder/main/images/teaser_new.png) +
+ +
## Installation @@ -22,17 +24,96 @@ mim install mmdet[multimodal] ## How to use it? -## Models and results - For convenience, you can download the weights to the `mmdetection` root dir ```shell -wget https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_best_openseg.pt wget https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_last_novg.pt +wget https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_best_openseg.pt ``` The above two weights are directly copied from the official website without any modification. The specific source is https://github.com/microsoft/X-Decoder +For convenience of demonstration, please download [the folder](https://github.com/microsoft/X-Decoder/tree/main/images) and place it in the root directory of mmdetection. + +**(1) Open Vocabulary Semantic Segmentation** + +```shell +cd projects/XDecoder +python demo.py ../../images/animals.png configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py --weights ../../xdecoder_focalt_last_novg.pt --texts zebra.giraffe +``` + +
+ +
+ +**(2) Open Vocabulary Instance Segmentation** + +```shell +cd projects/XDecoder +python demo.py ../../images/owls.jpeg configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py --weights ../../xdecoder_focalt_last_novg.pt --texts owl +``` + +
+ +
+ +**(3) Open Vocabulary Panoptic Segmentation** + +```shell +cd projects/XDecoder +python demo.py ../../images/street.jpg configs/xdecoder-tiny_zeroshot_open-vocab-panoptic_coco.py --weights ../../xdecoder_focalt_last_novg.pt --text car.person --stuff-text tree.sky +``` + +
+ +
+ +**(4) Referring Expression Segmentation** + +```shell +cd projects/XDecoder +python demo.py ../../images/fruit.jpg configs/xdecoder-tiny_zeroshot_open-vocab-ref-seg_refcocog.py --weights ../../xdecoder_focalt_last_novg.pt --text "The larger watermelon. The front white flower. White tea pot." +``` + +**(5) Image Caption** + +```shell +cd projects/XDecoder +python demo.py ../../images/penguin.jpeg configs/xdecoder-tiny_zeroshot_caption_coco2014.py --weights ../../xdecoder_focalt_last_novg.pt +``` + +
+ +
+ +**(6) Referring Expression Image Caption** + +```shell +cd projects/XDecoder +python demo.py ../../images/fruit.jpg configs/xdecoder-tiny_zeroshot_ref-caption.py --weights ../../xdecoder_focalt_last_novg.pt --text 'White tea pot' +``` + +
+ +
+ +**(7) Text Image Region Retrieval** + +```shell +cd projects/XDecoder +python demo.py ../../images/coco configs/xdecoder-tiny_zeroshot_text-image-retrieval.py --weights ../../xdecoder_focalt_last_novg.pt --text 'pizza on the plate' +``` + +```text +The image that best matches the given text is ../../images/coco/000.jpg and probability is 0.998 +``` + +
+ +
+ +## Models and results + ### Semantic segmentation on ADE20K Prepare your dataset according to the [docs](https://mmsegmentation.readthedocs.io/en/latest/user_guides/2_dataset_prepare.html#ade20k). @@ -64,8 +145,8 @@ Prepare your dataset according to the [docs](https://mmdetection.readthedocs.io/ ``` | Model | mIOU | mIOU(official) | Config | -| :------------------------------------------------ |:----:|---------------:| :----------------------------------------------------------------: | -| `xdecoder-tiny_zeroshot_open-vocab-semseg_coco`\* | 61.8 | 62.1 | [config](configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py) | +| :------------------------------------------------ | :--: | -------------: | :----------------------------------------------------------------: | +| `xdecoder-tiny_zeroshot_open-vocab-semseg_coco`\* | 62.1 | 62.1 | [config](configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py) | ### Instance segmentation on COCO2017 @@ -77,9 +158,9 @@ Prepare your dataset according to the [docs](https://mmdetection.readthedocs.io/ ./tools/dist_test.sh projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py xdecoder_focalt_last_novg.pt 8 ``` -| Model | mAP | mAP(official) | Config | -| :-------------------------------------------------- | :--: | ------------: | :------------------------------------------------------------------: | -| `xdecoder-tiny_zeroshot_open-vocab-instance_coco`\* | 39.7 | 39.7 | [config](configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py) | +| Model | Mask mAP | Mask mAP(official) | Config | +| :-------------------------------------------------- | :------: | -----------------: | :------------------------------------------------------------------: | +| `xdecoder-tiny_zeroshot_open-vocab-instance_coco`\* | 39.8 | 39.7 | [config](configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py) | ### Panoptic segmentation on COCO2017 @@ -111,7 +192,7 @@ Before testing, you need to install jdk 1.8, otherwise it will prompt that java | Model | BLEU-4 | CIDER | Config | | :------------------------------------------ | :----: | :----: | :----------------------------------------------------------: | -| `xdecoder-tiny_zeroshot_caption_coco2014`\* | 35.14 | 116.62 | [config](configs/xdecoder-tiny_zeroshot_caption_coco2014.py) | +| `xdecoder-tiny_zeroshot_caption_coco2014`\* | 35.26 | 116.81 | [config](configs/xdecoder-tiny_zeroshot_caption_coco2014.py) | ## Citation diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_caption_coco2014.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_caption_coco2014.py index 7af7080782b..963c7c61e09 100644 --- a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_caption_coco2014.py +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_caption_coco2014.py @@ -7,13 +7,7 @@ type='LoadImageFromFile', imdecode_backend='pillow', backend_args=_base_.backend_args), - dict( - type='FixScaleResize', - scale=224, - keep_ratio=True, - short_side_mode=True, - backend='pillow', - interpolation='bicubic'), + dict(type='ResizeShortestEdge', scale=224, backend='pillow'), dict( type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py index 0183ee5bc32..512a70824c8 100644 --- a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py @@ -8,7 +8,8 @@ type='LoadImageFromFile', imdecode_backend='pillow', backend_args=_base_.backend_args), - dict(type='Resize', scale=(1333, 800), backend='pillow', keep_ratio=True), + dict( + type='ResizeShortestEdge', scale=800, max_size=1333, backend='pillow'), dict( type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-panoptic_coco.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-panoptic_coco.py index 690bd4ba340..b0b2712a4ed 100644 --- a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-panoptic_coco.py +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-panoptic_coco.py @@ -8,7 +8,8 @@ type='LoadImageFromFile', imdecode_backend='pillow', backend_args=_base_.backend_args), - dict(type='Resize', scale=(1333, 800), backend='pillow', keep_ratio=True), + dict( + type='ResizeShortestEdge', scale=800, max_size=1333, backend='pillow'), dict( type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_ade20k.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_ade20k.py index 7c4e3367093..dc06c568722 100644 --- a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_ade20k.py +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_ade20k.py @@ -8,13 +8,7 @@ type='LoadImageFromFile', imdecode_backend='pillow', backend_args=_base_.backend_args), - dict( - type='FixScaleResize', - scale=640, - keep_ratio=True, - short_side_mode=True, - backend='pillow', - interpolation='bicubic'), + dict(type='ResizeShortestEdge', scale=640, backend='pillow'), dict( type='LoadAnnotations', with_bbox=False, diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py index 58458247939..cd9a7eccfe6 100644 --- a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py @@ -7,7 +7,8 @@ dict( type='LoadImageFromFile', imdecode_backend='pillow', backend_args=None), - dict(type='Resize', scale=(1333, 800), backend='pillow', keep_ratio=True), + dict( + type='ResizeShortestEdge', scale=800, max_size=1333, backend='pillow'), dict( type='LoadAnnotations', with_bbox=False, diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_ref-caption.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_ref-caption.py index 0220c10e304..fc81af198f9 100644 --- a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_ref-caption.py +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_ref-caption.py @@ -6,13 +6,7 @@ test_pipeline = [ dict(type='LoadImageFromFile', imdecode_backend='pillow'), - dict( - type='FixScaleResize', - scale=224, - keep_ratio=True, - short_side_mode=True, - backend='pillow', - interpolation='bicubic'), + dict(type='ResizeShortestEdge', scale=224, backend='pillow'), dict( type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_text-image-retrieval.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_text-image-retrieval.py index a77e63848fe..7523e045273 100644 --- a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_text-image-retrieval.py +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_text-image-retrieval.py @@ -10,10 +10,8 @@ imdecode_backend='pillow', backend_args=_base_.backend_args), dict( - type='FixScaleResize', + type='ResizeShortestEdge', scale=224, - keep_ratio=True, - short_side_mode=True, backend='pillow', interpolation='bicubic'), dict(