diff --git a/configs/_base_/datasets/refcoco+.py b/configs/_base_/datasets/refcoco+.py index caa8369ba19..56db966decf 100644 --- a/configs/_base_/datasets/refcoco+.py +++ b/configs/_base_/datasets/refcoco+.py @@ -1,44 +1,24 @@ # dataset settings dataset_type = 'RefCOCODataset' -data_root = 'data/refcoco/' +data_root = 'data/coco/' backend_args = None -train_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='Resize', scale=(1333, 800), keep_ratio=True), - dict(type='RandomFlip', prob=0.5), - dict( - type='PackDetInputs', - meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', - 'scale_factor', 'text', 'image_id')) -] - test_pipeline = [ dict(type='LoadImageFromFile'), dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict( + type='LoadAnnotations', + with_mask=True, + with_bbox=False, + with_seg=False, + with_label=False), dict( type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', - 'scale_factor', 'text', 'image_id')) + 'scale_factor', 'gt_masks', 'text')) ] -train_dataloader = dict( - batch_size=2, - num_workers=2, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=True), - batch_sampler=dict(type='AspectRatioBatchSampler'), - dataset=dict( - type=dataset_type, - data_root=data_root, - data_prefix=dict(img='train2014/'), - ann_file='refcoco+/instances.json', - split_file='refcoco+/refs(unc).p', - split='train', - pipeline=train_pipeline, - backend_args=backend_args)) - val_dataloader = dict( batch_size=1, num_workers=2, @@ -52,8 +32,8 @@ ann_file='refcoco+/instances.json', split_file='refcoco+/refs(unc).p', split='val', - pipeline=test_pipeline, - backend_args=backend_args)) + text_mode='original', + pipeline=test_pipeline)) test_dataloader = dict( batch_size=1, @@ -68,7 +48,8 @@ ann_file='refcoco+/instances.json', split_file='refcoco+/refs(unc).p', split='testA', # or 'testB' - pipeline=test_pipeline, - backend_args=backend_args)) + text_mode='original', + pipeline=test_pipeline)) -# TODO: set the metrics +val_evaluator = dict(type='RefSegMetric', metrics=['cIoU', 'mIoU']) +test_evaluator = val_evaluator diff --git a/configs/_base_/datasets/refcoco.py b/configs/_base_/datasets/refcoco.py index c98ee8017d4..518c652bcbb 100644 --- a/configs/_base_/datasets/refcoco.py +++ b/configs/_base_/datasets/refcoco.py @@ -1,44 +1,24 @@ # dataset settings dataset_type = 'RefCOCODataset' -data_root = 'data/refcoco/' +data_root = 'data/coco/' backend_args = None -train_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='Resize', scale=(1333, 800), keep_ratio=True), - dict(type='RandomFlip', prob=0.5), - dict( - type='PackDetInputs', - meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', - 'scale_factor', 'text', 'image_id')) -] - test_pipeline = [ dict(type='LoadImageFromFile'), dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict( + type='LoadAnnotations', + with_mask=True, + with_bbox=False, + with_seg=False, + with_label=False), dict( type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', - 'scale_factor', 'text', 'image_id')) + 'scale_factor', 'gt_masks', 'text')) ] -train_dataloader = dict( - batch_size=2, - num_workers=2, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=True), - batch_sampler=dict(type='AspectRatioBatchSampler'), - dataset=dict( - type=dataset_type, - data_root=data_root, - data_prefix=dict(img='train2014/'), - ann_file='refcoco/instances.json', - split_file='refcoco/refs(unc).p', - split='train', - pipeline=train_pipeline, - backend_args=backend_args)) - val_dataloader = dict( batch_size=1, num_workers=2, @@ -52,8 +32,8 @@ ann_file='refcoco/instances.json', split_file='refcoco/refs(unc).p', split='val', - pipeline=test_pipeline, - backend_args=backend_args)) + text_mode='original', + pipeline=test_pipeline)) test_dataloader = dict( batch_size=1, @@ -68,7 +48,8 @@ ann_file='refcoco/instances.json', split_file='refcoco/refs(unc).p', split='testA', # or 'testB' - pipeline=test_pipeline, - backend_args=backend_args)) + text_mode='original', + pipeline=test_pipeline)) -# TODO: set the metrics +val_evaluator = dict(type='RefSegMetric', metrics=['cIoU', 'mIoU']) +test_evaluator = val_evaluator diff --git a/configs/_base_/datasets/refcocog.py b/configs/_base_/datasets/refcocog.py index 9a2a45ff8a6..03b40add316 100644 --- a/configs/_base_/datasets/refcocog.py +++ b/configs/_base_/datasets/refcocog.py @@ -1,44 +1,24 @@ # dataset settings dataset_type = 'RefCOCODataset' -data_root = 'data/refcoco/' +data_root = 'data/coco/' backend_args = None -train_pipeline = [ - dict(type='LoadImageFromFile'), - dict(type='Resize', scale=(1333, 800), keep_ratio=True), - dict(type='RandomFlip', prob=0.5), - dict( - type='PackDetInputs', - meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', - 'scale_factor', 'text', 'image_id')) -] - test_pipeline = [ dict(type='LoadImageFromFile'), dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict( + type='LoadAnnotations', + with_mask=True, + with_bbox=False, + with_seg=False, + with_label=False), dict( type='PackDetInputs', meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', - 'scale_factor', 'text', 'image_id')) + 'scale_factor', 'gt_masks', 'text')) ] -train_dataloader = dict( - batch_size=2, - num_workers=2, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=True), - batch_sampler=dict(type='AspectRatioBatchSampler'), - dataset=dict( - type=dataset_type, - data_root=data_root, - data_prefix=dict(img='train2014/'), - ann_file='refcocog/instances.json', - split_file='refcocog/refs(umd).p', - split='train', - pipeline=train_pipeline, - backend_args=backend_args)) - val_dataloader = dict( batch_size=1, num_workers=2, @@ -48,12 +28,12 @@ dataset=dict( type=dataset_type, data_root=data_root, - data_prefix=dict(img='train2014/'), + data_prefix=dict(img_path='train2014/'), ann_file='refcocog/instances.json', split_file='refcocog/refs(umd).p', split='val', - pipeline=test_pipeline, - backend_args=backend_args)) + text_mode='original', + pipeline=test_pipeline)) test_dataloader = dict( batch_size=1, @@ -64,11 +44,12 @@ dataset=dict( type=dataset_type, data_root=data_root, - data_prefix=dict(img='train2014/'), + data_prefix=dict(img_path='train2014/'), ann_file='refcocog/instances.json', split_file='refcocog/refs(umd).p', split='test', - pipeline=test_pipeline, - backend_args=backend_args)) + text_mode='original', + pipeline=test_pipeline)) -# TODO: set the metrics +val_evaluator = dict(type='RefSegMetric', metrics=['cIoU', 'mIoU']) +test_evaluator = val_evaluator diff --git a/mmdet/datasets/refcoco.py b/mmdet/datasets/refcoco.py index ce95e04e171..fb178ce1716 100644 --- a/mmdet/datasets/refcoco.py +++ b/mmdet/datasets/refcoco.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +import collections import os.path as osp +import random from typing import List import mmengine -import numpy as np from mmengine.dataset import BaseDataset -from pycocotools.coco import COCO from mmdet.registry import DATASETS @@ -38,10 +38,17 @@ def __init__(self, data_prefix, split_file, split='train', + text_mode='random', **kwargs): self.split_file = split_file self.split = split + assert text_mode in ['original', 'random', 'concat', 'select_first'] + self.text_mode = text_mode + + self._init_refs( + osp.join(data_root, ann_file), osp.join(data_root, split_file)) + super().__init__( data_root=data_root, data_prefix=data_prefix, @@ -55,36 +62,98 @@ def _join_prefix(self): return super()._join_prefix() + def _init_refs(self, ann_file, split_file): + """Initialize the refs for RefCOCO.""" + self.instances = mmengine.load(ann_file, file_format='json') + splits = mmengine.load(split_file, file_format='pkl') + + anns, imgs = {}, {} + for ann in self.instances['annotations']: + anns[ann['id']] = ann + for img in self.instances['images']: + imgs[img['id']] = img + + refs, ref_to_ann = {}, {} + for ref in splits: + # ids + ref_id = ref['ref_id'] + ann_id = ref['ann_id'] + # add mapping related to ref + refs[ref_id] = ref + ref_to_ann[ref_id] = anns[ann_id] + + self.refs = refs + self.ref_to_ann = ref_to_ann + def load_data_list(self) -> List[dict]: """Load data list.""" - with mmengine.get_local_path(self.ann_file) as ann_file: - coco = COCO(ann_file) splits = mmengine.load(self.split_file, file_format='pkl') img_prefix = self.data_prefix['img_path'] + ref_ids = [ + ref['ref_id'] for ref in splits if ref['split'] == self.split + ] + full_anno = [] + for ref_id in ref_ids: + ref = self.refs[ref_id] + ann = self.ref_to_ann[ref_id] + ann.update(ref) + full_anno.append(ann) + + image_id_list = [] + final_anno = {} + for anno in full_anno: + image_id_list.append(anno['image_id']) + final_anno[anno['ann_id']] = anno + annotations = [value for key, value in final_anno.items()] + + coco_train_id = [] + image_annot = {} + for i in range(len(self.instances['images'])): + coco_train_id.append(self.instances['images'][i]['id']) + image_annot[self.instances['images'][i] + ['id']] = self.instances['images'][i] + + images = [] + for image_id in list(set(image_id_list)): + images += [image_annot[image_id]] + data_list = [] + + grounding_dict = collections.defaultdict(list) + for anno in annotations: + image_id = int(anno['image_id']) + grounding_dict[image_id].append(anno) + join_path = mmengine.fileio.get_file_backend(img_prefix).join_path - for refer in splits: - if refer['split'] != self.split: - continue - - ann = coco.anns[refer['ann_id']] - img = coco.imgs[ann['image_id']] - sentences = refer['sentences'] - bbox = np.array(ann['bbox'], dtype=np.float32) - bbox[2:4] = bbox[0:2] + bbox[2:4] # XYWH -> XYXY - mask = np.array(ann['segmentation'], dtype=np.float32) - - for sent in sentences: - data_info = { - 'img_path': join_path(img_prefix, img['file_name']), - 'image_id': ann['image_id'], - 'ann_id': ann['id'], - 'text': sent['sent'], - 'gt_bboxes': bbox[None, :], - 'gt_masks': mask[None, :], - } - data_list.append(data_info) + for image in images: + img_id = image['id'] + grounding_anno = grounding_dict[img_id][0] + texts = [x['raw'].lower() for x in grounding_anno['sentences']] + if self.text_mode == 'random': + idx = random.randint(0, len(texts) - 1) + text = texts[idx] + elif self.text_mode == 'concat': + text = [''.join(texts)] + elif self.text_mode == 'select_first': + text = [texts[0]] + elif self.text_mode == 'original': + text = texts + else: + raise ValueError(f'Invalid text mode "{self.text_mode}".') + data_info = { + 'img_path': + join_path(img_prefix, image['file_name']), + 'img_id': + img_id, + 'instances': [{ + 'mask': grounding_anno['segmentation'], + 'ignore_flag': 0 + }], + 'text': + text + } + data_list.append(data_info) if len(data_list) == 0: raise ValueError(f'No sample in split "{self.split}".') diff --git a/mmdet/evaluation/metrics/__init__.py b/mmdet/evaluation/metrics/__init__.py index df73bb329dc..e1ec0e46250 100644 --- a/mmdet/evaluation/metrics/__init__.py +++ b/mmdet/evaluation/metrics/__init__.py @@ -12,6 +12,7 @@ from .lvis_metric import LVISMetric from .mot_challenge_metric import MOTChallengeMetric from .openimages_metric import OpenImagesMetric +from .refseg_metric import RefSegMetric from .reid_metric import ReIDMetrics from .semseg_metric import SemSegMetric from .voc_metric import VOCMetric @@ -22,5 +23,5 @@ 'VOCMetric', 'LVISMetric', 'CrowdHumanMetric', 'DumpProposals', 'CocoOccludedSeparatedMetric', 'DumpDetResults', 'BaseVideoMetric', 'MOTChallengeMetric', 'CocoVideoMetric', 'ReIDMetrics', 'YouTubeVISMetric', - 'COCOCaptionMetric', 'SemSegMetric' + 'COCOCaptionMetric', 'SemSegMetric', 'RefSegMetric' ] diff --git a/mmdet/evaluation/metrics/refseg_metric.py b/mmdet/evaluation/metrics/refseg_metric.py new file mode 100644 index 00000000000..a59fd3db163 --- /dev/null +++ b/mmdet/evaluation/metrics/refseg_metric.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Sequence + +import torch +from mmengine.evaluator import BaseMetric + +from mmdet.registry import METRICS + + +@METRICS.register_module() +class RefSegMetric(BaseMetric): + + def __init__(self, + metrics: list = ['cIoU', 'mIoU'], + eval_first_text: bool = False, + **kwargs): + super().__init__(**kwargs) + assert set(metrics).issubset(['cIoU', 'mIoU']), \ + f'Only support cIoU and mIoU, but got {metrics}' + assert len(metrics) > 0, 'metrics should not be empty' + self.metrics = metrics + self.eval_first_text = eval_first_text + + def compute_iou(self, pred_seg, gt_seg): + i = pred_seg & gt_seg + u = pred_seg | gt_seg + return i, u + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data and data_samples. + + The processed results should be stored in ``self.results``, which will + be used to compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + for data_sample in data_samples: + pred_label = data_sample['pred_instances']['masks'].bool() + label = data_sample['gt_masks'].to_tensor( + pred_label.dtype, pred_label.device).bool() + if self.eval_first_text: + pred_label = pred_label[0:1] + else: + label = label.repeat(pred_label.shape[0], 1, 1) + + # calculate iou + i, u = self.compute_iou(pred_label, label) + + bsi = len(pred_label) + iou = i.reshape(bsi, -1).sum(-1) * 1.0 / u.reshape(bsi, -1).sum(-1) + iou = torch.nan_to_num_(iou, nan=0.0) + self.results.append((i.sum(), u.sum(), iou.sum(), bsi)) + + def compute_metrics(self, results: list) -> dict: + results = tuple(zip(*results)) + assert len(results) == 4 + cum_i = sum(results[0]) + cum_u = sum(results[1]) + iou = sum(results[2]) + seg_total = sum(results[3]) + + metrics = {} + if 'cIoU' in self.metrics: + metrics['cIoU'] = cum_i * 100 / cum_u + if 'mIoU' in self.metrics: + metrics['mIoU'] = iou * 100 / seg_total + return metrics diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-ref-semseg_refcocog.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-ref-semseg_refcocog.py new file mode 100644 index 00000000000..ed4bfb1bfea --- /dev/null +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-ref-semseg_refcocog.py @@ -0,0 +1,5 @@ +_base_ = [ + '_base_/xdecoder-tiny_ref-semseg.py', 'mmdet::_base_/datasets/refcocog.py' +] + +test_dataloader = dict(dataset=dict(split='val'))