Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use MMEval DOTAMAP #669

Draft
wants to merge 7 commits into
base: dev-1.x
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 141 additions & 142 deletions mmrotate/evaluation/metrics/dota_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,25 @@
import os.path as osp
import re
import tempfile
import warnings
import zipfile
from collections import OrderedDict, defaultdict
from typing import List, Optional, Sequence, Union

import numpy as np
import torch
from mmcv.ops import nms_quadri, nms_rotated
from mmengine.evaluator import BaseMetric
from mmengine.fileio import dump
from mmengine.logging import MMLogger
from mmengine.logging import MMLogger, print_log
from mmeval import DOTAMeanAP
from terminaltables import AsciiTable

from mmrotate.evaluation import eval_rbbox_map
from mmrotate.registry import METRICS
from mmrotate.structures.bbox import rbox2qbox


@METRICS.register_module()
class DOTAMetric(BaseMetric):
class DOTAMetric(DOTAMeanAP):
"""DOTA evaluation metric.

Note: In addition to format the output results to JSON like CocoMetric,
Expand All @@ -30,81 +31,115 @@ class DOTAMetric(BaseMetric):
large images, which can be found at: ``tools/data/dota/split``.

Args:
iou_thrs (float or List[float]): IoU threshold. Defaults to 0.5.
scale_ranges (List[tuple], optional): Scale ranges for evaluating
mAP. If not specified, all bounding boxes would be included in
evaluation. Defaults to None.
metric (str | list[str]): Metrics to be evaluated. Only support
'mAP' now. If is list, the first setting in the list will
be used to evaluate metric.
predict_box_type (str): Box type of model results. If the QuadriBoxes
is used, you need to specify 'qbox'. Defaults to 'rbox'.
format_only (bool): Format the output results without perform
iou_thrs (float | List[float], optional): IoU threshold.
scale_ranges (List[tuple], optional): Scale ranges for evaluating mAP.
num_classes (int, optional): The number of classes. If None, it will be
obtained from the 'CLASSES' field in ``self.dataset_meta``.
eval_mode (str, optional): 'area' or '11points', 'area' means
calculating the area under precision-recall curve, '11points'means
calculatingthe average precision of recalls at [0, 0.1, ..., 1].
nproc (int, optional): Processes used for computing TP and FP. If nproc
is less than or equal to 1, multiprocessing will not be used.
drop_class_ap (bool, optional): Whether to drop the class without
ground truth when calculating the average precision for each class.
dist_backend (str, optional): The name of the mmeval distributed
communication backend, you can get all the backend names through
``mmeval.core.list_all_backends()``.
predict_box_type (str, optional): Box type of model results. If the
QuadriBoxes is used, you need to specify 'qbox'. Defaults to
'rbox'.
format_only (bool, optional): Format the output results without perform
evaluation. It is useful when you want to format the result
to a specific format. Defaults to False.
outfile_prefix (str, optional): The prefix of json/zip files. It
includes the file path and the prefix of filename, e.g.,
outfile_prefix (Optional[str], optional): The prefix of json/zip files.
It includes the file path and the prefix of filename, e.g.,
"a/b/prefix". If not specified, a temp file will be created.
Defaults to None.
merge_patches (bool): Generate the full image's results by merging
patches' results.
iou_thr (float): IoU threshold of ``nms_rotated`` used in merge
patches. Defaults to 0.1.
eval_mode (str): 'area' or '11points', 'area' means calculating the
area under precision-recall curve, '11points' means calculating
the average precision of recalls at [0, 0.1, ..., 1].
The PASCAL VOC2007 defaults to use '11points', while PASCAL
VOC2012 defaults to use 'area'. Defaults to '11points'.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
merge_patches (bool, optional): Generate the full image's results by
merging patches' results.
iou_thr (float, optional): IoU threshold of ``nms_rotated`` used in
merge patches. Defaults to 0.1.
"""

default_prefix: Optional[str] = 'dota'

def __init__(self,
iou_thrs: Union[float, List[float]] = 0.5,
scale_ranges: Optional[List[tuple]] = None,
metric: Union[str, List[str]] = 'mAP',
num_classes: Optional[int] = None,
eval_mode: str = '11points',
nproc: int = 4,
drop_class_ap: bool = True,
dist_backend: str = 'torch_cuda',
predict_box_type: str = 'rbox',
format_only: bool = False,
outfile_prefix: Optional[str] = None,
merge_patches: bool = False,
iou_thr: float = 0.1,
eval_mode: str = '11points',
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)
self.iou_thrs = [iou_thrs] if isinstance(iou_thrs, float) \
else iou_thrs
assert isinstance(self.iou_thrs, list)
self.scale_ranges = scale_ranges
# voc evaluation metrics
if not isinstance(metric, str):
assert len(metric) == 1
metric = metric[0]
allowed_metrics = ['mAP']
if metric not in allowed_metrics:
raise KeyError(f"metric should be one of 'mAP', but got {metric}.")
self.metric = metric
self.predict_box_type = predict_box_type
**kwargs) -> None:
metric = kwargs.pop('metric', None)
if metric is not None:
warnings.warn('DeprecationWarning: The `metric` parameter of '
'`DOTAMetric` is deprecated, only mAP is supported')
collect_device = kwargs.pop('collect_device', None)
if collect_device is not None:
warnings.warn(
'DeprecationWarning: The `collect_device` parameter of '
'`DOTAMetric` is deprecated, use `dist_backend` instead.')

super().__init__(
iou_thrs=iou_thrs,
scale_ranges=scale_ranges,
num_classes=num_classes,
eval_mode=eval_mode,
nproc=nproc,
drop_class_ap=drop_class_ap,
classwise=True,
predict_box_type=predict_box_type,
dist_backend=dist_backend,
**kwargs)

self.format_only = format_only
if self.format_only:
assert outfile_prefix is not None, 'outfile_prefix must be not'
'None when format_only is True, otherwise the result files will'
'be saved to a temp directory which will be cleaned up at the end.'
assert outfile_prefix is not None, 'outfile_prefix must be not' \
'None when format_only is True, otherwise the result files' \
'will be saved to a temp directory which will be cleaned' \
'up at the end.'

self.outfile_prefix = outfile_prefix
self.merge_patches = merge_patches
self.iou_thr = iou_thr

self.use_07_metric = True if eval_mode == '11points' else False

def process(self, data_batch: Sequence[dict],
data_samples: Sequence[dict]):
"""Process one batch of data samples and predictions. The function will
call self.add() to add predictions and groundtruths to self._results.

Args:
data_batch (dict): A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of data samples that
contain annotations and predictions.
"""
predictions, groundtruths = [], []
for data_sample in data_samples:
gt = copy.deepcopy(data_sample)
gt_instances = gt['gt_instances']
gt_ignore_instances = gt['ignored_instances']
ann = dict(
labels=gt_instances['labels'].cpu().numpy(),
bboxes=gt_instances['bboxes'].cpu().numpy(),
bboxes_ignore=gt_ignore_instances['bboxes'].cpu().numpy(),
labels_ignore=gt_ignore_instances['labels'].cpu().numpy())
groundtruths.append(ann)

pred = data_sample['pred_instances']
# used for merge patches
pred['img_id'] = data_sample['img_id']
pred['bboxes'] = pred['bboxes'].cpu().numpy()
pred['scores'] = pred['scores'].cpu().numpy()
pred['labels'] = pred['labels'].cpu().numpy()
predictions.append(pred)
self.add(predictions, groundtruths)
YanxingLiu marked this conversation as resolved.
Show resolved Hide resolved

def merge_results(self, results: Sequence[dict],
outfile_prefix: str) -> str:
"""Merge patches' predictions into full image's results and generate a
Expand All @@ -121,7 +156,6 @@ def merge_results(self, results: Sequence[dict],
"somepath/xxx/xxx.zip".
"""
collector = defaultdict(list)

for idx, result in enumerate(results):
img_id = result.get('img_id', idx)
splitname = img_id.split('__')
Expand Down Expand Up @@ -223,8 +257,8 @@ def results2json(self, results: Sequence[dict],
Args:
results (Sequence[dict]): Testing results of the
dataset.
outfile_prefix (str): The filename prefix of the json files. If the
prefix is "somepath/xxx", the json files will be named
outfile_prefix (str): The filename prefix of the json files. If
the prefix is "somepath/xxx", the json files will be named
"somepath/xxx.bbox.json", "somepath/xxx.segm.json",
"somepath/xxx.proposal.json".

Expand Down Expand Up @@ -253,59 +287,9 @@ def results2json(self, results: Sequence[dict],

return result_files

def process(self, data_batch: Sequence[dict],
data_samples: Sequence[dict]) -> None:
"""Process one batch of data samples and predictions. The processed
results should be stored in ``self.results``, which will be used to
compute the metrics when all batches have been processed.

Args:
data_batch (dict): A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of data samples that
contain annotations and predictions.
"""
for data_sample in data_samples:
gt = copy.deepcopy(data_sample)
gt_instances = gt['gt_instances']
gt_ignore_instances = gt['ignored_instances']
if gt_instances == {}:
ann = dict()
else:
ann = dict(
labels=gt_instances['labels'].cpu().numpy(),
bboxes=gt_instances['bboxes'].cpu().numpy(),
bboxes_ignore=gt_ignore_instances['bboxes'].cpu().numpy(),
labels_ignore=gt_ignore_instances['labels'].cpu().numpy())
result = dict()
pred = data_sample['pred_instances']
result['img_id'] = data_sample['img_id']
result['bboxes'] = pred['bboxes'].cpu().numpy()
result['scores'] = pred['scores'].cpu().numpy()
result['labels'] = pred['labels'].cpu().numpy()

result['pred_bbox_scores'] = []
for label in range(len(self.dataset_meta['CLASSES'])):
index = np.where(result['labels'] == label)[0]
pred_bbox_scores = np.hstack([
result['bboxes'][index], result['scores'][index].reshape(
(-1, 1))
])
result['pred_bbox_scores'].append(pred_bbox_scores)

self.results.append((ann, result))

def compute_metrics(self, results: list) -> dict:
"""Compute the metrics from processed results.

Args:
results (list): The processed results of each batch.
Returns:
dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
def evaluate(self, *args, **kwargs) -> dict:
logger: MMLogger = MMLogger.get_current_instance()
gts, preds = zip(*results)

preds, _ = zip(*self._results)
tmp_dir = None
if self.outfile_prefix is None:
tmp_dir = tempfile.TemporaryDirectory()
Expand All @@ -319,35 +303,50 @@ def compute_metrics(self, results: list) -> dict:
zip_path = self.merge_results(preds, outfile_prefix)
logger.info(f'The submission file save at {zip_path}')
return eval_results
else:
# convert predictions to coco format and dump to json file
elif self.format_only:
_ = self.results2json(preds, outfile_prefix)
if self.format_only:
logger.info('results are saved in '
f'{osp.dirname(outfile_prefix)}')
return eval_results

if self.metric == 'mAP':
assert isinstance(self.iou_thrs, list)
dataset_name = self.dataset_meta['CLASSES']
dets = [pred['pred_bbox_scores'] for pred in preds]

mean_aps = []
for iou_thr in self.iou_thrs:
logger.info(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}')
mean_ap, _ = eval_rbbox_map(
dets,
gts,
scale_ranges=self.scale_ranges,
iou_thr=iou_thr,
use_07_metric=self.use_07_metric,
box_type=self.predict_box_type,
dataset=dataset_name,
logger=logger)
mean_aps.append(mean_ap)
eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3)
eval_results['mAP'] = sum(mean_aps) / len(mean_aps)
eval_results.move_to_end('mAP', last=False)
else:
raise NotImplementedError
return eval_results
logger.info('results are saved in '
f'{osp.dirname(outfile_prefix)}')
return eval_results

metric_results = self.compute(*args, **kwargs)
self.reset()
classwise_result = metric_results['classwise_result']
del metric_results['classwise_result']

classes = self.dataset_meta['CLASSES']
header = ['class', 'gts', 'dets', 'recall', 'ap']

for i, iou_thr in enumerate(self.iou_thrs):
for j, scale_range in enumerate(self.scale_ranges):
table_title = f' IoU thr: {iou_thr} '
if scale_range != (None, None):
table_title += f'Scale range: {scale_range} '

table_data = [header]
aps = []
for idx, _ in enumerate(classes):
class_results = classwise_result[idx]
recalls = class_results['recalls'][i, j]
recall = 0 if len(recalls) == 0 else recalls[-1]
row_data = [
classes[idx], class_results['num_gts'][i, j],
class_results['num_dets'],
round(recall, 3),
round(class_results['ap'][i, j], 3)
]
table_data.append(row_data)
if class_results['num_gts'][i, j] > 0:
aps.append(class_results['ap'][i, j])

mean_ap = np.mean(aps) if aps else 0
table_data.append(['mAP', '', '', '', f'{mean_ap:.3f}'])
table = AsciiTable(table_data, title=table_title)
table.inner_footing_row_border = True
print_log('\n' + table.table, logger='current')

evaluate_results = {
f'dota/{k}': round(float(v), 3)
for k, v in metric_results.items()
}
return evaluate_results
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ line_length = 79
multi_line_output = 0
known_standard_library = setuptools
known_first_party = mmrotate
known_third_party = PIL,cv2,e2cnn,matplotlib,mmcv,mmdet,mmengine,numpy,parameterized,pycocotools,pytest,pytorch_sphinx_theme,terminaltables,torch,ts,yaml
known_third_party = PIL,cv2,e2cnn,matplotlib,mmcv,mmdet,mmengine,mmeval,numpy,parameterized,pycocotools,pytest,pytorch_sphinx_theme,terminaltables,torch,ts,yaml
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY

Expand Down