Skip to content

Commit

Permalink
Add semseg eval result (#10452)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiexinch committed Jun 7, 2023
1 parent eb8e3b6 commit f18f02c
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 8 deletions.
29 changes: 29 additions & 0 deletions configs/_base_/datasets/ade20k_semseg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
dataset_type = 'ADE20KSemsegDataset'
data_root = 'data/ade/ADEChallengeData2016'

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 512), keep_ratio=True),
dict(type='LoadSemSegAnnotations'),
dict(
type='PackDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'seg_map_path', 'img',
'gt_seg_map', 'text'))
]

val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/validation',
seg_map_path='annotations/validation'),
pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(type='SemSegMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator
4 changes: 2 additions & 2 deletions mmdet/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ade20k import ADE20KDataset, ADE20KPanopticDataset
from .ade20k import ADE20KPanopticDataset, ADE20KSemsegDataset
from .base_det_dataset import BaseDetDataset
from .base_semseg_dataset import BaseSegDataset
from .base_video_dataset import BaseVideoDataset
Expand Down Expand Up @@ -37,5 +37,5 @@
'BaseVideoDataset', 'MOTChallengeDataset', 'TrackImgSampler',
'ReIDDataset', 'YouTubeVISDataset', 'TrackAspectRatioBatchSampler',
'ADE20KPanopticDataset', 'COCOCaptionDataset', 'RefCOCODataset',
'BaseSegDataset', 'ADE20KDataset'
'BaseSegDataset', 'ADE20KSemsegDataset'
]
2 changes: 1 addition & 1 deletion mmdet/datasets/ade20k.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class ADE20KPanopticDataset(CocoPanopticDataset):


@DATASETS.register_module()
class ADE20KDataset(BaseSegDataset):
class ADE20KSemsegDataset(BaseSegDataset):
"""ADE20K dataset.
In segmentation map annotation for ADE20K, 0 stands for background, which
Expand Down
4 changes: 3 additions & 1 deletion mmdet/datasets/base_semseg_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(self,
pipeline: List[Union[dict, Callable]] = [],
test_mode: bool = False,
lazy_init: bool = False,
use_label_map: bool = True,
max_refetch: int = 1000,
backend_args: Optional[dict] = None) -> None:

Expand All @@ -113,7 +114,8 @@ def __init__(self,

# Get label map for custom classes
new_classes = self._metainfo.get('classes', None)
self.label_map = self.get_label_map(new_classes)
self.label_map = self.get_label_map(
new_classes) if use_label_map else None
self._metainfo.update(dict(label_map=self.label_map))

# Update palette based on label map or generate palette
Expand Down
2 changes: 1 addition & 1 deletion mmdet/evaluation/metrics/semseg_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _compute_pred_stats(self, pred_label: torch.tensor,
assert pred_label.shape == label.shape
# 0 is background
mask = label != 0
pred_label = (pred_label + 1) * mask
pred_label = pred_label * mask
intersect = pred_label[pred_label == label]
area_intersect = torch.histc(
intersect.float(), bins=(num_classes), min=1, max=num_classes)
Expand Down
33 changes: 33 additions & 0 deletions projects/XDecoder/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# X-Decoder

> [X-Decoder: Generalized Decoding for Pixel, Image, and Language](https://arxiv.org/pdf/2212.11270.pdf)
<!-- [ALGORITHM] -->

## Abstract

We present X-Decoder, a generalized decoding model that can predict pixel-level segmentation and language tokens seamlessly. X-Decodert takes as input two types of queries: (i) generic non-semantic queries and (ii) semantic queries induced from text inputs, to decode different pixel-level and token-level outputs in the same semantic space. With such a novel design, X-Decoder is the first work that provides a unified way to support all types of image segmentation and a variety of vision-language (VL) tasks. Further, our design enables seamless interactions across tasks at different granularities and brings mutual benefits by learning a common and rich pixel-level visual-semantic understanding space, without any pseudo-labeling. After pretraining on a mixed set of a limited amount of segmentation data and millions of image-text pairs, X-Decoder exhibits strong transferability to a wide range of downstream tasks in both zero-shot and finetuning settings. Notably, it achieves (1) state-of-the-art results on open-vocabulary segmentation and referring segmentation on eight datasets; (2) better or competitive finetuned performance to other generalist and specialist models on segmentation and VL tasks; and (3) flexibility for efficient finetuning and novel task composition (e.g., referring captioning and image editing).

![img](https://raw.githubusercontent.com/microsoft/X-Decoder/main/images/teaser_new.png)

## Models and results

### Semantic segmentation on ADE20K

**Prepare dataset**

Prepare your dataset according to the [docs](https://mmsegmentation.readthedocs.io/en/latest/user_guides/2_dataset_prepare.html#ade20k).

**Test Command**

Since semantic segmentation is a pixel-level task, we don't need to use a threshold to filter out low-confidence predictions. So we set `model.test_cfg.use_thr_for_mc=False` in the test command.

````shell

```shell
./tools/dist_test.sh projects/XDecoder/configs/xdecoder-tiny_zeroshot_semseg.py xdecoder_focalt_best_openseg.pt 8 --cfg-options model.test_cfg.use_thr_for_mc=False
````
| Model | mIoU | Config | Download |
| :---------------------------------- | :---: | :------------------------------------------------: | :---------------------------------------------------------------------------------------------: |
| `xdecoder_focalt_best_openseg.pt`\* | 25.13 | [config](configs/xdecoder-tiny_zeroshot_semseg.py) | [model](https://huggingface.co/xdecoder/X-Decoder/resolve/main/xdecoder_focalt_best_openseg.pt) |
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
head=dict(
type='XDecoderUnifiedhead',
in_channels=(96, 192, 384, 768),
task='semseg',
pixel_decoder=dict(type='XTransformerEncoderPixelDecoder'),
transformer_decoder=dict(type='XDecoderTransformerDecoder'),
task='semseg',
),
test_cfg=dict(mask_thr=0.5, use_thr_for_mc=True) # mc means multi-class
# use_thr_for_mc=True means use threshold for multi-class
# This parameter is only used in semantic segmentation task and
# referring semantic segmentation task.
test_cfg=dict(mask_thr=0.5, use_thr_for_mc=True),
)

val_cfg = dict(type='ValLoop')
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
_base_ = [
'mmdet::_base_/default_runtime.py',
'_base_/xdecoder-tiny_open-vocab-semseg.py',
'mmdet::_base_/datasets/ade20k_semseg.py'
]

test_pipeline = [
dict(type='LoadImageFromFile', imdecode_backend='pillow'),
dict(
type='FixScaleResize',
scale=640,
keep_ratio=True,
short_side_mode=True,
backend='pillow',
interpolation='bicubic'),
dict(type='LoadSemSegAnnotations'),
dict(
type='PackDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'seg_map_path', 'img',
'gt_seg_map', 'text'))
]

x_decoder_ade20k_classes = (
'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed',
'window', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', 'door',
'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 'water',
'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 'field',
'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', 'tub',
'rail', 'cushion', 'base', 'box', 'column', 'signboard',
'chest of drawers', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace',
'refrigerator', 'grandstand', 'path', 'stairs', 'runway', 'case',
'pool table', 'pillow', 'screen door', 'stairway', 'river', 'bridge',
'bookcase', 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill',
'bench', 'countertop', 'stove', 'palm', 'kitchen island', 'computer',
'swivel chair', 'boat', 'bar', 'arcade machine', 'hovel', 'bus', 'towel',
'light', 'truck', 'tower', 'chandelier', 'awning', 'street lamp', 'booth',
'tv', 'airplane', 'dirt track', 'clothes', 'pole', 'land', 'bannister',
'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
'pool', 'stool', 'barrel', 'basket', 'falls', 'tent', 'bag', 'minibike',
'cradle', 'oven', 'ball', 'food', 'step', 'tank', 'trade name',
'microwave', 'pot', 'animal', 'bicycle', 'lake', 'dishwasher', 'screen',
'blanket', 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', 'tray',
'trash can', 'fan', 'pier', 'crt screen', 'plate', 'monitor',
'bulletin board', 'shower', 'radiator', 'glass', 'clock', 'flag')

val_dataloader = dict(
dataset=dict(
metainfo=dict(classes=x_decoder_ade20k_classes),
return_classes=True,
use_label_map=False,
pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(type='SemSegMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator

val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
8 changes: 7 additions & 1 deletion projects/XDecoder/xdecoder/unified_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@ def pre_process(self, batch_data_samples, device):
all_text_prompts = []
stuff_text_prompts = []
for data_samples in batch_data_samples:
text = data_samples.text.split('.')
if isinstance(data_samples.text, str):
text = data_samples.text.split('.')
elif isinstance(data_samples.text, list):
text = data_samples.text
else:
raise TypeError(
'Type pf data_sample.text must be list or str')
text = list(filter(lambda x: len(x) > 0, text))
all_text_prompts.append(text)

Expand Down

0 comments on commit f18f02c

Please sign in to comment.