Skip to content

Commit

Permalink
Fix cfg (#10503)
Browse files Browse the repository at this point in the history
  • Loading branch information
hhaAndroid committed Jun 14, 2023
1 parent c7e171c commit 429c2a4
Show file tree
Hide file tree
Showing 10 changed files with 185 additions and 67 deletions.
6 changes: 3 additions & 3 deletions mmdet/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
MinIoURandomCrop, MixUp, Mosaic, Pad,
PhotoMetricDistortion, RandomAffine,
RandomCenterCropPad, RandomCrop, RandomErasing,
RandomFlip, RandomShift, Resize, SegRescale,
YOLOXHSVRandomAug)
RandomFlip, RandomShift, Resize, ResizeShortestEdge,
SegRescale, YOLOXHSVRandomAug)
from .wrappers import MultiBranch, ProposalBroadcaster, RandomOrder

__all__ = [
Expand All @@ -37,5 +37,5 @@
'LoadEmptyAnnotations', 'RandomOrder', 'CachedMosaic', 'CachedMixUp',
'FixShapeResize', 'ProposalBroadcaster', 'InferencerLoader',
'LoadTrackAnnotations', 'BaseFrameSample', 'UniformRefFrameSample',
'PackTrackInputs', 'PackReIDInputs', 'FixScaleResize'
'PackTrackInputs', 'PackReIDInputs', 'FixScaleResize', 'ResizeShortestEdge'
]
108 changes: 81 additions & 27 deletions mmdet/datasets/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def _fixed_scale_size(

def rescale_size(old_size: tuple,
scale: Union[float, int, tuple],
return_scale: bool = False,
short_side_mode: bool = False) -> tuple:
return_scale: bool = False) -> tuple:
"""Calculate the new size to be rescaled to.
Args:
Expand All @@ -84,12 +83,8 @@ def rescale_size(old_size: tuple,
elif isinstance(scale, tuple):
max_long_edge = max(scale)
max_short_edge = min(scale)
if short_side_mode:
short, long = (w, h) if w <= h else (h, w)
scale_factor = max_short_edge / short
else:
scale_factor = min(max_long_edge / max(h, w),
max_short_edge / min(h, w))
scale_factor = min(max_long_edge / max(h, w),
max_short_edge / min(h, w))
else:
raise TypeError(
f'Scale must be a number or tuple of int, but got {type(scale)}')
Expand All @@ -107,8 +102,7 @@ def imrescale(
scale: Union[float, Tuple[int, int]],
return_scale: bool = False,
interpolation: str = 'bilinear',
backend: Optional[str] = None,
short_side_mode: bool = False,
backend: Optional[str] = None
) -> Union[np.ndarray, Tuple[np.ndarray, float]]:
"""Resize image while keeping the aspect ratio.
Expand All @@ -127,10 +121,7 @@ def imrescale(
ndarray: The rescaled image.
"""
h, w = img.shape[:2]
new_size, scale_factor = rescale_size((w, h),
scale,
return_scale=True,
short_side_mode=short_side_mode)
new_size, scale_factor = rescale_size((w, h), scale, return_scale=True)
rescaled_img = imresize(
img, new_size, interpolation=interpolation, backend=backend)
if return_scale:
Expand Down Expand Up @@ -259,17 +250,6 @@ class FixScaleResize(Resize):
"""Compared to Resize, FixScaleResize fixes the scaling issue when
`keep_ratio=true`."""

def __init__(
self,
*args,
short_side_mode: bool = False,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.short_side_mode = short_side_mode
if short_side_mode is True:
assert self.scale and self.keep_ratio is True

def _resize_img(self, results):
"""Resize images with ``results['scale']``."""
if results.get('img', None) is not None:
Expand All @@ -279,8 +259,7 @@ def _resize_img(self, results):
results['scale'],
interpolation=self.interpolation,
return_scale=True,
backend=self.backend,
short_side_mode=self.short_side_mode)
backend=self.backend)
new_h, new_w = img.shape[:2]
h, w = results['img'].shape[:2]
w_scale = new_w / w
Expand All @@ -298,6 +277,81 @@ def _resize_img(self, results):
results['keep_ratio'] = self.keep_ratio


@TRANSFORMS.register_module()
class ResizeShortestEdge(BaseTransform):
"""Resize the image and mask while keeping the aspect ratio unchanged.
Modified from https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/transforms/augmentation_impl.py#L130 # noqa:E501
This transform attempts to scale the shorter edge to the given
`scale`, as long as the longer edge does not exceed `max_size`.
If `max_size` is reached, then downscale so that the longer
edge does not exceed `max_size`.
Required Keys:
- img
- gt_seg_map (optional)
Modified Keys:
- img
- img_shape
- gt_seg_map (optional))
Added Keys:
- scale
- scale_factor
- keep_ratio
Args:
scale (Union[int, Tuple[int, int]]): The target short edge length.
If it's tuple, will select the min value as the short edge length.
max_size (int): The maximum allowed longest edge length.
"""

def __init__(self,
scale: Union[int, Tuple[int, int]],
max_size: Optional[int] = None,
resize_type: str = 'Resize',
**resize_kwargs) -> None:
super().__init__()
self.scale = scale
self.max_size = max_size

self.resize_cfg = dict(type=resize_type, **resize_kwargs)
self.resize = TRANSFORMS.build({'scale': 0, **self.resize_cfg})

def _get_output_shape(self, img, short_edge_length) -> Tuple[int, int]:
"""Compute the target image shape with the given `short_edge_length`.
Args:
img (np.ndarray): The input image.
short_edge_length (Union[int, Tuple[int, int]]): The target short
edge length. If it's tuple, will select the min value as the
short edge length.
"""
h, w = img.shape[:2]
if isinstance(short_edge_length, int):
size = short_edge_length * 1.0
elif isinstance(short_edge_length, tuple):
size = min(short_edge_length) * 1.0
scale = size / min(h, w)
if h < w:
new_h, new_w = size, scale * w
else:
new_h, new_w = scale * h, size

if self.max_size and max(new_h, new_w) > self.max_size:
scale = self.max_size * 1.0 / max(new_h, new_w)
new_h *= scale
new_w *= scale

new_h = int(new_h + 0.5)
new_w = int(new_w + 0.5)
return (new_w, new_h)

def transform(self, results: dict) -> dict:
self.resize.scale = self._get_output_shape(results['img'], self.scale)
return self.resize(results)


@TRANSFORMS.register_module()
class FixShapeResize(Resize):
"""Resize images & bbox & seg to the specified size.
Expand Down
101 changes: 91 additions & 10 deletions projects/XDecoder/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

We present X-Decoder, a generalized decoding model that can predict pixel-level segmentation and language tokens seamlessly. X-Decodert takes as input two types of queries: (i) generic non-semantic queries and (ii) semantic queries induced from text inputs, to decode different pixel-level and token-level outputs in the same semantic space. With such a novel design, X-Decoder is the first work that provides a unified way to support all types of image segmentation and a variety of vision-language (VL) tasks. Further, our design enables seamless interactions across tasks at different granularities and brings mutual benefits by learning a common and rich pixel-level visual-semantic understanding space, without any pseudo-labeling. After pretraining on a mixed set of a limited amount of segmentation data and millions of image-text pairs, X-Decoder exhibits strong transferability to a wide range of downstream tasks in both zero-shot and finetuning settings. Notably, it achieves (1) state-of-the-art results on open-vocabulary segmentation and referring segmentation on eight datasets; (2) better or competitive finetuned performance to other generalist and specialist models on segmentation and VL tasks; and (3) flexibility for efficient finetuning and novel task composition (e.g., referring captioning and image editing).

![img](https://raw.githubusercontent.com/microsoft/X-Decoder/main/images/teaser_new.png)
<div align=center>
<img src="https://github.com/open-mmlab/mmdetection/assets/17425982/cb126615-9402-4c19-8ea9-133722d7519c" width="70%"/>
</div>

## Installation

Expand All @@ -22,17 +24,96 @@ mim install mmdet[multimodal]

## How to use it?

## Models and results

For convenience, you can download the weights to the `mmdetection` root dir

```shell
wget https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_best_openseg.pt
wget https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_last_novg.pt
wget https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_best_openseg.pt
```

The above two weights are directly copied from the official website without any modification. The specific source is https://github.com/microsoft/X-Decoder

For convenience of demonstration, please download [the folder](https://github.com/microsoft/X-Decoder/tree/main/images) and place it in the root directory of mmdetection.

**(1) Open Vocabulary Semantic Segmentation**

```shell
cd projects/XDecoder
python demo.py ../../images/animals.png configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py --weights ../../xdecoder_focalt_last_novg.pt --texts zebra.giraffe
```

<div align=center>
<img src="https://github.com/open-mmlab/mmdetection/assets/17425982/c397c0ed-859a-4004-8725-78a591742bc8" width="70%"/>
</div>

**(2) Open Vocabulary Instance Segmentation**

```shell
cd projects/XDecoder
python demo.py ../../images/owls.jpeg configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py --weights ../../xdecoder_focalt_last_novg.pt --texts owl
```

<div align=center>
<img src="https://github.com/open-mmlab/mmdetection/assets/17425982/494b0b1c-4a42-4019-97ae-d33ee68af3d2" width="70%"/>
</div>

**(3) Open Vocabulary Panoptic Segmentation**

```shell
cd projects/XDecoder
python demo.py ../../images/street.jpg configs/xdecoder-tiny_zeroshot_open-vocab-panoptic_coco.py --weights ../../xdecoder_focalt_last_novg.pt --text car.person --stuff-text tree.sky
```

<div align=center>
<img src="https://github.com/open-mmlab/mmdetection/assets/17425982/9ad1e0f4-75ce-4e37-a5cc-83e0e8a722ed" width="70%"/>
</div>

**(4) Referring Expression Segmentation**

```shell
cd projects/XDecoder
python demo.py ../../images/fruit.jpg configs/xdecoder-tiny_zeroshot_open-vocab-ref-seg_refcocog.py --weights ../../xdecoder_focalt_last_novg.pt --text "The larger watermelon. The front white flower. White tea pot."
```

**(5) Image Caption**

```shell
cd projects/XDecoder
python demo.py ../../images/penguin.jpeg configs/xdecoder-tiny_zeroshot_caption_coco2014.py --weights ../../xdecoder_focalt_last_novg.pt
```

<div align=center>
<img src="https://github.com/open-mmlab/mmdetection/assets/17425982/7690ab79-791e-4011-ab0c-01f46c4a3d80" width="70%"/>
</div>

**(6) Referring Expression Image Caption**

```shell
cd projects/XDecoder
python demo.py ../../images/fruit.jpg configs/xdecoder-tiny_zeroshot_ref-caption.py --weights ../../xdecoder_focalt_last_novg.pt --text 'White tea pot'
```

<div align=center>
<img src="https://github.com/open-mmlab/mmdetection/assets/17425982/bae2fdba-0172-4fc8-8ad1-73b54c64ec30" width="70%"/>
</div>

**(7) Text Image Region Retrieval**

```shell
cd projects/XDecoder
python demo.py ../../images/coco configs/xdecoder-tiny_zeroshot_text-image-retrieval.py --weights ../../xdecoder_focalt_last_novg.pt --text 'pizza on the plate'
```

```text
The image that best matches the given text is ../../images/coco/000.jpg and probability is 0.998
```

<div align=center>
<img src="https://github.com/open-mmlab/mmdetection/assets/17425982/479de6b2-88e7-41f0-8228-4b9a48f52954" width="70%"/>
</div>

## Models and results

### Semantic segmentation on ADE20K

Prepare your dataset according to the [docs](https://mmsegmentation.readthedocs.io/en/latest/user_guides/2_dataset_prepare.html#ade20k).
Expand Down Expand Up @@ -64,8 +145,8 @@ Prepare your dataset according to the [docs](https://mmdetection.readthedocs.io/
```

| Model | mIOU | mIOU(official) | Config |
| :------------------------------------------------ |:----:|---------------:| :----------------------------------------------------------------: |
| `xdecoder-tiny_zeroshot_open-vocab-semseg_coco`\* | 61.8 | 62.1 | [config](configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py) |
| :------------------------------------------------ | :--: | -------------: | :----------------------------------------------------------------: |
| `xdecoder-tiny_zeroshot_open-vocab-semseg_coco`\* | 62.1 | 62.1 | [config](configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py) |

### Instance segmentation on COCO2017

Expand All @@ -77,9 +158,9 @@ Prepare your dataset according to the [docs](https://mmdetection.readthedocs.io/
./tools/dist_test.sh projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py xdecoder_focalt_last_novg.pt 8
```

| Model | mAP | mAP(official) | Config |
| :-------------------------------------------------- | :--: | ------------: | :------------------------------------------------------------------: |
| `xdecoder-tiny_zeroshot_open-vocab-instance_coco`\* | 39.7 | 39.7 | [config](configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py) |
| Model | Mask mAP | Mask mAP(official) | Config |
| :-------------------------------------------------- | :------: | -----------------: | :------------------------------------------------------------------: |
| `xdecoder-tiny_zeroshot_open-vocab-instance_coco`\* | 39.8 | 39.7 | [config](configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py) |

### Panoptic segmentation on COCO2017

Expand Down Expand Up @@ -111,7 +192,7 @@ Before testing, you need to install jdk 1.8, otherwise it will prompt that java

| Model | BLEU-4 | CIDER | Config |
| :------------------------------------------ | :----: | :----: | :----------------------------------------------------------: |
| `xdecoder-tiny_zeroshot_caption_coco2014`\* | 35.14 | 116.62 | [config](configs/xdecoder-tiny_zeroshot_caption_coco2014.py) |
| `xdecoder-tiny_zeroshot_caption_coco2014`\* | 35.26 | 116.81 | [config](configs/xdecoder-tiny_zeroshot_caption_coco2014.py) |

## Citation

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,7 @@
type='LoadImageFromFile',
imdecode_backend='pillow',
backend_args=_base_.backend_args),
dict(
type='FixScaleResize',
scale=224,
keep_ratio=True,
short_side_mode=True,
backend='pillow',
interpolation='bicubic'),
dict(type='ResizeShortestEdge', scale=224, backend='pillow'),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
type='LoadImageFromFile',
imdecode_backend='pillow',
backend_args=_base_.backend_args),
dict(type='Resize', scale=(1333, 800), backend='pillow', keep_ratio=True),
dict(
type='ResizeShortestEdge', scale=800, max_size=1333, backend='pillow'),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
type='LoadImageFromFile',
imdecode_backend='pillow',
backend_args=_base_.backend_args),
dict(type='Resize', scale=(1333, 800), backend='pillow', keep_ratio=True),
dict(
type='ResizeShortestEdge', scale=800, max_size=1333, backend='pillow'),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,7 @@
type='LoadImageFromFile',
imdecode_backend='pillow',
backend_args=_base_.backend_args),
dict(
type='FixScaleResize',
scale=640,
keep_ratio=True,
short_side_mode=True,
backend='pillow',
interpolation='bicubic'),
dict(type='ResizeShortestEdge', scale=640, backend='pillow'),
dict(
type='LoadAnnotations',
with_bbox=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
dict(
type='LoadImageFromFile', imdecode_backend='pillow',
backend_args=None),
dict(type='Resize', scale=(1333, 800), backend='pillow', keep_ratio=True),
dict(
type='ResizeShortestEdge', scale=800, max_size=1333, backend='pillow'),
dict(
type='LoadAnnotations',
with_bbox=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,7 @@

test_pipeline = [
dict(type='LoadImageFromFile', imdecode_backend='pillow'),
dict(
type='FixScaleResize',
scale=224,
keep_ratio=True,
short_side_mode=True,
backend='pillow',
interpolation='bicubic'),
dict(type='ResizeShortestEdge', scale=224, backend='pillow'),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
Expand Down
Loading

0 comments on commit 429c2a4

Please sign in to comment.