Skip to content

Commit

Permalink
[Feature] Support refcoco datasets (#10418)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiexinch committed May 31, 2023
1 parent 6145ae2 commit 78c4805
Show file tree
Hide file tree
Showing 7 changed files with 351 additions and 1 deletion.
74 changes: 74 additions & 0 deletions configs/_base_/datasets/refcoco+.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# dataset settings
dataset_type = 'RefCOCODataset'
data_root = 'data/refcoco/'

backend_args = None

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', prob=0.5),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'text', 'image_id'))
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'text', 'image_id'))
]

train_dataloader = dict(
batch_size=2,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
batch_sampler=dict(type='AspectRatioBatchSampler'),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(img='train2014/'),
ann_file='refcoco+/instances.json',
split_file='refcoco+/refs(unc).p',
split='train',
pipeline=train_pipeline,
backend_args=backend_args))

val_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(img='train2014/'),
ann_file='refcoco+/instances.json',
split_file='refcoco+/refs(unc).p',
split='val',
pipeline=test_pipeline,
backend_args=backend_args))

test_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(img='train2014/'),
ann_file='refcoco+/instances.json',
split_file='refcoco+/refs(unc).p',
split='testA', # or 'testB'
pipeline=test_pipeline,
backend_args=backend_args))

# TODO: set the metrics
74 changes: 74 additions & 0 deletions configs/_base_/datasets/refcoco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# dataset settings
dataset_type = 'RefCOCODataset'
data_root = 'data/refcoco/'

backend_args = None

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', prob=0.5),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'text', 'image_id'))
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'text', 'image_id'))
]

train_dataloader = dict(
batch_size=2,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
batch_sampler=dict(type='AspectRatioBatchSampler'),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(img='train2014/'),
ann_file='refcoco/instances.json',
split_file='refcoco/refs(unc).p',
split='train',
pipeline=train_pipeline,
backend_args=backend_args))

val_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(img='train2014/'),
ann_file='refcoco/instances.json',
split_file='refcoco/refs(unc).p',
split='val',
pipeline=test_pipeline,
backend_args=backend_args))

test_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(img='train2014/'),
ann_file='refcoco/instances.json',
split_file='refcoco/refs(unc).p',
split='testA', # or 'testB'
pipeline=test_pipeline,
backend_args=backend_args))

# TODO: set the metrics
74 changes: 74 additions & 0 deletions configs/_base_/datasets/refcocog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# dataset settings
dataset_type = 'RefCOCODataset'
data_root = 'data/refcoco/'

backend_args = None

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', prob=0.5),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'text', 'image_id'))
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'text', 'image_id'))
]

train_dataloader = dict(
batch_size=2,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
batch_sampler=dict(type='AspectRatioBatchSampler'),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(img='train2014/'),
ann_file='refcocog/instances.json',
split_file='refcocog/refs(umd).p',
split='train',
pipeline=train_pipeline,
backend_args=backend_args))

val_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(img='train2014/'),
ann_file='refcocog/instances.json',
split_file='refcocog/refs(umd).p',
split='val',
pipeline=test_pipeline,
backend_args=backend_args))

test_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(img='train2014/'),
ann_file='refcocog/instances.json',
split_file='refcocog/refs(umd).p',
split='test',
pipeline=test_pipeline,
backend_args=backend_args))

# TODO: set the metrics
25 changes: 25 additions & 0 deletions docs/en/user_guides/dataset_prepare.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,28 @@ python tools/dataset_converters/cityscapes.py \
--nproc 8 \
--out-dir ./data/cityscapes/annotations
```

The images and annotations of [RefCOCO](https://github.com/lichengunc/refer) series datasets can be download by running `tools/misc/download_dataset.py`:

```shell
python tools/misc/download_dataset.py --dataset-name refcoco --save-dir data/refcoco --unzip
```

Then the directory should be like this.

```text
data
├── refcoco
│   ├── refcoco
│   │   ├── instances.json
│   │   ├── refs(google).p
│   │   └── refs(unc).p
│   ├── refcoco+
│   │   ├── instances.json
│   │   └── refs(unc).p
│   ├── refcocog
│   │   ├── instances.json
│   │   ├── refs(google).p
│   │   └── refs(umd).p
| |── train2014
```
3 changes: 2 additions & 1 deletion mmdet/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .mot_challenge_dataset import MOTChallengeDataset
from .objects365 import Objects365V1Dataset, Objects365V2Dataset
from .openimages import OpenImagesChallengeDataset, OpenImagesDataset
from .refcoco import RefCOCODataset
from .reid_dataset import ReIDDataset
from .samplers import (AspectRatioBatchSampler, ClassAwareSampler,
GroupMultiSourceSampler, MultiSourceSampler,
Expand All @@ -34,5 +35,5 @@
'Objects365V1Dataset', 'Objects365V2Dataset', 'DSDLDetDataset',
'BaseVideoDataset', 'MOTChallengeDataset', 'TrackImgSampler',
'ReIDDataset', 'YouTubeVISDataset', 'TrackAspectRatioBatchSampler',
'ADE20KPanopticDataset', 'COCOCaptionDataset'
'ADE20KPanopticDataset', 'COCOCaptionDataset', 'RefCOCODataset'
]
92 changes: 92 additions & 0 deletions mmdet/datasets/refcoco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List

import mmengine
import numpy as np
from mmengine.dataset import BaseDataset
from pycocotools.coco import COCO

from mmdet.registry import DATASETS


@DATASETS.register_module()
class RefCOCODataset(BaseDataset):
"""RefCOCO dataset.
The `Refcoco` and `Refcoco+` dataset is based on
`ReferItGame: Referring to Objects in Photographs of Natural Scenes
<http://tamaraberg.com/papers/referit.pdf>`_.
The `Refcocog` dataset is based on
`Generation and Comprehension of Unambiguous Object Descriptions
<https://arxiv.org/abs/1511.02283>`_.
Args:
ann_file (str): Annotation file path.
data_root (str): The root directory for ``data_prefix`` and
``ann_file``. Defaults to ''.
data_prefix (str): Prefix for training data.
split_file (str): Split file path.
split (str): Split name. Defaults to 'train'.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""

def __init__(self,
data_root,
ann_file,
data_prefix,
split_file,
split='train',
**kwargs):
self.split_file = split_file
self.split = split

super().__init__(
data_root=data_root,
data_prefix=data_prefix,
ann_file=ann_file,
**kwargs,
)

def _join_prefix(self):
if not mmengine.is_abs(self.split_file) and self.split_file:
self.split_file = osp.join(self.data_root, self.split_file)

return super()._join_prefix()

def load_data_list(self) -> List[dict]:
"""Load data list."""
with mmengine.get_local_path(self.ann_file) as ann_file:
coco = COCO(ann_file)
splits = mmengine.load(self.split_file, file_format='pkl')
img_prefix = self.data_prefix['img_path']

data_list = []
join_path = mmengine.fileio.get_file_backend(img_prefix).join_path
for refer in splits:
if refer['split'] != self.split:
continue

ann = coco.anns[refer['ann_id']]
img = coco.imgs[ann['image_id']]
sentences = refer['sentences']
bbox = np.array(ann['bbox'], dtype=np.float32)
bbox[2:4] = bbox[0:2] + bbox[2:4] # XYWH -> XYXY
mask = np.array(ann['segmentation'], dtype=np.float32)

for sent in sentences:
data_info = {
'img_path': join_path(img_prefix, img['file_name']),
'image_id': ann['image_id'],
'ann_id': ann['id'],
'text': sent['sent'],
'gt_bboxes': bbox[None, :],
'gt_masks': mask[None, :],
}
data_list.append(data_info)

if len(data_list) == 0:
raise ValueError(f'No sample in split "{self.split}".')

return data_list
10 changes: 10 additions & 0 deletions tools/misc/download_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,16 @@ def main():
'https://raw.githubusercontent.com/CSAILVision/placeschallenge/master/instancesegmentation/imgCatIds.json', # noqa
# category mapping
'https://raw.githubusercontent.com/CSAILVision/placeschallenge/master/instancesegmentation/categoryMapping.txt' # noqa
],
refcoco=[
# images
'http://images.cocodataset.org/zips/train2014.zip',
# refcoco annotations
'https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco.zip',
# refcoco+ annotations
'https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco+.zip',
# refcocog annotations
'https://bvisionweb1.cs.unc.edu/licheng/referit/data/refcocog.zip'
])
url = data2url.get(args.dataset_name, None)
if url is None:
Expand Down

0 comments on commit 78c4805

Please sign in to comment.