diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py index 8e099940f10..78074823d6f 100644 --- a/mmdet/datasets/__init__.py +++ b/mmdet/datasets/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .ade20k import ADE20KPanopticDataset +from .ade20k import ADE20KDataset, ADE20KPanopticDataset from .base_det_dataset import BaseDetDataset +from .base_semseg_dataset import BaseSegDataset from .base_video_dataset import BaseVideoDataset from .cityscapes import CityscapesDataset from .coco import CocoDataset @@ -35,5 +36,6 @@ 'Objects365V1Dataset', 'Objects365V2Dataset', 'DSDLDetDataset', 'BaseVideoDataset', 'MOTChallengeDataset', 'TrackImgSampler', 'ReIDDataset', 'YouTubeVISDataset', 'TrackAspectRatioBatchSampler', - 'ADE20KPanopticDataset', 'COCOCaptionDataset', 'RefCOCODataset' + 'ADE20KPanopticDataset', 'COCOCaptionDataset', 'RefCOCODataset', + 'BaseSegDataset', 'ADE20KDataset' ] diff --git a/mmdet/datasets/ade20k.py b/mmdet/datasets/ade20k.py index ac0138f97c3..4831baaf5cd 100644 --- a/mmdet/datasets/ade20k.py +++ b/mmdet/datasets/ade20k.py @@ -1,5 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +from mmengine import fileio + from mmdet.registry import DATASETS +from .base_semseg_dataset import BaseSegDataset from .coco_panoptic import CocoPanopticDataset @@ -128,3 +133,115 @@ class ADE20KPanopticDataset(CocoPanopticDataset): 0], [25, 194, 194], [102, 255, 0], [92, 0, 255]] } + + +@DATASETS.register_module() +class ADE20KDataset(BaseSegDataset): + """ADE20K dataset. + + In segmentation map annotation for ADE20K, 0 stands for background, which + is not included in 150 categories. The ``img_suffix`` is fixed to '.jpg', + and ``seg_map_suffix`` is fixed to '.png'. + """ + METAINFO = dict( + classes=('wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', + 'bed ', 'windowpane', 'grass', 'cabinet', 'sidewalk', + 'person', 'earth', 'door', 'table', 'mountain', 'plant', + 'curtain', 'chair', 'car', 'water', 'painting', 'sofa', + 'shelf', 'house', 'sea', 'mirror', 'rug', 'field', 'armchair', + 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', + 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', + 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', + 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', + 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', + 'screen door', 'stairway', 'river', 'bridge', 'bookcase', + 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', + 'bench', 'countertop', 'stove', 'palm', 'kitchen island', + 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', + 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', + 'chandelier', 'awning', 'streetlight', 'booth', + 'television receiver', 'airplane', 'dirt track', 'apparel', + 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', + 'buffet', 'poster', 'stage', 'van', 'ship', 'fountain', + 'conveyer belt', 'canopy', 'washer', 'plaything', + 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', + 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', + 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', + 'bicycle', 'lake', 'dishwasher', 'screen', 'blanket', + 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', + 'tray', 'ashcan', 'fan', 'pier', 'crt screen', 'plate', + 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', + 'clock', 'flag'), + palette=[[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], + [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], + [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], + [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], + [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], + [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], + [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], + [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], + [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], + [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], + [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], + [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], + [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], + [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], + [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], + [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], + [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], + [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], + [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], + [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], + [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], + [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], + [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], + [102, 255, 0], [92, 0, 255]]) + + def __init__(self, + img_suffix='.jpg', + seg_map_suffix='.png', + return_classes=False, + **kwargs) -> None: + self.return_classes = return_classes + super().__init__( + img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs) + + def load_data_list(self) -> list[dict]: + """Load annotation from directory or annotation file. + + Returns: + list[dict]: All data info of dataset. + """ + data_list = [] + img_dir = self.data_prefix.get('img_path', None) + ann_dir = self.data_prefix.get('seg_map_path', None) + for img in fileio.list_dir_or_file( + dir_path=img_dir, + list_dir=False, + suffix=self.img_suffix, + recursive=True, + backend_args=self.backend_args): + data_info = dict(img_path=osp.join(img_dir, img)) + if ann_dir is not None: + seg_map = img.replace(self.img_suffix, self.seg_map_suffix) + data_info['seg_map_path'] = osp.join(ann_dir, seg_map) + data_info['label_map'] = self.label_map + data_info['seg_fields'] = [] + if self.return_classes: + data_info['text'] = list(self._metainfo['classes']) + data_list.append(data_info) + return data_list diff --git a/mmdet/datasets/base_semseg_dataset.py b/mmdet/datasets/base_semseg_dataset.py new file mode 100644 index 00000000000..e0ef56f043d --- /dev/null +++ b/mmdet/datasets/base_semseg_dataset.py @@ -0,0 +1,258 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os.path as osp +from typing import Callable, Dict, List, Optional, Sequence, Union + +import mmengine +import mmengine.fileio as fileio +import numpy as np +from mmengine.dataset import BaseDataset, Compose + +from mmdet.registry import DATASETS + + +@DATASETS.register_module() +class BaseSegDataset(BaseDataset): + """Custom dataset for semantic segmentation. An example of file structure + is as followed. + + .. code-block:: none + + ├── data + │ ├── my_dataset + │ │ ├── img_dir + │ │ │ ├── train + │ │ │ │ ├── xxx{img_suffix} + │ │ │ │ ├── yyy{img_suffix} + │ │ │ │ ├── zzz{img_suffix} + │ │ │ ├── val + │ │ ├── ann_dir + │ │ │ ├── train + │ │ │ │ ├── xxx{seg_map_suffix} + │ │ │ │ ├── yyy{seg_map_suffix} + │ │ │ │ ├── zzz{seg_map_suffix} + │ │ │ ├── val + + The img/gt_semantic_seg pair of BaseSegDataset should be of the same + except suffix. A valid img/gt_semantic_seg filename pair should be like + ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included + in the suffix). If split is given, then ``xxx`` is specified in txt file. + Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded. + Please refer to ``docs/en/tutorials/new_dataset.md`` for more details. + + + Args: + ann_file (str): Annotation file path. Defaults to ''. + metainfo (dict, optional): Meta information for dataset, such as + specify classes to load. Defaults to None. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Defaults to None. + data_prefix (dict, optional): Prefix for training data. Defaults to + dict(img_path=None, seg_map_path=None). + img_suffix (str): Suffix of images. Default: '.jpg' + seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' + filter_cfg (dict, optional): Config for filter data. Defaults to None. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Defaults to None which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. Defaults + to True. + pipeline (list, optional): Processing pipeline. Defaults to []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Defaults to False. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=True``. Defaults to False. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Defaults to 1000. + backend_args (dict, Optional): Arguments to instantiate a file backend. + See https://mmengine.readthedocs.io/en/latest/api/fileio.htm + for details. Defaults to None. + Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required. + """ + METAINFO: dict = dict() + + def __init__(self, + ann_file: str = '', + img_suffix='.jpg', + seg_map_suffix='.png', + metainfo: Optional[dict] = None, + data_root: Optional[str] = None, + data_prefix: dict = dict(img_path='', seg_map_path=''), + filter_cfg: Optional[dict] = None, + indices: Optional[Union[int, Sequence[int]]] = None, + serialize_data: bool = True, + pipeline: List[Union[dict, Callable]] = [], + test_mode: bool = False, + lazy_init: bool = False, + max_refetch: int = 1000, + backend_args: Optional[dict] = None) -> None: + + self.img_suffix = img_suffix + self.seg_map_suffix = seg_map_suffix + self.backend_args = backend_args.copy() if backend_args else None + + self.data_root = data_root + self.data_prefix = copy.copy(data_prefix) + self.ann_file = ann_file + self.filter_cfg = copy.deepcopy(filter_cfg) + self._indices = indices + self.serialize_data = serialize_data + self.test_mode = test_mode + self.max_refetch = max_refetch + self.data_list: List[dict] = [] + self.data_bytes: np.ndarray + + # Set meta information. + self._metainfo = self._load_metainfo(copy.deepcopy(metainfo)) + + # Get label map for custom classes + new_classes = self._metainfo.get('classes', None) + self.label_map = self.get_label_map(new_classes) + self._metainfo.update(dict(label_map=self.label_map)) + + # Update palette based on label map or generate palette + # if it is not defined + updated_palette = self._update_palette() + self._metainfo.update(dict(palette=updated_palette)) + + # Join paths. + if self.data_root is not None: + self._join_prefix() + + # Build pipeline. + self.pipeline = Compose(pipeline) + # Full initialize the dataset. + if not lazy_init: + self.full_init() + + if test_mode: + assert self._metainfo.get('classes') is not None, \ + 'dataset metainfo `classes` should be specified when testing' + + @classmethod + def get_label_map(cls, + new_classes: Optional[Sequence] = None + ) -> Union[Dict, None]: + """Require label mapping. + + The ``label_map`` is a dictionary, its keys are the old label ids and + its values are the new label ids, and is used for changing pixel + labels in load_annotations. If and only if old classes in cls.METAINFO + is not equal to new classes in self._metainfo and nether of them is not + None, `label_map` is not None. + + Args: + new_classes (list, tuple, optional): The new classes name from + metainfo. Default to None. + + + Returns: + dict, optional: The mapping from old classes in cls.METAINFO to + new classes in self._metainfo + """ + old_classes = cls.METAINFO.get('classes', None) + if (new_classes is not None and old_classes is not None + and list(new_classes) != list(old_classes)): + + label_map = {} + if not set(new_classes).issubset(cls.METAINFO['classes']): + raise ValueError( + f'new classes {new_classes} is not a ' + f'subset of classes {old_classes} in METAINFO.') + for i, c in enumerate(old_classes): + if c not in new_classes: + # 0 is background + label_map[i] = 0 + else: + label_map[i] = new_classes.index(c) + return label_map + else: + return None + + def _update_palette(self) -> list: + """Update palette after loading metainfo. + + If length of palette is equal to classes, just return the palette. + If palette is not defined, it will randomly generate a palette. + If classes is updated by customer, it will return the subset of + palette. + + Returns: + Sequence: Palette for current dataset. + """ + palette = self._metainfo.get('palette', []) + classes = self._metainfo.get('classes', []) + # palette does match classes + if len(palette) == len(classes): + return palette + + if len(palette) == 0: + # Get random state before set seed, and restore + # random state later. + # It will prevent loss of randomness, as the palette + # may be different in each iteration if not specified. + # See: https://github.com/open-mmlab/mmdetection/issues/5844 + state = np.random.get_state() + np.random.seed(42) + # random palette + new_palette = np.random.randint( + 0, 255, size=(len(classes), 3)).tolist() + np.random.set_state(state) + elif len(palette) >= len(classes) and self.label_map is not None: + new_palette = [] + # return subset of palette + for old_id, new_id in sorted( + self.label_map.items(), key=lambda x: x[1]): + # 0 is background + if new_id != 0: + new_palette.append(palette[old_id]) + new_palette = type(palette)(new_palette) + else: + raise ValueError('palette does not match classes ' + f'as metainfo is {self._metainfo}.') + return new_palette + + def load_data_list(self) -> List[dict]: + """Load annotation from directory or annotation file. + + Returns: + list[dict]: All data info of dataset. + """ + data_list = [] + img_dir = self.data_prefix.get('img_path', None) + ann_dir = self.data_prefix.get('seg_map_path', None) + if not osp.isdir(self.ann_file) and self.ann_file: + assert osp.isfile(self.ann_file), \ + f'Failed to load `ann_file` {self.ann_file}' + lines = mmengine.list_from_file( + self.ann_file, backend_args=self.backend_args) + for line in lines: + img_name = line.strip() + data_info = dict( + img_path=osp.join(img_dir, img_name + self.img_suffix)) + if ann_dir is not None: + seg_map = img_name + self.seg_map_suffix + data_info['seg_map_path'] = osp.join(ann_dir, seg_map) + data_info['label_map'] = self.label_map + data_list.append(data_info) + else: + for img in fileio.list_dir_or_file( + dir_path=img_dir, + list_dir=False, + suffix=self.img_suffix, + recursive=True, + backend_args=self.backend_args): + data_info = dict(img_path=osp.join(img_dir, img)) + if ann_dir is not None: + seg_map = img.replace(self.img_suffix, self.seg_map_suffix) + data_info['seg_map_path'] = osp.join(ann_dir, seg_map) + data_info['label_map'] = self.label_map + data_list.append(data_info) + data_list = sorted(data_list, key=lambda x: x['img_path']) + return data_list diff --git a/mmdet/datasets/transforms/__init__.py b/mmdet/datasets/transforms/__init__.py index c8c40f3660c..9892f61891f 100644 --- a/mmdet/datasets/transforms/__init__.py +++ b/mmdet/datasets/transforms/__init__.py @@ -12,7 +12,8 @@ from .loading import (FilterAnnotations, InferencerLoader, LoadAnnotations, LoadEmptyAnnotations, LoadImageFromNDArray, LoadMultiChannelImageFromFiles, LoadPanopticAnnotations, - LoadProposals, LoadTrackAnnotations) + LoadProposals, LoadSemSegAnnotations, + LoadTrackAnnotations) from .transforms import (Albu, CachedMixUp, CachedMosaic, CopyPaste, CutOut, Expand, FixScaleResize, FixShapeResize, MinIoURandomCrop, MixUp, Mosaic, Pad, @@ -37,5 +38,6 @@ 'LoadEmptyAnnotations', 'RandomOrder', 'CachedMosaic', 'CachedMixUp', 'FixShapeResize', 'ProposalBroadcaster', 'InferencerLoader', 'LoadTrackAnnotations', 'BaseFrameSample', 'UniformRefFrameSample', - 'PackTrackInputs', 'PackReIDInputs', 'FixScaleResize' + 'PackTrackInputs', 'PackReIDInputs', 'FixScaleResize', + 'LoadSemSegAnnotations' ] diff --git a/mmdet/datasets/transforms/loading.py b/mmdet/datasets/transforms/loading.py index f7ea3128d9f..c7db404f1e3 100644 --- a/mmdet/datasets/transforms/loading.py +++ b/mmdet/datasets/transforms/loading.py @@ -600,6 +600,72 @@ def transform(self, results: dict) -> dict: return results +@TRANSFORMS.register_module() +class LoadSemSegAnnotations(LoadAnnotations): + """Load annotations for semantic segmentation provided by dataset. + + The annotation format is as the following: + + .. code-block:: python + + { + # Filename of semantic segmentation ground truth file. + 'seg_map_path': 'a/b/c' + } + + After this module, the annotation has been changed to the format below: + + .. code-block:: python + + { + # In uint8 type. + 'gt_seg_map': np.ndarray (H, W) + } + + Required Keys: + + - seg_map_path (str): Path of semantic segmentation ground truth file. + + Added Keys: + + - gt_seg_map (np.uint8) + """ + + def __init__(self, **kwargs) -> None: + super().__init__( + with_bbox=False, + with_label=False, + with_seg=True, + with_keypoints=False, + **kwargs) + + def _load_seg_map(self, results: dict) -> None: + """Private function to load semantic segmentation annotations. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + + Returns: + dict: The dict contains loaded semantic segmentation annotations. + """ + + img_bytes = get( + results['seg_map_path'], backend_args=self.backend_args) + gt_semantic_seg = mmcv.imfrombytes( + img_bytes, flag='unchanged', + backend=self.imdecode_backend).squeeze().astype(np.uint8) + + # modify if custom classes + if results.get('label_map', None) is not None: + # Add deep copy to solve bug of repeatedly + # replace `gt_semantic_seg`, which is reported in + # https://github.com/open-mmlab/mmsegmentation/pull/1445/ + gt_semantic_seg_copy = gt_semantic_seg.copy() + for old_id, new_id in results['label_map'].items(): + gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id + results['gt_seg_map'] = gt_semantic_seg + + @TRANSFORMS.register_module() class LoadProposals(BaseTransform): """Load proposal pipeline. diff --git a/mmdet/evaluation/metrics/__init__.py b/mmdet/evaluation/metrics/__init__.py index 8221c87e60e..df73bb329dc 100644 --- a/mmdet/evaluation/metrics/__init__.py +++ b/mmdet/evaluation/metrics/__init__.py @@ -13,6 +13,7 @@ from .mot_challenge_metric import MOTChallengeMetric from .openimages_metric import OpenImagesMetric from .reid_metric import ReIDMetrics +from .semseg_metric import SemSegMetric from .voc_metric import VOCMetric from .youtube_vis_metric import YouTubeVISMetric @@ -21,5 +22,5 @@ 'VOCMetric', 'LVISMetric', 'CrowdHumanMetric', 'DumpProposals', 'CocoOccludedSeparatedMetric', 'DumpDetResults', 'BaseVideoMetric', 'MOTChallengeMetric', 'CocoVideoMetric', 'ReIDMetrics', 'YouTubeVISMetric', - 'COCOCaptionMetric' + 'COCOCaptionMetric', 'SemSegMetric' ] diff --git a/mmdet/evaluation/metrics/semseg_metric.py b/mmdet/evaluation/metrics/semseg_metric.py new file mode 100644 index 00000000000..6b12d4a0b0b --- /dev/null +++ b/mmdet/evaluation/metrics/semseg_metric.py @@ -0,0 +1,274 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from collections import OrderedDict +from typing import Dict, List, Optional, Sequence, Union + +import numpy as np +import torch +from mmcv import imwrite +from mmengine.dist import is_main_process +from mmengine.evaluator import BaseMetric +from mmengine.logging import MMLogger, print_log +from mmengine.utils import mkdir_or_exist +from PIL import Image + +try: + from prettytable import PrettyTable +except ImportError: + PrettyTable = None + +from mmdet.registry import METRICS + + +@METRICS.register_module() +class SemSegMetric(BaseMetric): + """mIoU evaluation metric. + + Args: + iou_metrics (list[str] | str): Metrics to be calculated, the options + includes 'mIoU', 'mDice' and 'mFscore'. + beta (int): Determines the weight of recall in the combined score. + Default: 1. + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + output_dir (str): The directory for output prediction. Defaults to + None. + format_only (bool): Only format result for results commit without + perform evaluation. It is useful when you want to save the result + to a specific format and submit it to the test server. + Defaults to False. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Defaults to None. + """ + + def __init__(self, + iou_metrics: List[str] = ['mIoU'], + beta: int = 1, + collect_device: str = 'cpu', + output_dir: Optional[str] = None, + format_only: bool = False, + backend_args: dict = None, + prefix: Optional[str] = None, + **kwargs) -> None: + super().__init__(collect_device=collect_device, prefix=prefix) + + if isinstance(iou_metrics, str): + iou_metrics = [iou_metrics] + if not set(iou_metrics).issubset(set(['mIoU', 'mDice', 'mFscore'])): + raise KeyError(f'metrics {iou_metrics} is not supported') + self.metrics = iou_metrics + self.beta = beta + self.output_dir = output_dir + if self.output_dir and is_main_process(): + mkdir_or_exist(self.output_dir) + self.format_only = format_only + self.backend_args = backend_args + + def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: + """Process one batch of data and data_samples. + + The processed results should be stored in ``self.results``, which will + be used to compute the metrics when all batches have been processed. + + Args: + data_batch (dict): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. + """ + num_classes = len(self.dataset_meta['classes']) + for data_sample in data_samples: + pred_label = data_sample['pred_sem_seg']['sem_seg'].squeeze() + # format_only always for test dataset without ground truth + if not self.format_only: + label = data_sample['gt_sem_seg']['sem_seg'].squeeze().to( + pred_label) + self.results.append( + self._compute_pred_stats(pred_label, label, num_classes)) + # format_result + if self.output_dir is not None: + basename = osp.splitext(osp.basename( + data_sample['img_path']))[0] + png_filename = osp.abspath( + osp.join(self.output_dir, f'{basename}.png')) + output_mask = pred_label.cpu().numpy() + output = Image.fromarray(output_mask.astype(np.uint8)) + imwrite(output, png_filename, backend_args=self.backend_args) + + def compute_metrics(self, results: list) -> Dict[str, float]: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, float]: The computed metrics. The keys are the names of + the metrics, and the values are corresponding results. The key + mainly includes aAcc, mIoU, mAcc, mDice, mFscore, mPrecision, + mRecall. + """ + logger: MMLogger = MMLogger.get_current_instance() + if self.format_only: + logger.info(f'results are saved to {osp.dirname(self.output_dir)}') + return OrderedDict() + + ret_metrics = self.get_return_metrics(results) + + # summary table + ret_metrics_summary = OrderedDict({ + ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) + for ret_metric, ret_metric_value in ret_metrics.items() + }) + metrics = dict() + for key, val in ret_metrics_summary.items(): + if key == 'aAcc': + metrics[key] = val + else: + metrics['m' + key] = val + + print_semantic_table(ret_metrics, self.dataset_meta['classes'], logger) + + return metrics + + def _compute_pred_stats(self, pred_label: torch.tensor, + label: torch.tensor, num_classes: int): + """Parse semantic segmentation predictions. + + Args: + pred_label (torch.tensor): Prediction segmentation map + or predict result filename. The shape is (H, W). + label (torch.tensor): Ground truth segmentation map + or label filename. The shape is (H, W). + num_classes (int): Number of categories. + + Returns: + torch.Tensor: The intersection of prediction and ground truth + histogram on all classes. + torch.Tensor: The union of prediction and ground truth histogram on + all classes. + torch.Tens6or: The prediction histogram on all classes. + torch.Tensor: The ground truth histogram on all classes. + """ + assert pred_label.shape == label.shape + # 0 is background + mask = label != 0 + pred_label = (pred_label + 1) * mask + intersect = pred_label[pred_label == label] + area_intersect = torch.histc( + intersect.float(), bins=(num_classes), min=1, max=num_classes) + area_pred_label = torch.histc( + pred_label.float(), bins=(num_classes), min=1, max=num_classes) + area_label = torch.histc( + label.float(), bins=(num_classes), min=1, max=num_classes) + area_union = area_pred_label + area_label - area_intersect + result = dict( + area_intersect=area_intersect, + area_union=area_union, + area_pred_label=area_pred_label, + area_label=area_label) + return result + + def get_return_metrics(self, results: list) -> dict: + """Calculate evaluation metrics. + + Args: + results (list): The processed results of each batch. + + Returns: + Dict[str, np.ndarray]: per category evaluation metrics, + shape (num_classes, ). + """ + + def f_score(precision, recall, beta=1): + """calculate the f-score value. + + Args: + precision (float | torch.Tensor): The precision value. + recall (float | torch.Tensor): The recall value. + beta (int): Determines the weight of recall in the combined + score. Default: 1. + + Returns: + [torch.tensor]: The f-score value. + """ + score = (1 + beta**2) * (precision * recall) / ( + (beta**2 * precision) + recall) + return score + + total_area_intersect = sum([r['area_intersect'] for r in results]) + total_area_union = sum([r['area_union'] for r in results]) + total_area_pred_label = sum([r['area_pred_label'] for r in results]) + total_area_label = sum([r['area_label'] for r in results]) + + all_acc = total_area_intersect / total_area_label + ret_metrics = OrderedDict({'aAcc': all_acc}) + for metric in self.metrics: + if metric == 'mIoU': + iou = total_area_intersect / total_area_union + acc = total_area_intersect / total_area_label + ret_metrics['IoU'] = iou + ret_metrics['Acc'] = acc + elif metric == 'mDice': + dice = 2 * total_area_intersect / ( + total_area_pred_label + total_area_label) + acc = total_area_intersect / total_area_label + ret_metrics['Dice'] = dice + ret_metrics['Acc'] = acc + elif metric == 'mFscore': + precision = total_area_intersect / total_area_pred_label + recall = total_area_intersect / total_area_label + f_value = torch.tensor([ + f_score(x[0], x[1], self.beta) + for x in zip(precision, recall) + ]) + ret_metrics['Fscore'] = f_value + ret_metrics['Precision'] = precision + ret_metrics['Recall'] = recall + + ret_metrics = { + metric: value.cpu().numpy() + for metric, value in ret_metrics.items() + } + + return ret_metrics + + +def print_semantic_table( + results: dict, + class_names: list, + logger: Optional[Union['MMLogger', str]] = None) -> None: + """Print semantic segmentation evaluation results table. + + Args: + results (dict): The evaluation results. + class_names (list): Class names. + logger (MMLogger | str, optional): Logger used for printing. + Default: None. + """ + # each class table + results.pop('aAcc', None) + ret_metrics_class = OrderedDict({ + ret_metric: np.round(ret_metric_value * 100, 2) + for ret_metric, ret_metric_value in results.items() + }) + + print_log('per class results:', logger) + if PrettyTable: + class_table_data = PrettyTable() + ret_metrics_class.update({'Class': class_names}) + ret_metrics_class.move_to_end('Class', last=False) + for key, val in ret_metrics_class.items(): + class_table_data.add_column(key, val) + print_log('\n' + class_table_data.get_string(), logger=logger) + else: + logger.warning( + '`prettytable` is not installed, for better table format, ' + 'please consider installing it with "pip install prettytable"') + print_result = {} + for class_name, iou, acc in zip(class_names, ret_metrics_class['IoU'], + ret_metrics_class['Acc']): + print_result[class_name] = {'IoU': iou, 'Acc': acc} + print_log(print_result, logger) diff --git a/requirements/tests.txt b/requirements/tests.txt index b382c031e66..6de5e44f508 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -14,6 +14,7 @@ nltk onnx==1.7.0 onnxruntime>=1.8.0 parameterized +prettytable protobuf<=3.20.1 psutil pytest