Skip to content

Commit

Permalink
refactor config (#10458)
Browse files Browse the repository at this point in the history
  • Loading branch information
hhaAndroid committed Jun 7, 2023
1 parent f18f02c commit a667312
Show file tree
Hide file tree
Showing 17 changed files with 190 additions and 262 deletions.
26 changes: 22 additions & 4 deletions configs/_base_/datasets/ade20k_semseg.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,32 @@
dataset_type = 'ADE20KSemsegDataset'
data_root = 'data/ade/ADEChallengeData2016'

# Example to use different file client
# Method 1: simply set the data root and let the file I/O module
# automatically infer from prefix (not support LMDB and Memcache yet)

# data_root = 's3://openmmlab/datasets/detection/coco/'

# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6
# backend_args = dict(
# backend='petrel',
# path_mapping=dict({
# './data/': 's3://openmmlab/datasets/detection/',
# 'data/': 's3://openmmlab/datasets/detection/'
# }))
backend_args = None

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadImageFromFile', backend_args=backend_args),
dict(type='Resize', scale=(2048, 512), keep_ratio=True),
dict(type='LoadSemSegAnnotations'),
dict(
type='LoadAnnotations',
with_bbox=False,
with_mask=False,
with_seg=True),
dict(
type='PackDetInputs',
meta_keys=('img_path', 'ori_shape', 'img_shape', 'seg_map_path', 'img',
'gt_seg_map', 'text'))
meta_keys=('img_path', 'ori_shape', 'img_shape', 'text'))
]

val_dataloader = dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,5 @@
]

val_dataloader = dict(
dataset=dict(pipeline=test_pipeline, return_caption=True))
dataset=dict(pipeline=test_pipeline, return_classes=True))
test_dataloader = val_dataloader
6 changes: 2 additions & 4 deletions mmdet/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from .loading import (FilterAnnotations, InferencerLoader, LoadAnnotations,
LoadEmptyAnnotations, LoadImageFromNDArray,
LoadMultiChannelImageFromFiles, LoadPanopticAnnotations,
LoadProposals, LoadSemSegAnnotations,
LoadTrackAnnotations)
LoadProposals, LoadTrackAnnotations)
from .transforms import (Albu, CachedMixUp, CachedMosaic, CopyPaste, CutOut,
Expand, FixScaleResize, FixShapeResize,
MinIoURandomCrop, MixUp, Mosaic, Pad,
Expand All @@ -38,6 +37,5 @@
'LoadEmptyAnnotations', 'RandomOrder', 'CachedMosaic', 'CachedMixUp',
'FixShapeResize', 'ProposalBroadcaster', 'InferencerLoader',
'LoadTrackAnnotations', 'BaseFrameSample', 'UniformRefFrameSample',
'PackTrackInputs', 'PackReIDInputs', 'FixScaleResize',
'LoadSemSegAnnotations'
'PackTrackInputs', 'PackReIDInputs', 'FixScaleResize'
]
94 changes: 28 additions & 66 deletions mmdet/datasets/transforms/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,34 @@ def _load_masks(self, results: dict) -> None:
gt_masks = PolygonMasks([mask for mask in gt_masks], h, w)
results['gt_masks'] = gt_masks

def _load_seg_map(self, results: dict) -> None:
"""Private function to load semantic segmentation annotations.
Args:
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
Returns:
dict: The dict contains loaded semantic segmentation annotations.
"""
if results.get('seg_map_path', None) is None:
return

img_bytes = get(
results['seg_map_path'], backend_args=self.backend_args)
gt_semantic_seg = mmcv.imfrombytes(
img_bytes, flag='unchanged',
backend=self.imdecode_backend).squeeze()

# modify if custom classes
if results.get('label_map', None) is not None:
# Add deep copy to solve bug of repeatedly
# replace `gt_semantic_seg`, which is reported in
# https://github.com/open-mmlab/mmsegmentation/pull/1445/
gt_semantic_seg_copy = gt_semantic_seg.copy()
for old_id, new_id in results['label_map'].items():
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
results['gt_seg_map'] = gt_semantic_seg

def transform(self, results: dict) -> dict:
"""Function to load multiple types annotations.
Expand Down Expand Up @@ -600,72 +628,6 @@ def transform(self, results: dict) -> dict:
return results


@TRANSFORMS.register_module()
class LoadSemSegAnnotations(LoadAnnotations):
"""Load annotations for semantic segmentation provided by dataset.
The annotation format is as the following:
.. code-block:: python
{
# Filename of semantic segmentation ground truth file.
'seg_map_path': 'a/b/c'
}
After this module, the annotation has been changed to the format below:
.. code-block:: python
{
# In uint8 type.
'gt_seg_map': np.ndarray (H, W)
}
Required Keys:
- seg_map_path (str): Path of semantic segmentation ground truth file.
Added Keys:
- gt_seg_map (np.uint8)
"""

def __init__(self, **kwargs) -> None:
super().__init__(
with_bbox=False,
with_label=False,
with_seg=True,
with_keypoints=False,
**kwargs)

def _load_seg_map(self, results: dict) -> None:
"""Private function to load semantic segmentation annotations.
Args:
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
Returns:
dict: The dict contains loaded semantic segmentation annotations.
"""

img_bytes = get(
results['seg_map_path'], backend_args=self.backend_args)
gt_semantic_seg = mmcv.imfrombytes(
img_bytes, flag='unchanged',
backend=self.imdecode_backend).squeeze().astype(np.uint8)

# modify if custom classes
if results.get('label_map', None) is not None:
# Add deep copy to solve bug of repeatedly
# replace `gt_semantic_seg`, which is reported in
# https://github.com/open-mmlab/mmsegmentation/pull/1445/
gt_semantic_seg_copy = gt_semantic_seg.copy()
for old_id, new_id in results['label_map'].items():
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
results['gt_seg_map'] = gt_semantic_seg


@TRANSFORMS.register_module()
class LoadProposals(BaseTransform):
"""Load proposal pipeline.
Expand Down
9 changes: 3 additions & 6 deletions mmdet/evaluation/metrics/semseg_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,13 @@ def _compute_pred_stats(self, pred_label: torch.tensor,
torch.Tensor: The ground truth histogram on all classes.
"""
assert pred_label.shape == label.shape
# 0 is background
mask = label != 0
pred_label = pred_label * mask
intersect = pred_label[pred_label == label]
area_intersect = torch.histc(
intersect.float(), bins=(num_classes), min=1, max=num_classes)
intersect.float(), bins=num_classes, min=0, max=num_classes-1)
area_pred_label = torch.histc(
pred_label.float(), bins=(num_classes), min=1, max=num_classes)
pred_label.float(), bins=num_classes, min=0, max=num_classes-1)
area_label = torch.histc(
label.float(), bins=(num_classes), min=1, max=num_classes)
label.float(), bins=num_classes, min=0, max=num_classes-1)
area_union = area_pred_label + area_label - area_intersect
result = dict(
area_intersect=area_intersect,
Expand Down
27 changes: 9 additions & 18 deletions mmdet/visualization/local_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,18 +255,11 @@ def _draw_panoptic_seg(self, image: np.ndarray,

if 'label_names' in panoptic_seg:
# open set panoptic segmentation
label_names = panoptic_seg.metainfo['label_names']
classes = panoptic_seg.metainfo['label_names']
ids = np.unique(panoptic_seg_data)
if {0: 'background'} in label_names:
label_names.remove({0: 'background'})
# 0 = background
ids = ids[ids != 0]
label_names_values = []
for id in ids:
for name in label_names:
if id in name.keys():
label_names_values.append(name[id])
break
# for VOID label
bg_index = panoptic_seg.metainfo.get('bg_index', 255)
ids = ids[ids != bg_index]
else:
ids = np.unique(panoptic_seg_data)[::-1]
legal_indices = ids != num_classes # for VOID label
Expand Down Expand Up @@ -305,10 +298,7 @@ def _draw_panoptic_seg(self, image: np.ndarray,
text_colors = [text_palette[label] for label in labels]

for i, (pos, label) in enumerate(zip(positions, labels)):
if 'label_names' in panoptic_seg:
label_text = label_names_values[i]
else:
label_text = classes[label]
label_text = classes[label]

self.draw_texts(
label_text,
Expand Down Expand Up @@ -352,20 +342,21 @@ def _draw_sem_seg(self, image: np.ndarray, sem_seg: PixelData,

# 0 ~ num_class, the value 0 means background
ids = np.unique(sem_seg_data)
ids = ids[ids != 0]

if 'label_names' in sem_seg:
# open set semseg
label_names = sem_seg.metainfo['label_names']
bg_index = sem_seg.metainfo.get('bg_index', 255)
ids = ids[ids != bg_index]

labels = np.array(ids, dtype=np.int64) - 1
labels = np.array(ids, dtype=np.int64)
colors = [palette[label] for label in labels]

self.set_image(image)

# draw semantic masks
for i, (label, color) in enumerate(zip(labels, colors)):
masks = sem_seg_data == (label + 1)
masks = sem_seg_data == label
self.draw_binary_masks(masks, colors=[color], alphas=self.alpha)
if 'label_names' in sem_seg:
label_text = label_names[label]
Expand Down
36 changes: 33 additions & 3 deletions projects/XDecoder/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@ We present X-Decoder, a generalized decoding model that can predict pixel-level

![img](https://raw.githubusercontent.com/microsoft/X-Decoder/main/images/teaser_new.png)

## Installation

```shell
# if source
pip install -r requirements/multimodal.txt

# if wheel
mim install mmdet[multimodal]
```

## How to use it?

## Models and results

### Semantic segmentation on ADE20K
Expand All @@ -22,12 +34,30 @@ Prepare your dataset according to the [docs](https://mmsegmentation.readthedocs.

Since semantic segmentation is a pixel-level task, we don't need to use a threshold to filter out low-confidence predictions. So we set `model.test_cfg.use_thr_for_mc=False` in the test command.

````shell

```shell
./tools/dist_test.sh projects/XDecoder/configs/xdecoder-tiny_zeroshot_semseg.py xdecoder_focalt_best_openseg.pt 8 --cfg-options model.test_cfg.use_thr_for_mc=False
````
```

| Model | mIoU | Config | Download |
| :---------------------------------- | :---: | :------------------------------------------------: | :---------------------------------------------------------------------------------------------: |
| `xdecoder_focalt_best_openseg.pt`\* | 25.13 | [config](configs/xdecoder-tiny_zeroshot_semseg.py) | [model](https://huggingface.co/xdecoder/X-Decoder/resolve/main/xdecoder_focalt_best_openseg.pt) |

### Instance segmentation on COCO2017

```shell
./tools/dist_test.sh projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py xdecoder_focalt_last_novg.pt 8
```

| Model | mAP | Config | Download |
| :-------------------------------------------------- | :--: | :------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------: |
| `xdecoder-tiny_zeroshot_open-vocab-instance_coco`\* | 39.7 | [config](configs/xdecoder-tiny_zeroshot_open-vocab-instance_coco.py) | [model](https://huggingface.co/xdecoder/X-Decoder/resolve/main/xdecoder_focalt_best_openseg.pt) |

### Image Caption on COCO2014

```shell
./tools/dist_test.sh projects/XDecoder/configs/xdecoder-tiny_zeroshot_caption_coco2014.py xdecoder_focalt_last_novg.pt 8
```

| Model | BLEU-4 | CIDER | Config | Download |
| :------------------------------------------ | :----: | :----: | :----------------------------------------------------------: | :------------------------------------------------------------------------------------------: |
| `xdecoder-tiny_zeroshot_caption_coco2014`\* | 35.14 | 116.62 | [config](configs/xdecoder-tiny_zeroshot_caption_coco2014.py) | [model](https://huggingface.co/xdecoder/X-Decoder/resolve/main/xdecoder_focalt_last_novg.pt) |
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = 'xdecoder-tiny_open-vocab-semseg.py'

model = dict(
head=dict(task='panoptic'), test_cfg=dict(mask_thr=0.8, overlap_thr=0.5))
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# use_thr_for_mc=True means use threshold for multi-class
# This parameter is only used in semantic segmentation task and
# referring semantic segmentation task.
test_cfg=dict(mask_thr=0.5, use_thr_for_mc=True),
test_cfg=dict(mask_thr=0.5, use_thr_for_mc=True, bg_index=255),
)

val_cfg = dict(type='ValLoop')
Expand Down
3 changes: 3 additions & 0 deletions projects/XDecoder/configs/_base_/xdecoder-tiny_ref-semseg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = 'xdecoder-tiny_open-vocab-semseg.py'

model = dict(head=dict(task='ref-semseg'))
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
_base_ = [
'_base_/xdecoder-tiny_caption.py', 'mmdet:_base_/datasets/coco_caption.py'
'_base_/xdecoder-tiny_caption.py', 'mmdet::_base_/datasets/coco_caption.py'
]

test_pipeline = [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
_base_ = [
'_base_/xdecoder-tiny_open-vocab-instance.py',
'mmdet:_base_/datasets/coco_instance.py'
'mmdet::_base_/datasets/coco_instance.py'
]

test_pipeline = [
Expand All @@ -21,5 +21,11 @@
'scale_factor', 'text'))
]

val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
val_dataloader = dict(
dataset=dict(pipeline=test_pipeline, return_classes=True))
test_dataloader = val_dataloader

val_evaluator = dict(metric='segm')
test_evaluator = val_evaluator

train_dataloader = None
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
_base_ = 'xdecoder-tiny_zeroshot_open-vocab-semseg.py'

model = dict(
head=dict(task='panoptic'), test_cfg=dict(mask_thr=0.8, overlap_thr=0.5))
_base_ = [
'_base_/xdecoder-tiny_open-vocab-panoptic.py',
'mmdet::_base_/datasets/coco_panoptic.py'
]

test_pipeline = [
dict(
Expand Down
Loading

0 comments on commit a667312

Please sign in to comment.