Skip to content

Commit

Permalink
Refcoco results (#10516)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiexinch committed Jun 16, 2023
1 parent 47f0f3f commit 81d5089
Show file tree
Hide file tree
Showing 12 changed files with 173 additions and 95 deletions.
10 changes: 5 additions & 5 deletions configs/_base_/datasets/refcoco+.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(img='train2014/'),
data_prefix=dict(img_path='train2014/'),
ann_file='refcoco+/instances.json',
split_file='refcoco+/refs(unc).p',
split='val',
text_mode='original',
text_mode='select_first',
pipeline=test_pipeline))

test_dataloader = dict(
Expand All @@ -44,12 +44,12 @@
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(img='train2014/'),
data_prefix=dict(img_path='train2014/'),
ann_file='refcoco+/instances.json',
split_file='refcoco+/refs(unc).p',
split='testA', # or 'testB'
text_mode='original',
text_mode='select_first',
pipeline=test_pipeline))

val_evaluator = dict(type='RefSegMetric', iou_metrics=['cIoU', 'mIoU'])
val_evaluator = dict(type='RefSegMetric', metric=['cIoU', 'mIoU'])
test_evaluator = val_evaluator
10 changes: 5 additions & 5 deletions configs/_base_/datasets/refcoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(img='train2014/'),
data_prefix=dict(img_path='train2014/'),
ann_file='refcoco/instances.json',
split_file='refcoco/refs(unc).p',
split='val',
text_mode='original',
text_mode='select_first',
pipeline=test_pipeline))

test_dataloader = dict(
Expand All @@ -44,12 +44,12 @@
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(img='train2014/'),
data_prefix=dict(img_path='train2014/'),
ann_file='refcoco/instances.json',
split_file='refcoco/refs(unc).p',
split='testA', # or 'testB'
text_mode='original',
text_mode='select_first',
pipeline=test_pipeline))

val_evaluator = dict(type='RefSegMetric', iou_metrics=['cIoU', 'mIoU'])
val_evaluator = dict(type='RefSegMetric', metric=['cIoU', 'mIoU'])
test_evaluator = val_evaluator
6 changes: 3 additions & 3 deletions configs/_base_/datasets/refcocog.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
ann_file='refcocog/instances.json',
split_file='refcocog/refs(umd).p',
split='val',
text_mode='original',
text_mode='select_first',
pipeline=test_pipeline))

test_dataloader = dict(
Expand All @@ -48,8 +48,8 @@
ann_file='refcocog/instances.json',
split_file='refcocog/refs(umd).p',
split='test',
text_mode='original',
text_mode='select_first',
pipeline=test_pipeline))

val_evaluator = dict(type='RefSegMetric', iou_metrics=['cIoU', 'mIoU'])
val_evaluator = dict(type='RefSegMetric', metric=['cIoU', 'mIoU'])
test_evaluator = val_evaluator
8 changes: 4 additions & 4 deletions docs/en/user_guides/dataset_prepare.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,15 +189,15 @@ mmdetection
The images and annotations of [RefCOCO](https://github.com/lichengunc/refer) series datasets can be download by running `tools/misc/download_dataset.py`:

```shell
python tools/misc/download_dataset.py --dataset-name refcoco --save-dir data/refcoco --unzip
python tools/misc/download_dataset.py --dataset-name refcoco --save-dir data/coco --unzip
```

Then the directory should be like this.
Then the directory should be like this:

```text
data
├── refcoco
   ├── refcoco
├── coco
├── refcoco
│   │   ├── instances.json
│   │   ├── refs(google).p
│   │   └── refs(unc).p
Expand Down
93 changes: 93 additions & 0 deletions docs/zh_cn/user_guides/dataset_prepare.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,97 @@ mmdetection

### RefCOCO 数据集准备

[RefCOCO](https://github.com/lichengunc/refer)系列数据集的图像和注释可以通过运行 `tools/misc/download_dataset.py` 下载:

```shell
python tools/misc/download_dataset.py --dataset-name refcoco --save-dir data/coco --unzip
```

然后,目录应该是这样的:

```text
data
├── coco
│ ├── refcoco
│   │   ├── instances.json
│   │   ├── refs(google).p
│   │   └── refs(unc).p
│   ├── refcoco+
│   │   ├── instances.json
│   │   └── refs(unc).p
│   ├── refcocog
│   │   ├── instances.json
│   │   ├── refs(google).p
│   │   └── refs(umd).p
| |── train2014
```

### ADE20K 数据集准备

[ADE20K](http://groups.csail.mit.edu/vision/datasets/ADE20K/)数据集的图像和注释可以通过运行 `tools/misc/download_dataset.py` 下载:

```shell
python tools/misc/download_dataset.py --dataset-name ade20k_2016 --save-dir data --unzip
```

然后将注释移至`data/ADEChallengeData2016`目录,并运行预处理脚本以产生coco格式注释:

```shell
mv data/annotations_instance data/ADEChallengeData2016/
mv data/categoryMapping.txt data/ADEChallengeData2016/
mv data/imgCatIds.json data/ADEChallengeData2016/
python tools/dataset_converters/ade20k2coco.py data/ADEChallengeData2016 --task panoptic
python tools/dataset_converters/ade20k2coco.py data/ADEChallengeData2016 --task instance
```

然后,目录应该是这样的:

```text
data
├── ADEChallengeData2016
│   ├── ade20k_instance_train.json
│   ├── ade20k_instance_val.json
│   ├── ade20k_panoptic_train
| | ├── ADE_train_00000001.png
| | ├── ADE_train_00000002.png
| | ├── ...
│   ├── ade20k_panoptic_train.json
│   ├── ade20k_panoptic_val
| | ├── ADE_val_00000001.png
| | ├── ADE_val_00000002.png
| | ├── ...
│   ├── ade20k_panoptic_val.json
│   ├── annotations
| | ├── training
| | | ├── ADE_train_00000001.png
| | | ├── ADE_train_00000002.png
| | | ├── ...
| | ├── validation
| | | ├── ADE_val_00000001.png
| | | ├── ADE_val_00000002.png
| | | ├── ...
│   ├── annotations_instance
| | ├── training
| | | ├── ADE_train_00000001.png
| | | ├── ADE_train_00000002.png
| | | ├── ...
| | ├── validation
| | | ├── ADE_val_00000001.png
| | | ├── ADE_val_00000002.png
| | | ├── ...
│   ├── categoryMapping.txt
│   ├── images
│   | ├── training
| | | ├── ADE_train_00000001.jpg
| | | ├── ADE_train_00000002.jpg
| | | ├── ...
| | ├── validation
| | | ├── ADE_val_00000001.jpg
| | | ├── ADE_val_00000002.jpg
| | | ├── ...
│   ├── imgCatIds.json
│   ├── objectInfo150.txt
| |── sceneCategories.txt
```

上述文件夹包括ADE20K的语义分割、实例分割和泛在分割的所有数据。
77 changes: 37 additions & 40 deletions mmdet/datasets/refcoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import collections
import os.path as osp
import random
from typing import List
from typing import Dict, List

import mmengine
from mmengine.dataset import BaseDataset
Expand All @@ -29,26 +29,23 @@ class RefCOCODataset(BaseDataset):
data_prefix (str): Prefix for training data.
split_file (str): Split file path.
split (str): Split name. Defaults to 'train'.
text_mode (str): Text mode. Defaults to 'random'.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""

def __init__(self,
data_root,
ann_file,
data_prefix,
split_file,
split='train',
text_mode='random',
data_root: str,
ann_file: str,
split_file: str,
data_prefix: Dict,
split: str = 'train',
text_mode: str = 'random',
**kwargs):
self.split_file = split_file
self.split = split

assert text_mode in ['original', 'random', 'concat', 'select_first']
self.text_mode = text_mode

self._init_refs(
osp.join(data_root, ann_file), osp.join(data_root, split_file))

super().__init__(
data_root=data_root,
data_prefix=data_prefix,
Expand All @@ -62,19 +59,16 @@ def _join_prefix(self):

return super()._join_prefix()

def _init_refs(self, ann_file, split_file):
def _init_refs(self):
"""Initialize the refs for RefCOCO."""
self.instances = mmengine.load(ann_file, file_format='json')
splits = mmengine.load(split_file, file_format='pkl')

anns, imgs = {}, {}
for ann in self.instances['annotations']:
anns[ann['id']] = ann
for img in self.instances['images']:
imgs[img['id']] = img

refs, ref_to_ann = {}, {}
for ref in splits:
for ref in self.splits:
# ids
ref_id = ref['ref_id']
ann_id = ref['ann_id']
Expand All @@ -87,11 +81,13 @@ def _init_refs(self, ann_file, split_file):

def load_data_list(self) -> List[dict]:
"""Load data list."""
splits = mmengine.load(self.split_file, file_format='pkl')
self.splits = mmengine.load(self.split_file, file_format='pkl')
self.instances = mmengine.load(self.ann_file, file_format='json')
self._init_refs()
img_prefix = self.data_prefix['img_path']

ref_ids = [
ref['ref_id'] for ref in splits if ref['split'] == self.split
ref['ref_id'] for ref in self.splits if ref['split'] == self.split
]
full_anno = []
for ref_id in ref_ids:
Expand Down Expand Up @@ -128,30 +124,31 @@ def load_data_list(self) -> List[dict]:
join_path = mmengine.fileio.get_file_backend(img_prefix).join_path
for image in images:
img_id = image['id']
grounding_anno = grounding_dict[img_id][0]
texts = [x['raw'].lower() for x in grounding_anno['sentences']]
if self.text_mode == 'random':
idx = random.randint(0, len(texts) - 1)
text = texts[idx]
elif self.text_mode == 'concat':
text = [''.join(texts)]
elif self.text_mode == 'select_first':
text = [texts[0]]
elif self.text_mode == 'original':
text = texts
else:
raise ValueError(f'Invalid text mode "{self.text_mode}".')
data_info = {
'img_path':
join_path(img_prefix, image['file_name']),
'img_id':
img_id,
'instances': [{
instances = []
sentences = []
for grounding_anno in grounding_dict[img_id]:
texts = [x['raw'].lower() for x in grounding_anno['sentences']]
if self.text_mode == 'random':
idx = random.randint(0, len(texts) - 1)
text = [texts[idx]]
elif self.text_mode == 'concat':
text = [''.join(texts)]
elif self.text_mode == 'select_first':
text = texts[0]
elif self.text_mode == 'original':
text = texts
else:
raise ValueError(f'Invalid text mode "{self.text_mode}".')
instances.append({
'mask': grounding_anno['segmentation'],
'ignore_flag': 0
}],
'text':
text
})
sentences.append(text)
data_info = {
'img_path': join_path(img_prefix, image['file_name']),
'img_id': img_id,
'instances': instances,
'text': sentences
}
data_list.append(data_info)

Expand Down
19 changes: 5 additions & 14 deletions mmdet/evaluation/metrics/refseg_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,12 @@
@METRICS.register_module()
class RefSegMetric(BaseMetric):

def __init__(self,
iou_metrics: list = ['cIoU', 'mIoU'],
eval_first_text: bool = False,
**kwargs):
def __init__(self, metric: list = ['cIoU', 'mIoU'], **kwargs):
super().__init__(**kwargs)
assert set(iou_metrics).issubset(['cIoU', 'mIoU']), \
f'Only support cIoU and mIoU, but got {iou_metrics}'
assert len(iou_metrics) > 0, 'metrics should not be empty'
self.metrics = iou_metrics
self.eval_first_text = eval_first_text
assert set(metric).issubset(['cIoU', 'mIoU']), \
f'Only support cIoU and mIoU, but got {metric}'
assert len(metric) > 0, 'metrics should not be empty'
self.metrics = metric

def compute_iou(self, pred_seg, gt_seg):
i = pred_seg & gt_seg
Expand All @@ -40,11 +36,6 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:
pred_label = data_sample['pred_instances']['masks'].bool()
label = data_sample['gt_masks'].to_tensor(
pred_label.dtype, pred_label.device).bool()
if self.eval_first_text:
pred_label = pred_label[0:1]
else:
label = label.repeat(pred_label.shape[0], 1, 1)

# calculate iou
i, u = self.compute_iou(pred_label, label)

Expand Down
12 changes: 11 additions & 1 deletion projects/XDecoder/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,19 @@ Prepare your dataset according to the [docs](https://mmdetection.readthedocs.io/

### Referring segmentation on RefCOCO

Prepare your dataset according to the [docs](https://mmdetection.readthedocs.io/en/latest/user_guides/dataset_prepare.html#refcoco-dataset-preparation).

**Test Command**

```shell
./tools/dist_test.sh projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-ref-seg_refcocog.py xdecoder_focalt_last_novg.pt 8 --cfg-options test_dataloader.dataset.split='val'
```

| Model | cIoU | cIOU(official) | Config |
| :------------------------------- | :---: | :------------: | :---------------------------------------------------------------------: |
| `xdecoder_focalt_last_novg.pt`\* | 62.25 | 57.85 | [config](configs/xdecoder-tiny_zeroshot_open-vocab-ref-seg_refcocog.py) |
| `xdecoder_focalt_last_novg.pt`\* | 58.85 | 57.85 | [config](configs/xdecoder-tiny_zeroshot_open-vocab-ref-seg_refcocog.py) |

**Note:** If you set the scale of `Resize` to (1024, 512), the result will be `57.69`.

### Image Caption on COCO2014

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = [
'_base_/xdecoder-tiny_ref-seg.py', 'mmdet::_base_/datasets/refcoco+.py'
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = [
'_base_/xdecoder-tiny_ref-seg.py', 'mmdet::_base_/datasets/refcoco.py'
]
Loading

0 comments on commit 81d5089

Please sign in to comment.