From 88884b40b0f55dbeb5ac31cdb4bd32d1c963836e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Haian=20Huang=28=E6=B7=B1=E5=BA=A6=E7=9C=B8=29?= <1286304229@qq.com> Date: Fri, 9 Jun 2023 13:53:08 +0800 Subject: [PATCH] Fix ref seg (#10478) --- configs/_base_/datasets/coco_semantic.py | 2 +- docs/zh_cn/user_guides/dataset_prepare.md | 113 +++++++++++ mmdet/datasets/__init__.py | 4 +- mmdet/datasets/ade20k.py | 186 ++++++++++-------- mmdet/datasets/coco_caption.py | 12 +- mmdet/datasets/coco_semantic.py | 7 +- mmdet/datasets/transforms/loading.py | 2 + .../evaluation/metrics/coco_caption_metric.py | 2 +- mmdet/evaluation/metrics/semseg_metric.py | 11 +- mmdet/visualization/local_visualizer.py | 14 +- projects/XDecoder/README.md | 84 ++++++-- .../_base_/xdecoder-tiny_ref-semseg.py | 2 +- ...er-tiny_zeroshot_open-vocab-semseg_coco.py | 2 +- .../configs/xdecoder-tiny_zeroshot_ref-seg.py | 3 + .../xdecoder-tiny_zeroshot_ref-semseg.py | 3 - projects/XDecoder/demo.py | 2 +- .../xdecoder/inference/image_caption.py | 6 +- .../texttoimage_regionretrieval_inferencer.py | 6 +- .../XDecoder/xdecoder/transformer_decoder.py | 13 +- projects/XDecoder/xdecoder/unified_head.py | 105 +++++----- ...coco_semantic_annos_from_panoptic_annos.py | 2 + 21 files changed, 393 insertions(+), 188 deletions(-) create mode 100644 projects/XDecoder/configs/xdecoder-tiny_zeroshot_ref-seg.py delete mode 100644 projects/XDecoder/configs/xdecoder-tiny_zeroshot_ref-semseg.py diff --git a/configs/_base_/datasets/coco_semantic.py b/configs/_base_/datasets/coco_semantic.py index 92295832689..23d02079724 100644 --- a/configs/_base_/datasets/coco_semantic.py +++ b/configs/_base_/datasets/coco_semantic.py @@ -1,5 +1,5 @@ # dataset settings -dataset_type = 'CocoSegDaset' +dataset_type = 'CocoSegDataset' data_root = 'data/coco/' # Example to use different file client diff --git a/docs/zh_cn/user_guides/dataset_prepare.md b/docs/zh_cn/user_guides/dataset_prepare.md index b33ec3bd309..f26a022b7d7 100644 --- a/docs/zh_cn/user_guides/dataset_prepare.md +++ b/docs/zh_cn/user_guides/dataset_prepare.md @@ -1,5 +1,7 @@ ## 数据集准备 +### 基础检测数据集准备 + MMDetection 支持多个公共数据集,包括 [COCO](https://cocodataset.org/), [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC), [Cityscapes](https://www.cityscapes-dataset.com/) 和 [其他更多数据集](https://github.com/open-mmlab/mmdetection/tree/main/configs/_base_/datasets)。 一些公共数据集,比如 Pascal VOC 及其镜像数据集,或者 COCO 等数据集都可以从官方网站或者镜像网站获取。注意:在检测任务中,Pascal VOC 2012 是 Pascal VOC 2007 的无交集扩展,我们通常将两者一起使用。 我们建议将数据集下载,然后解压到项目外部的某个文件夹内,然后通过符号链接的方式,将数据集根目录链接到 `$MMDETECTION/data` 文件夹下, 如果你的文件夹结构和下方不同的话,你需要在配置文件中改变对应的路径。 @@ -71,3 +73,114 @@ python tools/dataset_converters/cityscapes.py \ --nproc 8 \ --out-dir ./data/cityscapes/annotations ``` + +### COCO Caption 数据集准备 + +COCO Caption 采用的是 COCO2014 数据集作为图片,并且使用了 karpathy 的标注, + +首先你需要下载 COCO2014 数据集 + +```shell +python tools/misc/download_dataset.py --dataset-name coco2014 --unzip +``` + +数据集会下载到当前路径的 `data/coco` 下。然后下载 karpathy 的标注 + +```shell +cd data/coco/annotations +wget https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json +wget https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json +wget https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json +wget https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json +wget https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json +``` + +最终直接可用于训练和测试的数据集文件夹结构如下: + +```text +mmdetection +├── data +│ ├── coco +│ │ ├── annotations +│ │ │ ├── coco_karpathy_train.json +│ │ │ ├── coco_karpathy_test.json +│ │ │ ├── coco_karpathy_val.json +│ │ │ ├── coco_karpathy_val_gt.json +│ │ │ ├── coco_karpathy_test_gt.json +│ │ ├── train2014 +│ │ ├── val2014 +│ │ ├── test2014 +``` + +### COCO semantic 数据集准备 + +COCO 语义分割有两种类型标注,主要差别在于类别名定义不一样,因此处理方式也有两种,第一种是直接使用 stuffthingmaps 数据集,第二种是使用 panoptic 数据集。 + +**(1) 使用 stuffthingmaps 数据集** + +该数据集的下载地址为 [stuffthingmaps_trainval2017](http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip),请下载后解压防止到 `data/coco` 文件夹下。 + +```text +mmdetection +├── data +│ ├── coco +│ │ ├── annotations +│ │ ├── train2017 +│ │ ├── val2017 +│ │ ├── test2017 +│ │ ├── stuffthingmaps +``` + +该数据集不同于标准的 COCO 类别标注,其包括 172 个类: 80 thing 类、91 stuff 类和 1 个 'unlabeled',其每个类别的说明见 https://github.com/nightrome/cocostuff/blob/master/labels.md + +虽然只标注了 172 个类别,但是 stuffthingmaps 中最大标签 id 是 182,中间有些类别是没有标注的,并且第 0 类的 `unlabeled` 类别被移除。因此最终的 stuffthingmaps 图片中每个位置的值对应的类别关系见 https://github.com/kazuto1011/deeplab-pytorch/blob/master/data/datasets/cocostuff/labels.txt + +考虑到训练高效和方便用户,在开启训练或者评估前,我们需要将没有标注的 12 个类移除,这 12 个类的名字为: street sign、hat、shoe、eye glasses、plate、mirror、window、desk、door、blender、hair brush,最终可用于训练和评估的类别信息见 `mmdet/datasets/coco_semantic.py` + +你可以使用 `tools/dataset_converters/coco_stuff164k.py` 来完成将下载的 stuffthingmaps 转换为直接可以训练和评估的数据集,转换后的数据集文件夹结构如下: + +```text +mmdetection +├── data +│ ├── coco +│ │ ├── annotations +│ │ ├── train2017 +│ │ ├── val2017 +│ │ ├── test2017 +│ │ ├── stuffthingmaps +│ │ ├── stuffthingmaps_semseg +``` + +stuffthingmaps_semseg 即为新生成的可以直接训练和测试的 COCO 语义分割数据集。 + +**(2) 使用 panoptic 数据集** + +通过 panoptic 标注生成的语义分割数据集类别数相比使用 stuffthingmaps 数据集生成的会少一些。首先你需要准备全景分割标注,然后使用如下脚本完成转换 + +```shell +python tools/dataset_converters/prepare_coco_semantic_annos_from_panoptic_annos.py data/coco +``` + +转换后的数据集文件夹结构如下: + +```text +mmdetection +├── data +│ ├── coco +│ │ ├── annotations +│ │ │ ├── panoptic_train2017.json +│ │ │ ├── panoptic_train2017 +│ │ │ ├── panoptic_val2017.json +│ │ │ ├── panoptic_val2017 +│ │ │ ├── panoptic_semseg_train2017 +│ │ │ ├── panoptic_semseg_val2017 +│ │ ├── train2017 +│ │ ├── val2017 +│ │ ├── test2017 +``` + +panoptic_semseg_train2017 和 panoptic_semseg_val2017 即为新生成的可以直接训练和测试的 COCO 语义分割数据集。注意其类别信息就是 COCO 全景分割的类别信息,包括 thing 和 stuff。 + +### RefCOCO 数据集准备 + +### ADE20K 数据集准备 diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py index 54a4b5ef3ee..3e14849262b 100644 --- a/mmdet/datasets/__init__.py +++ b/mmdet/datasets/__init__.py @@ -7,7 +7,7 @@ from .coco import CocoDataset from .coco_caption import COCOCaptionDataset from .coco_panoptic import CocoPanopticDataset -from .coco_semantic import CocoSegDaset +from .coco_semantic import CocoSegDataset from .crowdhuman import CrowdHumanDataset from .dataset_wrappers import MultiImageMixDataset from .deepfashion import DeepFashionDataset @@ -38,5 +38,5 @@ 'BaseVideoDataset', 'MOTChallengeDataset', 'TrackImgSampler', 'ReIDDataset', 'YouTubeVISDataset', 'TrackAspectRatioBatchSampler', 'ADE20KPanopticDataset', 'COCOCaptionDataset', 'RefCOCODataset', - 'BaseSegDataset', 'ADE20KSegDataset', 'CocoSegDaset' + 'BaseSegDataset', 'ADE20KSegDataset', 'CocoSegDataset' ] diff --git a/mmdet/datasets/ade20k.py b/mmdet/datasets/ade20k.py index 00766c0fefb..d765d8e54c6 100644 --- a/mmdet/datasets/ade20k.py +++ b/mmdet/datasets/ade20k.py @@ -92,47 +92,71 @@ class ADE20KPanopticDataset(CocoPanopticDataset): 'conveyor belt, conveyor belt, conveyor, conveyor, transporter', 'canopy', 'pool', 'falls', 'tent', 'cradle', 'tank, storage tank', 'lake', 'blanket, cover', 'pier', 'crt screen', 'shower'), - 'palette': [[120, 120, 120], [180, 120, 120], [6, 230, 230], - [80, 50, 50], [4, 200, 3], [120, 120, 80], [140, 140, 140], - [204, 5, 255], [230, 230, 230], [4, 250, 7], [224, 5, 255], - [235, 255, 7], [150, 5, 61], [120, 120, 70], [8, 255, 51], - [255, 6, 82], [143, 255, 140], [204, 255, 4], [255, 51, 7], - [204, 70, 3], [0, 102, 200], [61, 230, 250], [255, 6, 51], - [11, 102, 255], [255, 7, 71], [255, 9, 224], [9, 7, 230], - [220, 220, 220], [255, 9, 92], - [112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6], - [10, 255, 71], [255, 41, 10], [7, 255, 255], [224, 255, 8], - [102, 8, 255], [255, 61, 6], [255, 194, 7], [255, 122, 8], - [0, 255, 20], [255, 8, 41], [255, 5, 153], [6, 51, 255], - [235, 12, 255], [160, 150, 20], [0, 163, 255], - [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], - [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], - [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], - [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], - [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], - [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], - [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], - [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], - [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], - [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], - [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], - [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], - [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], - [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], - [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], - [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], - [92, 255, 0], [0, 224, 255], [112, 224, - 255], [70, 184, 160], - [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], - [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], - [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], - [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], - [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], - [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], - [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], - [184, 255, 0], [0, 133, 255], [255, 214, - 0], [25, 194, 194], - [102, 255, 0], [92, 0, 255]] + 'palette': + ((120, 120, 120), (180, 120, 120), (6, 230, 230), (80, 50, 50), + (4, 200, 3), (120, 120, 80), (140, 140, 140), (204, 5, 255), + (230, 230, 230), (4, 250, 7), (224, 5, 255), (235, 255, 7), + (150, 5, 61), (120, 120, 70), (8, 255, 51), (255, 6, 82), + (143, 255, 140), (204, 255, 4), (255, 51, 7), (204, 70, 3), + (0, 102, 200), (61, 230, 250), (255, 6, 51), (11, 102, 255), + (255, 7, 71), (255, 9, 224), (9, 7, 230), (220, 220, 220), + (255, 9, 92), (112, 9, 255), (8, 255, 214), (7, 255, 224), + (255, 184, 6), (10, 255, 71), (255, 41, 10), (7, 255, 255), + (224, 255, 8), (102, 8, 255), (255, 61, 6), (255, 194, 7), (255, 122, + 8), + (0, 255, 20), (255, 8, 41), (255, 5, 153), (6, 51, 255), (235, 12, + 255), + (160, 150, 20), (0, 163, 255), (140, 140, 140), (250, 10, + 15), (20, 255, 0), + (31, 255, 0), (255, 31, 0), (255, 224, 0), (153, 255, 0), (0, 0, 255), + (255, 71, 0), (0, 235, 255), (0, 173, 255), (31, 0, 255), + (11, 200, + 200), (255, 82, + 0), (0, 255, 245), (0, 61, 255), (0, 255, 112), (0, 255, 133), + (255, 0, 0), (255, 163, 0), (255, 102, 0), (194, 255, 0), (0, 143, + 255), + (51, 255, 0), (0, 82, 255), (0, 255, 41), (0, 255, 173), (10, 0, 255), + (173, 255, + 0), (0, 255, 153), (255, 92, 0), (255, 0, 255), (255, 0, + 245), (255, 0, 102), + (255, 173, 0), (255, 0, 20), (255, 184, + 184), (0, 31, 255), (0, 255, + 61), (0, 71, 255), + (255, 0, 204), (0, 255, 194), (0, 255, + 82), (0, 10, 255), (0, 112, + 255), (51, 0, 255), + (0, 194, 255), (0, 122, 255), (0, 255, 163), (255, 153, + 0), (0, 255, + 10), (255, 112, 0), + (143, 255, 0), (82, 0, 255), (163, 255, 0), (255, 235, + 0), (8, 184, + 170), (133, 0, 255), + (0, 255, 92), (184, 0, 255), (255, 0, 31), (0, 184, 255), (0, 214, + 255), + (255, 0, 112), (92, 255, + 0), (0, 224, 255), (112, 224, + 255), (70, 184, + 160), (163, 0, + 255), (153, 0, 255), + (71, 255, 0), (255, 0, 163), (255, 204, + 0), (255, 0, 143), (0, 255, + 235), (133, 255, 0), + (255, 0, 235), (245, 0, 255), (255, 0, 122), (255, 245, + 0), (10, 190, + 212), (214, 255, + 0), (0, 204, + 255), + (20, 0, 255), (255, 255, + 0), (0, 153, 255), (0, 41, 255), (0, 255, 204), (41, 0, + 255), + (41, 255, + 0), (173, 0, 255), (0, 245, 255), (71, 0, 255), (122, 0, + 255), (0, 255, 184), + (0, 92, 255), (184, 255, 0), (0, 133, 255), (255, 214, + 0), (25, 194, + 194), (102, 255, + 0), (92, 0, + 255)) } @@ -173,44 +197,44 @@ class ADE20KSegDataset(BaseSegDataset): 'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', 'clock', 'flag'), - palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], - [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], - [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], - [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], - [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], - [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], - [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], - [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], - [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], - [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], - [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], - [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], - [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], - [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], - [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], - [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], - [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], - [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], - [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], - [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], - [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], - [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], - [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], - [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], - [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], - [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], - [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], - [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], - [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], - [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], - [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], - [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], - [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], - [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], - [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], - [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], - [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], - [102, 255, 0], [92, 0, 255]]) + palette=[(120, 120, 120), (180, 120, 120), (6, 230, 230), (80, 50, 50), + (4, 200, 3), (120, 120, 80), (140, 140, 140), (204, 5, 255), + (230, 230, 230), (4, 250, 7), (224, 5, 255), (235, 255, 7), + (150, 5, 61), (120, 120, 70), (8, 255, 51), (255, 6, 82), + (143, 255, 140), (204, 255, 4), (255, 51, 7), (204, 70, 3), + (0, 102, 200), (61, 230, 250), (255, 6, 51), (11, 102, 255), + (255, 7, 71), (255, 9, 224), (9, 7, 230), (220, 220, 220), + (255, 9, 92), (112, 9, 255), (8, 255, 214), (7, 255, 224), + (255, 184, 6), (10, 255, 71), (255, 41, 10), (7, 255, 255), + (224, 255, 8), (102, 8, 255), (255, 61, 6), (255, 194, 7), + (255, 122, 8), (0, 255, 20), (255, 8, 41), (255, 5, 153), + (6, 51, 255), (235, 12, 255), (160, 150, 20), (0, 163, 255), + (140, 140, 140), (250, 10, 15), (20, 255, 0), (31, 255, 0), + (255, 31, 0), (255, 224, 0), (153, 255, 0), (0, 0, 255), + (255, 71, 0), (0, 235, 255), (0, 173, 255), (31, 0, 255), + (11, 200, 200), (255, 82, 0), (0, 255, 245), (0, 61, 255), + (0, 255, 112), (0, 255, 133), (255, 0, 0), (255, 163, 0), + (255, 102, 0), (194, 255, 0), (0, 143, 255), (51, 255, 0), + (0, 82, 255), (0, 255, 41), (0, 255, 173), (10, 0, 255), + (173, 255, 0), (0, 255, 153), (255, 92, 0), (255, 0, 255), + (255, 0, 245), (255, 0, 102), (255, 173, 0), (255, 0, 20), + (255, 184, 184), (0, 31, 255), (0, 255, 61), (0, 71, 255), + (255, 0, 204), (0, 255, 194), (0, 255, 82), (0, 10, 255), + (0, 112, 255), (51, 0, 255), (0, 194, 255), (0, 122, 255), + (0, 255, 163), (255, 153, 0), (0, 255, 10), (255, 112, 0), + (143, 255, 0), (82, 0, 255), (163, 255, 0), (255, 235, 0), + (8, 184, 170), (133, 0, 255), (0, 255, 92), (184, 0, 255), + (255, 0, 31), (0, 184, 255), (0, 214, 255), (255, 0, 112), + (92, 255, 0), (0, 224, 255), (112, 224, 255), (70, 184, 160), + (163, 0, 255), (153, 0, 255), (71, 255, 0), (255, 0, 163), + (255, 204, 0), (255, 0, 143), (0, 255, 235), (133, 255, 0), + (255, 0, 235), (245, 0, 255), (255, 0, 122), (255, 245, 0), + (10, 190, 212), (214, 255, 0), (0, 204, 255), (20, 0, 255), + (255, 255, 0), (0, 153, 255), (0, 41, 255), (0, 255, 204), + (41, 0, 255), (41, 255, 0), (173, 0, 255), (0, 245, 255), + (71, 0, 255), (122, 0, 255), (0, 255, 184), (0, 92, 255), + (184, 255, 0), (0, 133, 255), (255, 214, 0), (25, 194, 194), + (102, 255, 0), (92, 0, 255)]) def __init__(self, img_suffix='.jpg', @@ -225,7 +249,7 @@ def load_data_list(self) -> List[dict]: """Load annotation from directory or annotation file. Returns: - list[dict]: All data info of dataset. + list(dict): All data info of dataset. """ data_list = [] img_dir = self.data_prefix.get('img_path', None) @@ -242,6 +266,6 @@ def load_data_list(self) -> List[dict]: data_info['seg_map_path'] = osp.join(ann_dir, seg_map) data_info['label_map'] = self.label_map if self.return_classes: - data_info['text'] = list(self._metainfo['classes']) + data_info['text'] = list(self._metainfo('classes')) data_list.append(data_info) return data_list diff --git a/mmdet/datasets/coco_caption.py b/mmdet/datasets/coco_caption.py index e5af1ec59a6..bcc3ebdbc86 100644 --- a/mmdet/datasets/coco_caption.py +++ b/mmdet/datasets/coco_caption.py @@ -11,17 +11,7 @@ @DATASETS.register_module() class COCOCaptionDataset(BaseDataset): - """COCO Caption dataset. - - Args: - data_root (str): The root directory for ``data_prefix`` and - ``ann_file``.. - ann_file (str): Annotation file path. - data_prefix (dict): Prefix for data field. Defaults to - ``dict(img_path='')``. - pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. - **kwargs: Other keyword arguments in :class:`BaseDataset`. - """ + """COCO2014 Caption dataset.""" def load_data_list(self) -> List[dict]: """Load data list.""" diff --git a/mmdet/datasets/coco_semantic.py b/mmdet/datasets/coco_semantic.py index aadd44cbda4..8fc26ce8aa0 100644 --- a/mmdet/datasets/coco_semantic.py +++ b/mmdet/datasets/coco_semantic.py @@ -4,12 +4,11 @@ @DATASETS.register_module() -class CocoSegDaset(ADE20KSegDataset): +class CocoSegDataset(ADE20KSegDataset): """COCO dataset. - In segmentation map annotation for COCO, 0 stands for background, which is - not included in 150 categories. The ``img_suffix`` is fixed to '.jpg', and - ``seg_map_suffix`` is fixed to '.png'. + In segmentation map annotation for COCO. The ``img_suffix`` is fixed to + '.jpg', and ``seg_map_suffix`` is fixed to '.png'. """ METAINFO = dict( diff --git a/mmdet/datasets/transforms/loading.py b/mmdet/datasets/transforms/loading.py index bf933e3b3d8..fd3c05ce2cd 100644 --- a/mmdet/datasets/transforms/loading.py +++ b/mmdet/datasets/transforms/loading.py @@ -242,6 +242,8 @@ class LoadAnnotations(MMCV_LoadAnnotations): reduce_zero_label (bool): Whether reduce all label value by 1. Usually used for datasets where 0 is background label. Defaults to False. + ignore_index (int): The label index to be ignored. + Valid only if reduce_zero_label is true. Defaults is 255. imdecode_backend (str): The image decoding backend type. The backend argument for :func:``mmcv.imfrombytes``. See :fun:``mmcv.imfrombytes`` for details. diff --git a/mmdet/evaluation/metrics/coco_caption_metric.py b/mmdet/evaluation/metrics/coco_caption_metric.py index f8821301ae9..d8c7350150f 100644 --- a/mmdet/evaluation/metrics/coco_caption_metric.py +++ b/mmdet/evaluation/metrics/coco_caption_metric.py @@ -85,7 +85,7 @@ def compute_metrics(self, results: List): eval_result_file = save_result( result=results, result_dir=temp_dir, - filename='m4-caption_pred', + filename='caption_pred', remove_duplicate='image_id', ) diff --git a/mmdet/evaluation/metrics/semseg_metric.py b/mmdet/evaluation/metrics/semseg_metric.py index f951947f883..bc760860fbf 100644 --- a/mmdet/evaluation/metrics/semseg_metric.py +++ b/mmdet/evaluation/metrics/semseg_metric.py @@ -53,14 +53,14 @@ def __init__(self, output_dir: Optional[str] = None, format_only: bool = False, backend_args: dict = None, - prefix: Optional[str] = None, - **kwargs) -> None: + prefix: Optional[str] = None) -> None: super().__init__(collect_device=collect_device, prefix=prefix) if isinstance(iou_metrics, str): iou_metrics = [iou_metrics] if not set(iou_metrics).issubset(set(['mIoU', 'mDice', 'mFscore'])): - raise KeyError(f'metrics {iou_metrics} is not supported') + raise KeyError(f'metrics {iou_metrics} is not supported. ' + f'Only supports mIoU/mDice/mFscore.') self.metrics = iou_metrics self.beta = beta self.output_dir = output_dir @@ -86,7 +86,8 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: if not self.format_only: label = data_sample['gt_sem_seg']['sem_seg'].squeeze().to( pred_label) - ignore_index = data_sample['pred_sem_seg']['ignore_index'] + ignore_index = data_sample['pred_sem_seg'].get( + 'ignore_index', 255) self.results.append( self._compute_pred_stats(pred_label, label, num_classes, ignore_index)) @@ -153,7 +154,7 @@ def _compute_pred_stats(self, pred_label: torch.tensor, histogram on all classes. torch.Tensor: The union of prediction and ground truth histogram on all classes. - torch.Tens6or: The prediction histogram on all classes. + torch.Tensor: The prediction histogram on all classes. torch.Tensor: The ground truth histogram on all classes. """ assert pred_label.shape == label.shape diff --git a/mmdet/visualization/local_visualizer.py b/mmdet/visualization/local_visualizer.py index 45743d65994..d4cb3ee7bb9 100644 --- a/mmdet/visualization/local_visualizer.py +++ b/mmdet/visualization/local_visualizer.py @@ -123,7 +123,7 @@ def _draw_instances(self, image: np.ndarray, instances: ['InstanceData'], """ self.set_image(image) - if 'bboxes' in instances: + if 'bboxes' in instances and instances.bboxes.sum() > 0: bboxes = instances.bboxes labels = instances.labels @@ -253,17 +253,17 @@ def _draw_panoptic_seg(self, image: np.ndarray, panoptic_seg_data = panoptic_seg.sem_seg[0] + ids = np.unique(panoptic_seg_data)[::-1] + if 'label_names' in panoptic_seg: # open set panoptic segmentation classes = panoptic_seg.metainfo['label_names'] - ids = np.unique(panoptic_seg_data) - # for VOID label - ignore_index = panoptic_seg.metainfo.get('ignore_index', 255) + ignore_index = panoptic_seg.metainfo.get('ignore_index', + len(classes)) ids = ids[ids != ignore_index] else: - ids = np.unique(panoptic_seg_data)[::-1] - legal_indices = ids != num_classes # for VOID label - ids = ids[legal_indices] + # for VOID label + ids = ids[ids != num_classes] labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64) segms = (panoptic_seg_data[None] == ids[:, None, None]) diff --git a/projects/XDecoder/README.md b/projects/XDecoder/README.md index 28882e07bb6..73da624d3e7 100644 --- a/projects/XDecoder/README.md +++ b/projects/XDecoder/README.md @@ -24,9 +24,16 @@ mim install mmdet[multimodal] ## Models and results -### Semantic segmentation on ADE20K +For convenience, you can download the weights to the `mmdetection` root dir -**Prepare dataset** +```shell +wget https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_best_openseg.pt +wget https://download.openmmlab.com/mmdetection/v3.0/xdecoder/xdecoder_focalt_last_novg.pt +``` + +The above two weights are directly copied from the official website without any modification. The specific source is https://github.com/microsoft/X-Decoder + +### Semantic segmentation on ADE20K Prepare your dataset according to the [docs](https://mmsegmentation.readthedocs.io/en/latest/user_guides/2_dataset_prepare.html#ade20k). @@ -38,26 +45,81 @@ Since semantic segmentation is a pixel-level task, we don't need to use a thresh ./tools/dist_test.sh projects/XDecoder/configs/xdecoder-tiny_zeroshot_semseg.py xdecoder_focalt_best_openseg.pt 8 --cfg-options model.test_cfg.use_thr_for_mc=False ``` -| Model | mIoU | Config | Download | -| :---------------------------------- | :---: | :------------------------------------------------: | :---------------------------------------------------------------------------------------------: | -| `xdecoder_focalt_best_openseg.pt`\* | 25.13 | [config](configs/xdecoder-tiny_zeroshot_semseg.py) | [model](https://huggingface.co/xdecoder/X-Decoder/resolve/main/xdecoder_focalt_best_openseg.pt) | +| Model | mIoU | Config | +| :---------------------------------- | :---: | :------------------------------------------------: | +| `xdecoder_focalt_best_openseg.pt`\* | 25.13 | [config](configs/xdecoder-tiny_zeroshot_semseg.py) | + +### Instance segmentation on ADE20K + +### Panoptic segmentation on ADE20K + +### Semantic segmentation on COCO2017 + +Prepare your dataset according to the [docs](https://mmdetection.readthedocs.io/en/latest/user_guides/dataset_prepare.html#coco). + +**Test Command** + +```shell +./tools/dist_test.sh projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py xdecoder_focalt_last_novg.pt 8 +``` + +| Model | mIOU | Config | +| :------------------------------------------------ | :--: | :----------------------------------------------------------------: | +| `xdecoder-tiny_zeroshot_open-vocab-semseg_coco`\* | | [config](configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py) | ### Instance segmentation on COCO2017 +Prepare your dataset according to the [docs](https://mmdetection.readthedocs.io/en/latest/user_guides/dataset_prepare.html#coco). + +**Test Command** + ```shell ./tools/dist_test.sh projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py xdecoder_focalt_last_novg.pt 8 ``` -| Model | mAP | Config | Download | -| :-------------------------------------------------- | :--: | :------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------: | -| `xdecoder-tiny_zeroshot_open-vocab-instance_coco`\* | 39.7 | [config](configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py) | [model](https://huggingface.co/xdecoder/X-Decoder/resolve/main/xdecoder_focalt_best_openseg.pt) | +| Model | mAP | Config | +| :-------------------------------------------------- | :--: | :------------------------------------------------------------------: | +| `xdecoder-tiny_zeroshot_open-vocab-instance_coco`\* | 39.7 | [config](configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py) | + +### Panoptic segmentation on COCO2017 + +Prepare your dataset according to the [docs](https://mmdetection.readthedocs.io/en/latest/user_guides/dataset_prepare.html#coco). + +**Test Command** + +```shell +./tools/dist_test.sh projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-panoptic_coco.py xdecoder_focalt_last_novg.pt 8 +``` + +| Model | mIOU | Config | +| :-------------------------------------------------- | :--: | :------------------------------------------------------------------: | +| `xdecoder-tiny_zeroshot_open-vocab-panoptic_coco`\* | | [config](configs/xdecoder-tiny_zeroshot_open-vocab-panoptic_coco.py) | + +### Referring segmentation on RefCOCO ### Image Caption on COCO2014 +Prepare your dataset according to the [docs](https://mmdetection.readthedocs.io/en/latest/user_guides/dataset_prepare.html#coco_caption). + +Before testing, you need to install jdk 1.8, otherwise it will prompt that java does not exist during the evaluation process + +**Test Command** + ```shell ./tools/dist_test.sh projects/XDecoder/configs/xdecoder-tiny_zeroshot_caption_coco2014.py xdecoder_focalt_last_novg.pt 8 ``` -| Model | BLEU-4 | CIDER | Config | Download | -| :------------------------------------------ | :----: | :----: | :----------------------------------------------------------: | :------------------------------------------------------------------------------------------: | -| `xdecoder-tiny_zeroshot_caption_coco2014`\* | 35.14 | 116.62 | [config](configs/xdecoder-tiny_zeroshot_caption_coco2014.py) | [model](https://huggingface.co/xdecoder/X-Decoder/resolve/main/xdecoder_focalt_last_novg.pt) | +| Model | BLEU-4 | CIDER | Config | +| :------------------------------------------ | :----: | :----: | :----------------------------------------------------------: | +| `xdecoder-tiny_zeroshot_caption_coco2014`\* | 35.14 | 116.62 | [config](configs/xdecoder-tiny_zeroshot_caption_coco2014.py) | + +## Citation + +```latex +@article{zou2022xdecoder, + author = {Zou*, Xueyan and Dou*, Zi-Yi and Yang*, Jianwei and Gan, Zhe and Li, Linjie and Li, Chunyuan and Dai, Xiyang and Wang, Jianfeng and Yuan, Lu and Peng, Nanyun and Wang, Lijuan and Lee*, Yong Jae and Gao*, Jianfeng}, + title = {Generalized Decoding for Pixel, Image and Language}, + publisher = {arXiv}, + year = {2022}, +} +``` diff --git a/projects/XDecoder/configs/_base_/xdecoder-tiny_ref-semseg.py b/projects/XDecoder/configs/_base_/xdecoder-tiny_ref-semseg.py index 595affa8ce0..6101474b8e1 100644 --- a/projects/XDecoder/configs/_base_/xdecoder-tiny_ref-semseg.py +++ b/projects/XDecoder/configs/_base_/xdecoder-tiny_ref-semseg.py @@ -1,3 +1,3 @@ _base_ = 'xdecoder-tiny_open-vocab-semseg.py' -model = dict(head=dict(task='ref-semseg')) +model = dict(head=dict(task='ref-seg')) diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py index 863b01036be..0ea725a5b2b 100644 --- a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_coco.py @@ -1,6 +1,6 @@ _base_ = '_base_/xdecoder-tiny_open-vocab-semseg.py' -dataset_type = 'CocoSegDaset' +dataset_type = 'CocoSegDataset' data_root = 'data/coco/' test_pipeline = [ diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_ref-seg.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_ref-seg.py new file mode 100644 index 00000000000..4742ba353a5 --- /dev/null +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_ref-seg.py @@ -0,0 +1,3 @@ +_base_ = 'xdecoder-tiny_zeroshot_open-vocab-semseg_ade20k.py' + +model = dict(head=dict(task='ref-seg')) diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_ref-semseg.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_ref-semseg.py deleted file mode 100644 index a9f01c8215b..00000000000 --- a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_ref-semseg.py +++ /dev/null @@ -1,3 +0,0 @@ -_base_ = 'xdecoder-tiny_zeroshot_open-vocab-semseg.py' - -model = dict(head=dict(task='ref-semseg')) diff --git a/projects/XDecoder/demo.py b/projects/XDecoder/demo.py index 21cf957d245..fb281c85f1e 100644 --- a/projects/XDecoder/demo.py +++ b/projects/XDecoder/demo.py @@ -11,7 +11,7 @@ TASKINFOS = { 'semseg': DetInferencer, - 'ref-semseg': DetInferencer, + 'ref-seg': DetInferencer, 'instance': DetInferencer, 'panoptic': DetInferencer, 'caption': ImageCaptionInferencer, diff --git a/projects/XDecoder/xdecoder/inference/image_caption.py b/projects/XDecoder/xdecoder/inference/image_caption.py index be496fb1597..ee8f859873d 100644 --- a/projects/XDecoder/xdecoder/inference/image_caption.py +++ b/projects/XDecoder/xdecoder/inference/image_caption.py @@ -252,13 +252,13 @@ def __call__( for ori_inputs, grounding_data, caption_data in track( inputs, description='Inference'): - self.model.sem_seg_head.task = 'ref-semseg' - self.model.sem_seg_head.predictor.task = 'ref-semseg' + self.model.sem_seg_head.task = 'ref-seg' + self.model.sem_seg_head.predictor.task = 'ref-seg' preds = self.forward(grounding_data, **forward_kwargs) for data_sample, pred_datasmaple in zip( caption_data['data_samples'], preds): - data_sample.pred_sem_seg = pred_datasmaple.pred_sem_seg + data_sample.pred_instances = pred_datasmaple.pred_instances data_sample.set_metainfo({ 'grounding_img_shape': pred_datasmaple.metainfo['img_shape'] diff --git a/projects/XDecoder/xdecoder/inference/texttoimage_regionretrieval_inferencer.py b/projects/XDecoder/xdecoder/inference/texttoimage_regionretrieval_inferencer.py index ec254c6cf21..0aa091bbb24 100644 --- a/projects/XDecoder/xdecoder/inference/texttoimage_regionretrieval_inferencer.py +++ b/projects/XDecoder/xdecoder/inference/texttoimage_regionretrieval_inferencer.py @@ -191,9 +191,9 @@ def __call__( batch_size=1, **preprocess_kwargs) - self.model.task = 'ref-semseg' - self.model.sem_seg_head.task = 'ref-semseg' - self.model.sem_seg_head.predictor.task = 'ref-semseg' + self.model.task = 'ref-seg' + self.model.sem_seg_head.task = 'ref-seg' + self.model.sem_seg_head.predictor.task = 'ref-seg' ori_inputs, grounding_data = next(inputs) diff --git a/projects/XDecoder/xdecoder/transformer_decoder.py b/projects/XDecoder/xdecoder/transformer_decoder.py index 06d83b97fa7..4c1165b0e6e 100644 --- a/projects/XDecoder/xdecoder/transformer_decoder.py +++ b/projects/XDecoder/xdecoder/transformer_decoder.py @@ -147,7 +147,7 @@ def forward(self, x, mask_features, extra=None): predictions_mask = [] predictions_class_embed = [] - if self.task == 'ref-semseg': + if self.task == 'ref-seg': self_tgt_mask = self.self_attn_mask[:, :self.num_queries, :self. num_queries].repeat( output.shape[1] * @@ -196,7 +196,7 @@ def forward(self, x, mask_features, extra=None): pos=pos[level_index], query_pos=query_embed) - if self.task == 'ref-semseg': + if self.task == 'ref-seg': output = torch.cat((output, _grounding_tokens), dim=0) query_embed = torch.cat((query_embed, grounding_tokens), dim=0) @@ -208,7 +208,7 @@ def forward(self, x, mask_features, extra=None): output = self.transformer_ffn_layers[i](output) - if self.task == 'ref-semseg': + if self.task == 'ref-seg': _grounding_tokens = output[-len(_grounding_tokens):] output = output[:-len(_grounding_tokens)] query_embed = query_embed[:-len(_grounding_tokens)] @@ -227,8 +227,9 @@ def forward(self, x, mask_features, extra=None): 'pred_class_embed': predictions_class_embed[-1], } - if self.task == 'ref-semseg': + if self.task == 'ref-seg': mask_pred_results = [] + outputs_class = [] for idx in range(mask_features.shape[0]): # batch size pred_gmasks = out['pred_masks'][idx, self.num_queries:2 * self.num_queries - 1] @@ -244,7 +245,9 @@ def forward(self, x, mask_features, extra=None): matched_id = out_prob.max(0)[1] mask_pred_results += [pred_gmasks[matched_id, :, :]] + outputs_class += [out_prob[matched_id, :]] out['pred_masks'] = mask_pred_results + out['pred_logits'] = outputs_class elif self.task == 'retrieval': t_emb = extra['class_emb'] temperature = self.lang_encoder.logit_scale @@ -387,7 +390,7 @@ def forward_prediction_heads(self, output, mask_features, cls_token = (sim * decoder_output[:, :self.num_queries - 1]).sum( dim=1, keepdim=True) - if self.task == 'ref-semseg': + if self.task == 'ref-seg': decoder_output = torch.cat( (decoder_output[:, :self.num_queries - 1], cls_token, decoder_output[:, self.num_queries:2 * self.num_queries - 1]), diff --git a/projects/XDecoder/xdecoder/unified_head.py b/projects/XDecoder/xdecoder/unified_head.py index 1a45a1c7c33..ec852b1d0df 100644 --- a/projects/XDecoder/xdecoder/unified_head.py +++ b/projects/XDecoder/xdecoder/unified_head.py @@ -35,7 +35,7 @@ def __init__(self, self.return_inter_mask = False if self.task == 'ref-caption': - # ref-caption = ref-semseg + caption, + # ref-caption = ref-seg + caption, # so we need to return the intermediate mask self.return_inter_mask = True @@ -83,7 +83,7 @@ def pre_process(self, batch_data_samples, device): if self.task in ['semseg', 'instance', 'panoptic']: self.predictor.lang_encoder.get_mean_embeds( all_text_prompts + ['background']) - elif self.task == 'ref-semseg': + elif self.task == 'ref-seg': token_info = self.predictor.lang_encoder.get_text_embeds( all_text_prompts, norm=False) token_emb = token_info['token_emb'] @@ -134,7 +134,7 @@ def post_process(self, predictions, batch_data_samples, all_text_prompts, batch_data_samples): data_samples.pred_caption = text - if 'pred_sem_seg' in batch_data_samples[0]: + if 'pred_instances' in batch_data_samples[0]: for img_metas, data_samples in zip(batch_img_metas, batch_data_samples): original_caption = data_samples.text.split('.') @@ -145,12 +145,20 @@ def post_process(self, predictions, batch_data_samples, all_text_prompts, width = img_metas['ori_shape'][1] image_size = img_metas['grounding_img_shape'][:2] + mask_pred_result = data_samples.pred_instances.masks.float( + ) + mask_cls_result = data_samples.pred_instances.scores.float( + ) + mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( - data_samples.pred_sem_seg.sem_seg.float(), image_size, - height, width) - pred_sem_seg = retry_if_cuda_oom(self._semantic_inference)( - None, mask_pred_result, text_prompts) - data_samples.pred_sem_seg = pred_sem_seg + mask_pred_result, image_size, height, width) + + pred_instances = retry_if_cuda_oom( + self._instance_inference)(mask_cls_result, + mask_pred_result, + text_prompts) + data_samples.pred_instances = pred_instances + elif self.task in ['semseg', 'instance', 'panoptic']: mask_pred_results = predictions['pred_masks'] mask_cls_results = predictions['pred_logits'] @@ -168,22 +176,15 @@ def post_process(self, predictions, batch_data_samples, all_text_prompts, align_corners=False, antialias=True) - # used for ref-caption - if self.return_inter_mask: - sem_seg = mask_pred_results[0] > 0 - pred_sem_seg = PixelData(sem_seg=sem_seg[None]) - batch_data_samples[0].pred_sem_seg = pred_sem_seg - return batch_data_samples - # for batch for mask_cls_result, \ mask_pred_result, \ img_metas, \ data_samples in zip( - mask_cls_results, - mask_pred_results, - batch_img_metas, - batch_data_samples): + mask_cls_results, + mask_pred_results, + batch_img_metas, + batch_data_samples): height = img_metas['ori_shape'][0] width = img_metas['ori_shape'][1] image_size = img_metas['img_shape'][:2] @@ -208,28 +209,33 @@ def post_process(self, predictions, batch_data_samples, all_text_prompts, all_text_prompts, num_thing_class) data_samples.pred_panoptic_seg = pred_panoptic_seg - elif self.task == 'ref-semseg': + elif self.task == 'ref-seg': mask_pred_results = predictions['pred_masks'] - for mask_pred_result, img_metas, data_samples in zip( - mask_pred_results, batch_img_metas, batch_data_samples): + mask_cls_results = predictions['pred_logits'] + results_ = zip(mask_pred_results, mask_cls_results, + batch_img_metas, batch_data_samples) + for mask_pred_result, mask_cls_result, \ + img_metas, data_samples in results_: if is_lower_torch_version(): mask_pred_result = F.interpolate( - mask_pred_result[None, ], + mask_pred_result[None], size=(batch_input_shape[-2], batch_input_shape[-1]), mode='bicubic', align_corners=False)[0] else: mask_pred_result = F.interpolate( - mask_pred_result[None, ], + mask_pred_result[None], size=(batch_input_shape[-2], batch_input_shape[-1]), mode='bicubic', align_corners=False, antialias=True)[0] if self.return_inter_mask: - sem_seg = mask_pred_result > 0 - pred_sem_seg = PixelData(sem_seg=sem_seg) - data_samples.pred_sem_seg = pred_sem_seg + mask = mask_pred_result > 0 + pred_instances = InstanceData() + pred_instances.masks = mask + pred_instances.scores = mask_cls_result + data_samples.pred_instances = pred_instances continue height = img_metas['ori_shape'][0] @@ -238,9 +244,9 @@ def post_process(self, predictions, batch_data_samples, all_text_prompts, mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( mask_pred_result, image_size, height, width) - pred_sem_seg = retry_if_cuda_oom(self._semantic_inference)( - None, mask_pred_result, all_text_prompts) - data_samples.pred_sem_seg = pred_sem_seg + pred_instances = retry_if_cuda_oom(self._instance_inference)( + mask_cls_result, mask_pred_result, all_text_prompts) + data_samples.pred_instances = pred_instances elif self.task == 'retrieval': batch_data_samples[0].pred_score = predictions['pred_logits'] return batch_data_samples @@ -248,24 +254,30 @@ def post_process(self, predictions, batch_data_samples, all_text_prompts, def _instance_inference(self, mask_cls, mask_pred, text_prompts): num_class = len(text_prompts) - scores = F.softmax(mask_cls, dim=-1)[:, :-1] + if self.task in ['ref-seg', 'caption']: + scores = F.softmax(mask_cls, dim=-1) + scores_per_image = scores.max(dim=-1)[0] + labels_per_image = torch.arange(num_class) + else: + scores = F.softmax(mask_cls, dim=-1)[:, :-1] - labels = torch.arange( - num_class, - device=scores.device).unsqueeze(0).repeat(scores.shape[0], - 1).flatten(0, 1) - scores_per_image, topk_indices = scores.flatten(0, 1).topk( - self.test_cfg.max_per_img, sorted=False) + labels = torch.arange( + num_class, + device=scores.device).unsqueeze(0).repeat(scores.shape[0], + 1).flatten(0, 1) + scores_per_image, topk_indices = scores.flatten(0, 1).topk( + self.test_cfg.get('max_per_img', 100), sorted=False) - labels_per_image = labels[topk_indices] - topk_indices = (topk_indices // num_class) - mask_pred = mask_pred[topk_indices] + labels_per_image = labels[topk_indices] + topk_indices = (topk_indices // num_class) + mask_pred = mask_pred[topk_indices] result = InstanceData() - result.masks = (mask_pred > 0).float() + mask_pred = mask_pred.sigmoid() + result.masks = (mask_pred > self.test_cfg.mask_thr).float() # calculate average mask prob - mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * + mask_scores_per_image = (mask_pred.flatten(1) * result.masks.flatten(1)).sum(1) / ( result.masks.flatten(1).sum(1) + 1e-6) result.scores = scores_per_image * mask_scores_per_image @@ -277,12 +289,9 @@ def _instance_inference(self, mask_cls, mask_pred, text_prompts): return result def _semantic_inference(self, mask_cls, mask_pred, text_prompts): - if mask_cls is None: - sem_seg = mask_pred.sigmoid() - else: - mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] - mask_pred = mask_pred.sigmoid() - sem_seg = torch.einsum('qc,qhw->chw', mask_cls, mask_pred) + mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1] + mask_pred = mask_pred.sigmoid() + sem_seg = torch.einsum('qc,qhw->chw', mask_cls, mask_pred) if sem_seg.shape[0] == 1: # 0 is foreground, ignore_index is background diff --git a/tools/dataset_converters/prepare_coco_semantic_annos_from_panoptic_annos.py b/tools/dataset_converters/prepare_coco_semantic_annos_from_panoptic_annos.py index bf963c8725c..ac1f2dc4ae3 100644 --- a/tools/dataset_converters/prepare_coco_semantic_annos_from_panoptic_annos.py +++ b/tools/dataset_converters/prepare_coco_semantic_annos_from_panoptic_annos.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +# Modified from https://github.com/facebookresearch/Mask2Former/blob/main/datasets/prepare_coco_semantic_annos_from_panoptic_annos.py # noqa + import argparse import functools import json