From efadf9ced8ab8fd2b028ea29ee42cb1d6b55c323 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E6=98=95=E8=BE=B0?= Date: Thu, 8 Jun 2023 11:21:35 +0800 Subject: [PATCH] Fix semseg eval (#10466) --- configs/_base_/datasets/ade20k_semseg.py | 3 ++- mmdet/datasets/transforms/loading.py | 11 +++++++++++ mmdet/evaluation/metrics/semseg_metric.py | 17 ++++++++++++----- ...er-tiny_zeroshot_open-vocab-semseg_ade20k.py | 3 ++- projects/XDecoder/xdecoder/unified_head.py | 9 +-------- 5 files changed, 28 insertions(+), 15 deletions(-) diff --git a/configs/_base_/datasets/ade20k_semseg.py b/configs/_base_/datasets/ade20k_semseg.py index 7e779369a23..0e57e357bcd 100644 --- a/configs/_base_/datasets/ade20k_semseg.py +++ b/configs/_base_/datasets/ade20k_semseg.py @@ -23,7 +23,8 @@ type='LoadAnnotations', with_bbox=False, with_mask=False, - with_seg=True), + with_seg=True, + reduce_zero_label=True), dict( type='PackDetInputs', meta_keys=('img_path', 'ori_shape', 'img_shape', 'text')) diff --git a/mmdet/datasets/transforms/loading.py b/mmdet/datasets/transforms/loading.py index ad45d8d46d0..bd3ca236e49 100644 --- a/mmdet/datasets/transforms/loading.py +++ b/mmdet/datasets/transforms/loading.py @@ -239,6 +239,9 @@ class LoadAnnotations(MMCV_LoadAnnotations): poly2mask (bool): Whether to convert mask to bitmap. Default: True. box_type (str): The box type used to wrap the bboxes. If ``box_type`` is None, gt_bboxes will keep being np.ndarray. Defaults to 'hbox'. + reduce_zero_label (bool): Whether reduce all label value + by 1. Usually used for datasets where 0 is background label. + Defaults to False. imdecode_backend (str): The image decoding backend type. The backend argument for :func:``mmcv.imfrombytes``. See :fun:``mmcv.imfrombytes`` for details. @@ -251,11 +254,13 @@ def __init__(self, with_mask: bool = False, poly2mask: bool = True, box_type: str = 'hbox', + reduce_zero_label: bool = False, **kwargs) -> None: super(LoadAnnotations, self).__init__(**kwargs) self.with_mask = with_mask self.poly2mask = poly2mask self.box_type = box_type + self.reduce_zero_label = reduce_zero_label def _load_bboxes(self, results: dict) -> None: """Private function to load bounding box annotations. @@ -399,6 +404,12 @@ def _load_seg_map(self, results: dict) -> None: img_bytes, flag='unchanged', backend=self.imdecode_backend).squeeze() + if self.reduce_zero_label: + # avoid using underflow conversion + gt_semantic_seg[gt_semantic_seg == 0] = 255 + gt_semantic_seg = gt_semantic_seg - 1 + gt_semantic_seg[gt_semantic_seg == 254] = 255 + # modify if custom classes if results.get('label_map', None) is not None: # Add deep copy to solve bug of repeatedly diff --git a/mmdet/evaluation/metrics/semseg_metric.py b/mmdet/evaluation/metrics/semseg_metric.py index b59e17444bc..144f13be010 100644 --- a/mmdet/evaluation/metrics/semseg_metric.py +++ b/mmdet/evaluation/metrics/semseg_metric.py @@ -86,8 +86,11 @@ def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: if not self.format_only: label = data_sample['gt_sem_seg']['sem_seg'].squeeze().to( pred_label) + bg_index = data_sample['pred_sem_seg']['bg_index'] self.results.append( - self._compute_pred_stats(pred_label, label, num_classes)) + self._compute_pred_stats(pred_label, label, num_classes, + bg_index)) + # format_result if self.output_dir is not None: basename = osp.splitext(osp.basename( @@ -134,7 +137,8 @@ def compute_metrics(self, results: list) -> Dict[str, float]: return metrics def _compute_pred_stats(self, pred_label: torch.tensor, - label: torch.tensor, num_classes: int): + label: torch.tensor, num_classes: int, + bg_index: int): """Parse semantic segmentation predictions. Args: @@ -153,13 +157,16 @@ def _compute_pred_stats(self, pred_label: torch.tensor, torch.Tensor: The ground truth histogram on all classes. """ assert pred_label.shape == label.shape + mask = label != bg_index + label, pred_label = label[mask], pred_label[mask] + intersect = pred_label[pred_label == label] area_intersect = torch.histc( - intersect.float(), bins=num_classes, min=0, max=num_classes-1) + intersect.float(), bins=num_classes, min=0, max=num_classes - 1) area_pred_label = torch.histc( - pred_label.float(), bins=num_classes, min=0, max=num_classes-1) + pred_label.float(), bins=num_classes, min=0, max=num_classes - 1) area_label = torch.histc( - label.float(), bins=num_classes, min=0, max=num_classes-1) + label.float(), bins=num_classes, min=0, max=num_classes - 1) area_union = area_pred_label + area_label - area_intersect result = dict( area_intersect=area_intersect, diff --git a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_ade20k.py b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_ade20k.py index 5fff76f1838..00d29d417d9 100644 --- a/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_ade20k.py +++ b/projects/XDecoder/configs/xdecoder-tiny_zeroshot_open-vocab-semseg_ade20k.py @@ -19,7 +19,8 @@ type='LoadAnnotations', with_bbox=False, with_mask=False, - with_seg=True), + with_seg=True, + reduce_zero_label=True), dict( type='PackDetInputs', meta_keys=('img_path', 'ori_shape', 'img_shape', 'text')) diff --git a/projects/XDecoder/xdecoder/unified_head.py b/projects/XDecoder/xdecoder/unified_head.py index 7a15be4994c..a8b9d91f310 100644 --- a/projects/XDecoder/xdecoder/unified_head.py +++ b/projects/XDecoder/xdecoder/unified_head.py @@ -288,26 +288,19 @@ def _semantic_inference(self, mask_cls, mask_pred, text_prompts): # 0 is foreground, bg_index is background sem_seg = (sem_seg.squeeze(0) <= self.test_cfg.mask_thr).int() sem_seg[sem_seg == 1] = self.test_cfg.get('bg_index', 255) - label_names = text_prompts # for visualization else: # 0 is foreground, bg_index is background if self.test_cfg.use_thr_for_mc: foreground_flag = sem_seg > self.test_cfg.mask_thr sem_seg = sem_seg.max(0)[1] - label_names = [ - text_prompts[id] for id in torch.unique(sem_seg) - ] sem_seg[foreground_flag.sum(0) == 0] = self.test_cfg.get( 'bg_index', 255) else: sem_seg = sem_seg.max(0)[1] - label_names = [ - text_prompts[id] for id in torch.unique(sem_seg) - ] pred_sem_seg = PixelData( sem_seg=sem_seg, metainfo={ - 'label_names': label_names, + 'label_names': text_prompts, 'bg_index': self.test_cfg.get('bg_index', 255) }) return pred_sem_seg