From f737a5391f5c78a60e83ce10da357a33f8d2833e Mon Sep 17 00:00:00 2001 From: ajlynch Date: Wed, 7 Sep 2022 11:08:32 -0400 Subject: [PATCH 1/2] mmdet smqtk plugin implementation, tests, and config files --- .../impls/detect_image_objects/mmdet_base.py | 280 ++++++++++++++++ .../mmdet_config/_base_/default_runtime.py | 27 ++ .../_base_/models/retinanet_r50_fpn.py | 60 ++++ .../mmdet_config/common/mstrain_3x_coco.py | 76 +++++ tests/data/mmdet_config/retinanet/README.md | 53 +++ .../data/mmdet_config/retinanet/metafile.yml | 312 ++++++++++++++++++ ...tinanet_r50_fpn_mstrain_640-800_3x_coco.py | 5 + .../detect_image_objects/test_mmdet_base.py | 154 +++++++++ .../test_mmdet_implementation.py | 22 ++ 9 files changed, 989 insertions(+) create mode 100644 smqtk_detection/impls/detect_image_objects/mmdet_base.py create mode 100644 tests/data/mmdet_config/_base_/default_runtime.py create mode 100644 tests/data/mmdet_config/_base_/models/retinanet_r50_fpn.py create mode 100644 tests/data/mmdet_config/common/mstrain_3x_coco.py create mode 100644 tests/data/mmdet_config/retinanet/README.md create mode 100644 tests/data/mmdet_config/retinanet/metafile.yml create mode 100644 tests/data/mmdet_config/retinanet/retinanet_r50_fpn_mstrain_640-800_3x_coco.py create mode 100644 tests/impls/detect_image_objects/test_mmdet_base.py create mode 100644 tests/impls/detect_image_objects/test_mmdet_implementation.py diff --git a/smqtk_detection/impls/detect_image_objects/mmdet_base.py b/smqtk_detection/impls/detect_image_objects/mmdet_base.py new file mode 100644 index 0000000..5dd8b23 --- /dev/null +++ b/smqtk_detection/impls/detect_image_objects/mmdet_base.py @@ -0,0 +1,280 @@ +import logging +from typing import Any, Dict, Hashable, Iterable, Iterator, List +from typing import Optional, Tuple, TypeVar, Union +from typing_extensions import Protocol, runtime_checkable +import warnings + +import numpy as np +from mmdet.apis import init_detector, inference_detector +from mmdet.datasets.pipelines.auto_augment import AutoAugment + +import mmcv + +import torch.nn +from torch.utils.data import DataLoader, Dataset, IterableDataset + +from smqtk_detection import DetectImageObjects +from smqtk_detection.utils.bbox import AxisAlignedBoundingBox + + +LOG = logging.getLogger(__name__) +T_co = TypeVar("T_co", covariant=True) + + +class MMDetectionBase(DetectImageObjects): + """ + Plugin base wrapping the loading and application of mmdetection models + on images to yield object detections and classifications. + + It is expected that classes will be derived from this base that concretely + defines the data augmentation to appropriate transform input imagery for + the configured network. + + This plugin expects input image matrices to be in the dimension format + ``[H x W]`` or ``[H x W x C]``. It is *not* the case that all input imagery + must have matching shape values. + + This plugin attempts to be intelligent in how it handles different kinds of + iterable inputs. When given a ``Dataset`` or countable sequence (has + ``__len__`` and ``__getitem__``), any valid value may be provided to + ``num_workers`` as ``DataLoader`` might accept. + However, when the input to ``detect_objects`` is an uncountable iterable, + like a generic generator or stream source, the ``num_workers`` value should + usually be either 0 or 1. + This is due to the input iterable being copied for each worker, which may + not result in desired behavior. + For example: + * when the input iterable involves non-trivial operations per yield, + these operations are duplicated for each copy of the iterable as + traversed on each worker, probably resulting in excessive use of + resources. E.g. if the iterable is loading images from disk, each + worker is loading every image as it traverses their copy of the + iterable, even though each worker may only operate on a minority of + elements traversed. + * when the input iterable yields real-time data or is otherwise **not** + idempotent, like an iterable that yields images from a webcam stream, + each traversal of a copy of that iterable will produce different values + for equivalent "indices" since what is returned is conditional on when + ``next()`` is requested. Since iterators are copied to N separate + workers, each making independent next requests, the e.g. 64th + ``next()`` for each worker might yield a different image matrix. + + :param config_path: Filesystem path to the mmdet model configuration for + use in the model initialization + :param load_device: The device to load the model onto. + :param batch_size: Optionally provide prediction batch size override. If + set, we will override the configuration's ``SOLVER.IMS_PER_BATCH`` + parameter to this integer value. Otherwise, we will use the batch size + value set to that parameter. + :param weights_uri: Optional reference to the model weights file to use + instead of that referenced in the detectron configuration file. + If not provided, we will + :param model_lazy_load: If the model should be lazy-loaded on the first + inference request (``True`` value), or if we should load the model up- + front (``False`` value). + :param num_workers: The number of workers to use for data loading. When set + to ``None`` (the default) we will pull from the detectron config, + otherwise we will obey this value. See torch ``DataLoader`` for + ``num_workers`` value meanings. + + """ + + def __init__( + self, + config_path: str, + load_device: Union[int, str] = "cuda:0", + batch_size: Optional[int] = None, + weights_uri: Optional[str] = None, + model_lazy_load: bool = True, + num_workers: Optional[int] = None, + ): + self._mmdet_config_path = config_path + + self._load_device_prim = load_device # int/str reference only + self._batch_size = batch_size + self._weights_uri = weights_uri + self._model_lazy_load = model_lazy_load + self._num_workers = num_workers + + self._model_device = torch.device(load_device) + self._model: Optional[torch.nn.Module] = None + self._classes = None + + if not model_lazy_load: + self._lazy_load_model() + + def _lazy_load_model(self) -> torch.nn.Module: + """ + Actually initialize the model and set the weights, storing on the + requested device. If the model is already initialized, we simply return + it. This method is idempotent and should always return the same model + instance once loaded. + + If this fails to initialize the model, then nothing is set to the class + and ``None`` is returned (reflective of the set model state). + """ + if self._model is None: + model = init_detector(self._mmdet_config_path, self._weights_uri, device=self._load_device_prim) + model.to(self._model_device).eval() + self._model = model + self._classes = range(self._model.bbox_head.num_classes) + + return self._model + + def detect_objects( + self, + img_iter: Iterable[np.array] + )-> Iterable[Iterable[Tuple[AxisAlignedBoundingBox, Dict[Hashable, float]]]]: + model = self._lazy_load_model() + print("model", model) + print("loaded model", self._classes) + batch = [] + with torch.no_grad(): + for batch_input in img_iter: + batch.append(batch_input) + if len(batch) < self._batch_size: + continue + + batch_output = self._forward(model, batch) + # For each output, yield an iteration converting outputs into + # the interface-defined data-types + print("batch out", len(batch_output)) + for output_dict in batch_output: + yield self._format_detections(output_dict) + batch = [] + + def _forward(self, model: torch.nn.Module, batch_inputs: List[Dict[str, Any]]) -> List[Any]: + """ + Method encapsulating running a forward pass on a model given some batch + inputs. + + This is a separate method to allow for potential subclasses to override + this. + + :param model: Torch module as loaded by mmdet to perform forward + passes with. + :param batch_inputs: mmdet formatted batch inputs. It can be + expected that this will follow the format described by [1] + which is a list[str/ndarray] or tuple[str/ndarray] + + Returns: + Sequence of outputs for each batch input. Each item in this output + is expected to be interpreted by ``_iterate_output``. + + [1]: https://mmdetection.readthedocs.io/en/latest/api.html#mmdet.apis.inference_detector + """ + return inference_detector(model, batch_inputs) + + def _format_detections( + self, + preds + ): + # Empty dict to fill + print(self._classes) + zero_dict: Dict[Hashable, float] = {lbl: 0. for lbl in self._classes} + + # Loop over each prediction and format result + formatted_dets = [] + for pred in preds: + a_bboxes = [] + score_dicts = [] + for i, bbox in enumerate(pred): + a_bboxes.append(AxisAlignedBoundingBox( + [bbox[0], bbox[1]], [bbox[2], bbox[3]])) + #print(i, len(self._classes)) + class_dict = zero_dict.copy() + class_dict[self._classes[i]] = bbox[4] + score_dicts.append(class_dict) + break + + formatted_dets.append(list(zip(a_bboxes, score_dicts))) + return formatted_dets + + def _iterate_output( + self, single_output: List[Any] + ) -> Iterable[Tuple[AxisAlignedBoundingBox, Dict[Hashable, float]]]: + """ + Given the model's output for a single image's input, yield out a number + of ``(AxisAlignedBoundingBox, dict)`` pairs representing detections. + + :param single_output: mmdet formatted results output. + + """ + bboxes = np.vstack(single_output) + scores = bboxes[:, -1] + labels = [ + np.full(bbox.shape[0], i, dtype=np.int32) + for i, bbox in enumerate(single_output) + ] + labels = np.concatenate(labels) + cpu_instances = single_output['instances'].to('cpu') + for box, cls_idx, score in zip( + bboxes, + labels, + scores, + ): + yield ( + AxisAlignedBoundingBox(box[:2], box[2:-1]), + {cls_idx: float(score)} + ) + + def get_config(self) -> Dict[str, Any]: + return { + "mmdet_config": self._mmdet_config_path, + "load_device": self._load_device_prim, + "batch_size": self._batch_size, + "weights_uri": self._weights_uri, + "model_lazy_load": self._model_lazy_load, + "num_workers": self._num_workers, + } + + +def _trivial_batch_collator(batch: Any) -> Any: + """ + A batch collator that does nothing. + """ + return batch + + +def _to_tensor(np_image: np.ndarray) -> torch.Tensor: + """ + Common transform to go from ``[H x W [x C]]`` numpy image matrix to a + ``[[C x] x H x W]`` torch float32 tensor image. Image pixel scale is left + alone. + """ + aug_image = np_image.astype(np.float32) + if aug_image.ndim == 3: + aug_image = aug_image.transpose([2, 0, 1]) + return torch.as_tensor(aug_image) + + +def _aug_one_image(image: np.ndarray, aug: AutoAugment, gt_bboxes=[]) -> Dict[str, Union[torch.Tensor, int]]: + """ + Common augmentation operation for detectron2 inference passes, performed by + datasets defined below. + + Args: + image: Image matrix to be augmented + aug: Augmentation to be performed on the input image. + + Returns: + mmdet input with the augmented image tensor and + original image height and width attributes. + """ + # sorta replicating detectron2.engine.defaults.DefaultPredictor use of + # input formatting, which passes along original image height and width + height, width = image.shape[:2] + + # apply aug. `aug_input` will now contain the changed image matrix after + # `aug` call. + aug_input = {'image':image, 'gt_bboxes':gt_bboxes} + aug(aug_input) + + # convert from numpy-common format to torch.Tensor-common format. + aug_image = _to_tensor(aug_input['image']) + + return { + "image": aug_image, + "height": height, + "width": width, + } diff --git a/tests/data/mmdet_config/_base_/default_runtime.py b/tests/data/mmdet_config/_base_/default_runtime.py new file mode 100644 index 0000000..5b0b145 --- /dev/null +++ b/tests/data/mmdet_config/_base_/default_runtime.py @@ -0,0 +1,27 @@ +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +custom_hooks = [dict(type='NumClassCheckHook')] + +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] + +# disable opencv multithreading to avoid system being overloaded +opencv_num_threads = 0 +# set multi-process start method as `fork` to speed up the training +mp_start_method = 'fork' + +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (8 GPUs) x (2 samples per GPU). +auto_scale_lr = dict(enable=False, base_batch_size=16) diff --git a/tests/data/mmdet_config/_base_/models/retinanet_r50_fpn.py b/tests/data/mmdet_config/_base_/models/retinanet_r50_fpn.py new file mode 100644 index 0000000..56e43fa --- /dev/null +++ b/tests/data/mmdet_config/_base_/models/retinanet_r50_fpn.py @@ -0,0 +1,60 @@ +# model settings +model = dict( + type='RetinaNet', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_input', + num_outs=5), + bbox_head=dict( + type='RetinaHead', + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + # model training and testing settings + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100)) diff --git a/tests/data/mmdet_config/common/mstrain_3x_coco.py b/tests/data/mmdet_config/common/mstrain_3x_coco.py new file mode 100644 index 0000000..80ec8b8 --- /dev/null +++ b/tests/data/mmdet_config/common/mstrain_3x_coco.py @@ -0,0 +1,76 @@ +_base_ = '../_base_/default_runtime.py' +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +# In mstrain 3x config, img_scale=[(1333, 640), (1333, 800)], +# multiscale_mode='range' +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='Resize', + img_scale=[(1333, 640), (1333, 800)], + multiscale_mode='range', + keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] + +# Use RepeatDataset to speed up training +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + type='RepeatDataset', + times=3, + dataset=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_train2017.json', + img_prefix=data_root + 'train2017/', + pipeline=train_pipeline)), + val=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'annotations/instances_val2017.json', + img_prefix=data_root + 'val2017/', + pipeline=test_pipeline)) +evaluation = dict(interval=1, metric='bbox') + +# optimizer +optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=None) + +# learning policy +# Experiments show that using step=[9, 11] has higher performance +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[9, 11]) +runner = dict(type='EpochBasedRunner', max_epochs=12) diff --git a/tests/data/mmdet_config/retinanet/README.md b/tests/data/mmdet_config/retinanet/README.md new file mode 100644 index 0000000..b9e0a2a --- /dev/null +++ b/tests/data/mmdet_config/retinanet/README.md @@ -0,0 +1,53 @@ +# RetinaNet + +> [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) + + + +## Abstract + +The highest accuracy object detectors to date are based on a two-stage approach popularized by R-CNN, where a classifier is applied to a sparse set of candidate object locations. In contrast, one-stage detectors that are applied over a regular, dense sampling of possible object locations have the potential to be faster and simpler, but have trailed the accuracy of two-stage detectors thus far. In this paper, we investigate why this is the case. We discover that the extreme foreground-background class imbalance encountered during training of dense detectors is the central cause. We propose to address this class imbalance by reshaping the standard cross entropy loss such that it down-weights the loss assigned to well-classified examples. Our novel Focal Loss focuses training on a sparse set of hard examples and prevents the vast number of easy negatives from overwhelming the detector during training. To evaluate the effectiveness of our loss, we design and train a simple dense detector we call RetinaNet. Our results show that when trained with the focal loss, RetinaNet is able to match the speed of previous one-stage detectors while surpassing the accuracy of all existing state-of-the-art two-stage detectors. + +
+ +
+ +## Results and Models + +| Backbone | Style | Lr schd | Mem (GB) | Inf time (fps) | box AP | Config | Download | +| :-------------: | :-----: | :----------: | :------: | :------------: | :----: | :-------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| R-18-FPN | pytorch | 1x | 1.7 | | 31.7 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet/retinanet_r18_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r18_fpn_1x_coco/retinanet_r18_fpn_1x_coco_20220407_171055-614fd399.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r18_fpn_1x_coco/retinanet_r18_fpn_1x_coco_20220407_171055.log.json) | +| R-18-FPN | pytorch | 1x(1 x 8 BS) | 5.0 | | 31.7 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet/retinanet_r18_fpn_1x8_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r18_fpn_1x8_1x_coco/retinanet_r18_fpn_1x8_1x_coco_20220407_171255-4ea310d7.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r18_fpn_1x8_1x_coco/retinanet_r18_fpn_1x8_1x_coco_20220407_171255.log.json) | +| R-50-FPN | caffe | 1x | 3.5 | 18.6 | 36.3 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet/retinanet_r50_caffe_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_caffe_fpn_1x_coco/retinanet_r50_caffe_fpn_1x_coco_20200531-f11027c5.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_caffe_fpn_1x_coco/retinanet_r50_caffe_fpn_1x_coco_20200531_012518.log.json) | +| R-50-FPN | pytorch | 1x | 3.8 | 19.0 | 36.5 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet/retinanet_r50_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130_002941.log.json) | +| R-50-FPN (FP16) | pytorch | 1x | 2.8 | 31.6 | 36.4 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet/retinanet_r50_fpn_fp16_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/fp16/retinanet_r50_fpn_fp16_1x_coco/retinanet_r50_fpn_fp16_1x_coco_20200702-0dbfb212.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/fp16/retinanet_r50_fpn_fp16_1x_coco/retinanet_r50_fpn_fp16_1x_coco_20200702_020127.log.json) | +| R-50-FPN | pytorch | 2x | - | - | 37.4 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet/retinanet_r50_fpn_2x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_2x_coco/retinanet_r50_fpn_2x_coco_20200131-fdb43119.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_2x_coco/retinanet_r50_fpn_2x_coco_20200131_114738.log.json) | +| R-101-FPN | caffe | 1x | 5.5 | 14.7 | 38.5 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet/retinanet_r101_caffe_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r101_caffe_fpn_1x_coco/retinanet_r101_caffe_fpn_1x_coco_20200531-b428fa0f.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r101_caffe_fpn_1x_coco/retinanet_r101_caffe_fpn_1x_coco_20200531_012536.log.json) | +| R-101-FPN | pytorch | 1x | 5.7 | 15.0 | 38.5 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet/retinanet_r101_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r101_fpn_1x_coco/retinanet_r101_fpn_1x_coco_20200130-7a93545f.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r101_fpn_1x_coco/retinanet_r101_fpn_1x_coco_20200130_003055.log.json) | +| R-101-FPN | pytorch | 2x | - | - | 38.9 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet/retinanet_r101_fpn_2x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r101_fpn_2x_coco/retinanet_r101_fpn_2x_coco_20200131-5560aee8.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r101_fpn_2x_coco/retinanet_r101_fpn_2x_coco_20200131_114859.log.json) | +| X-101-32x4d-FPN | pytorch | 1x | 7.0 | 12.1 | 39.9 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet/retinanet_x101_32x4d_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_32x4d_fpn_1x_coco/retinanet_x101_32x4d_fpn_1x_coco_20200130-5c8b7ec4.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_32x4d_fpn_1x_coco/retinanet_x101_32x4d_fpn_1x_coco_20200130_003004.log.json) | +| X-101-32x4d-FPN | pytorch | 2x | - | - | 40.1 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet/retinanet_x101_32x4d_fpn_2x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_32x4d_fpn_2x_coco/retinanet_x101_32x4d_fpn_2x_coco_20200131-237fc5e1.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_32x4d_fpn_2x_coco/retinanet_x101_32x4d_fpn_2x_coco_20200131_114812.log.json) | +| X-101-64x4d-FPN | pytorch | 1x | 10.0 | 8.7 | 41.0 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet/retinanet_x101_64x4d_fpn_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130_003008.log.json) | +| X-101-64x4d-FPN | pytorch | 2x | - | - | 40.8 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet/retinanet_x101_64x4d_fpn_2x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_2x_coco/retinanet_x101_64x4d_fpn_2x_coco_20200131-bca068ab.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_2x_coco/retinanet_x101_64x4d_fpn_2x_coco_20200131_114833.log.json) | + +## Pre-trained Models + +We also train some models with longer schedules and multi-scale training. The users could finetune them for downstream tasks. + +| Backbone | Style | Lr schd | Mem (GB) | box AP | Config | Download | +| :-------------: | :-----: | :-----: | :------: | :----: | :-----------------------------------------------------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| R-50-FPN | pytorch | 3x | 3.5 | 39.5 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet/retinanet_r50_fpn_mstrain_640-800_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_mstrain_3x_coco/retinanet_r50_fpn_mstrain_3x_coco_20210718_220633-88476508.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_mstrain_3x_coco/retinanet_r50_fpn_mstrain_3x_coco_20210718_220633-88476508.log.json) | +| R-101-FPN | caffe | 3x | 5.4 | 40.7 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet/retinanet_r101_caffe_fpn_mstrain_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r101_caffe_fpn_mstrain_3x_coco/retinanet_r101_caffe_fpn_mstrain_3x_coco_20210721_063439-88a8a944.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r101_caffe_fpn_mstrain_3x_coco/retinanet_r101_caffe_fpn_mstrain_3x_coco_20210721_063439-88a8a944.log.json) | +| R-101-FPN | pytorch | 3x | 5.4 | 41 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet/retinanet_r101_fpn_mstrain_640-800_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r101_fpn_mstrain_3x_coco/retinanet_r101_fpn_mstrain_3x_coco_20210720_214650-7ee888e0.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r101_fpn_mstrain_3x_coco/retinanet_r101_fpn_mstrain_3x_coco_20210720_214650-7ee888e0.log.json) | +| X-101-64x4d-FPN | pytorch | 3x | 9.8 | 41.6 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet/retinanet_x101_64x4d_fpn_mstrain_640-800_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_mstrain_3x_coco/retinanet_x101_64x4d_fpn_mstrain_3x_coco_20210719_051838-022c2187.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_mstrain_3x_coco/retinanet_x101_64x4d_fpn_mstrain_3x_coco_20210719_051838-022c2187.log.json) | + +## Citation + +```latex +@inproceedings{lin2017focal, + title={Focal loss for dense object detection}, + author={Lin, Tsung-Yi and Goyal, Priya and Girshick, Ross and He, Kaiming and Doll{\'a}r, Piotr}, + booktitle={Proceedings of the IEEE international conference on computer vision}, + year={2017} +} +``` diff --git a/tests/data/mmdet_config/retinanet/metafile.yml b/tests/data/mmdet_config/retinanet/metafile.yml new file mode 100644 index 0000000..2080771 --- /dev/null +++ b/tests/data/mmdet_config/retinanet/metafile.yml @@ -0,0 +1,312 @@ +Collections: + - Name: RetinaNet + Metadata: + Training Data: COCO + Training Techniques: + - SGD with Momentum + - Weight Decay + Training Resources: 8x V100 GPUs + Architecture: + - Focal Loss + - FPN + - ResNet + Paper: + URL: https://arxiv.org/abs/1708.02002 + Title: "Focal Loss for Dense Object Detection" + README: configs/retinanet/README.md + Code: + URL: https://github.com/open-mmlab/mmdetection/blob/v2.0.0/mmdet/models/detectors/retinanet.py#L6 + Version: v2.0.0 + +Models: + - Name: retinanet_r18_fpn_1x_coco + In Collection: RetinaNet + Config: configs/retinanet/retinanet_r18_fpn_1x_coco.py + Metadata: + Training Memory (GB): 1.7 + Training Resources: 8x V100 GPUs + Epochs: 12 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 31.7 + Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r18_fpn_1x_coco/retinanet_r18_fpn_1x_coco_20220407_171055-614fd399.pth + + - Name: retinanet_r18_fpn_1x8_1x_coco + In Collection: RetinaNet + Config: configs/retinanet/retinanet_r18_fpn_1x8_1x_coco.py + Metadata: + Training Memory (GB): 5.0 + Training Resources: 1x V100 GPUs + Epochs: 12 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 31.7 + Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r18_fpn_1x8_1x_coco/retinanet_r18_fpn_1x8_1x_coco_20220407_171255-4ea310d7.pth + + - Name: retinanet_r50_caffe_fpn_1x_coco + In Collection: RetinaNet + Config: configs/retinanet/retinanet_r50_caffe_fpn_1x_coco.py + Metadata: + Training Memory (GB): 3.5 + inference time (ms/im): + - value: 53.76 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (800, 1333) + Epochs: 12 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 36.3 + Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_caffe_fpn_1x_coco/retinanet_r50_caffe_fpn_1x_coco_20200531-f11027c5.pth + + - Name: retinanet_r50_fpn_1x_coco + In Collection: RetinaNet + Config: configs/retinanet/retinanet_r50_fpn_1x_coco.py + Metadata: + Training Memory (GB): 3.8 + inference time (ms/im): + - value: 52.63 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (800, 1333) + Epochs: 12 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 36.5 + Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth + + - Name: retinanet_r50_fpn_fp16_1x_coco + In Collection: RetinaNet + Config: configs/retinanet/retinanet_r50_fpn_fp16_1x_coco.py + Metadata: + Training Memory (GB): 2.8 + Training Techniques: + - SGD with Momentum + - Weight Decay + - Mixed Precision Training + inference time (ms/im): + - value: 31.65 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP16 + resolution: (800, 1333) + Epochs: 12 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 36.4 + Weights: https://download.openmmlab.com/mmdetection/v2.0/fp16/retinanet_r50_fpn_fp16_1x_coco/retinanet_r50_fpn_fp16_1x_coco_20200702-0dbfb212.pth + + - Name: retinanet_r50_fpn_2x_coco + In Collection: RetinaNet + Config: configs/retinanet/retinanet_r50_fpn_2x_coco.py + Metadata: + Epochs: 24 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 37.4 + Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_2x_coco/retinanet_r50_fpn_2x_coco_20200131-fdb43119.pth + + - Name: retinanet_r50_fpn_mstrain_3x_coco + In Collection: RetinaNet + Config: configs/retinanet/retinanet_r50_fpn_mstrain_640-800_3x_coco.py + Metadata: + Epochs: 36 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 39.5 + Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_mstrain_3x_coco/retinanet_r50_fpn_mstrain_3x_coco_20210718_220633-88476508.pth + + - Name: retinanet_r101_caffe_fpn_1x_coco + In Collection: RetinaNet + Config: configs/retinanet/retinanet_r101_caffe_fpn_1x_coco.py + Metadata: + Training Memory (GB): 5.5 + inference time (ms/im): + - value: 68.03 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (800, 1333) + Epochs: 12 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 38.5 + Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r101_caffe_fpn_1x_coco/retinanet_r101_caffe_fpn_1x_coco_20200531-b428fa0f.pth + + - Name: retinanet_r101_caffe_fpn_mstrain_3x_coco + In Collection: RetinaNet + Config: configs/retinanet/retinanet_r101_caffe_fpn_1x_coco.py + Metadata: + Epochs: 36 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 40.7 + Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r101_caffe_fpn_mstrain_3x_coco/retinanet_r101_caffe_fpn_mstrain_3x_coco_20210721_063439-88a8a944.pth + + - Name: retinanet_r101_fpn_1x_coco + In Collection: RetinaNet + Config: configs/retinanet/retinanet_r101_fpn_1x_coco.py + Metadata: + Training Memory (GB): 5.7 + inference time (ms/im): + - value: 66.67 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (800, 1333) + Epochs: 12 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 38.5 + Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r101_fpn_1x_coco/retinanet_r101_fpn_1x_coco_20200130-7a93545f.pth + + - Name: retinanet_r101_fpn_2x_coco + In Collection: RetinaNet + Config: configs/retinanet/retinanet_r101_fpn_2x_coco.py + Metadata: + Training Memory (GB): 5.7 + inference time (ms/im): + - value: 66.67 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (800, 1333) + Epochs: 24 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 38.9 + Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r101_fpn_2x_coco/retinanet_r101_fpn_2x_coco_20200131-5560aee8.pth + + - Name: retinanet_r101_fpn_mstrain_3x_coco + In Collection: RetinaNet + Config: configs/retinanet/retinanet_r101_fpn_2x_coco.py + Metadata: + Epochs: 36 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 41 + Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r101_fpn_mstrain_3x_coco/retinanet_r101_fpn_mstrain_3x_coco_20210720_214650-7ee888e0.pth + + - Name: retinanet_x101_32x4d_fpn_1x_coco + In Collection: RetinaNet + Config: configs/retinanet/retinanet_x101_32x4d_fpn_1x_coco.py + Metadata: + Training Memory (GB): 7.0 + inference time (ms/im): + - value: 82.64 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (800, 1333) + Epochs: 12 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 39.9 + Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_32x4d_fpn_1x_coco/retinanet_x101_32x4d_fpn_1x_coco_20200130-5c8b7ec4.pth + + - Name: retinanet_x101_32x4d_fpn_2x_coco + In Collection: RetinaNet + Config: configs/retinanet/retinanet_x101_32x4d_fpn_2x_coco.py + Metadata: + Training Memory (GB): 7.0 + inference time (ms/im): + - value: 82.64 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (800, 1333) + Epochs: 24 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 40.1 + Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_32x4d_fpn_2x_coco/retinanet_x101_32x4d_fpn_2x_coco_20200131-237fc5e1.pth + + - Name: retinanet_x101_64x4d_fpn_1x_coco + In Collection: RetinaNet + Config: configs/retinanet/retinanet_x101_64x4d_fpn_1x_coco.py + Metadata: + Training Memory (GB): 10.0 + inference time (ms/im): + - value: 114.94 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (800, 1333) + Epochs: 12 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 41.0 + Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_1x_coco/retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af1.pth + + - Name: retinanet_x101_64x4d_fpn_2x_coco + In Collection: RetinaNet + Config: configs/retinanet/retinanet_x101_64x4d_fpn_2x_coco.py + Metadata: + Training Memory (GB): 10.0 + inference time (ms/im): + - value: 114.94 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (800, 1333) + Epochs: 24 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 40.8 + Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_2x_coco/retinanet_x101_64x4d_fpn_2x_coco_20200131-bca068ab.pth + + - Name: retinanet_x101_64x4d_fpn_mstrain_3x_coco + In Collection: RetinaNet + Config: configs/retinanet/retinanet_x101_64x4d_fpn_mstrain_640-800_3x_coco.py + Metadata: + Epochs: 36 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 41.6 + Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_mstrain_3x_coco/retinanet_x101_64x4d_fpn_mstrain_3x_coco_20210719_051838-022c2187.pth diff --git a/tests/data/mmdet_config/retinanet/retinanet_r50_fpn_mstrain_640-800_3x_coco.py b/tests/data/mmdet_config/retinanet/retinanet_r50_fpn_mstrain_640-800_3x_coco.py new file mode 100644 index 0000000..02a2c29 --- /dev/null +++ b/tests/data/mmdet_config/retinanet/retinanet_r50_fpn_mstrain_640-800_3x_coco.py @@ -0,0 +1,5 @@ +_base_ = [ + '../_base_/models/retinanet_r50_fpn.py', '../common/mstrain_3x_coco.py' +] +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) diff --git a/tests/impls/detect_image_objects/test_mmdet_base.py b/tests/impls/detect_image_objects/test_mmdet_base.py new file mode 100644 index 0000000..3f4c00f --- /dev/null +++ b/tests/impls/detect_image_objects/test_mmdet_base.py @@ -0,0 +1,154 @@ +import gc +import logging +from pathlib import Path +from typing import Any +from typing import Callable +from typing import Type +import unittest.mock as mock + +from smqtk_image_io import AxisAlignedBoundingBox +import numpy as np +import pytest +import torch +from torch.utils.data import Dataset, IterableDataset + +from mmdet.datasets.pipelines.auto_augment import AutoAugment +from mmdet.datasets import build_dataloader +from torch.utils.data import Dataset + + +# noinspection PyProtectedMember +from smqtk_detection.impls.detect_image_objects.mmdet_base import ( + _trivial_batch_collator, + _to_tensor, + _aug_one_image, + MMDetectionBase, +) + + +LOG = logging.getLogger(__name__) + +PARAM_np_dtype = pytest.mark.parametrize( + "np_dtype", + [np.uint8, np.float32, np.float64], + ids=lambda v: f"dtype={v.__name__}" +) + +TEST_CONFIG_PATH = Path("../../data/mmdet_config/retinanet/retinanet_r50_fpn_mstrain_640-800_3x_coco.py") + + +def test_trivial_batch_collector() -> None: + """ + Should literally do nothing and return what was input. + """ + new_type = type("rando_new_type") + new_object = new_type() + val = _trivial_batch_collator(new_object) + assert val is new_object + + +@PARAM_np_dtype +def test_to_tensor_nochan(np_dtype: np.dtype) -> None: + """ + Test that a tensor is appropriately output for the given ndarray with NO + channel dimension. + """ + n = np.empty((15, 16), dtype=np_dtype) + t = _to_tensor(n) + assert t.shape == (15, 16) + assert t.dtype == torch.float32 + + +@PARAM_np_dtype +def test_to_tensor_1chan(np_dtype: np.dtype) -> None: + """ + Test that a tensor is appropriately output for the given ndarray with ONE + channel dimension. + """ + n = np.empty((17, 18, 1), dtype=np_dtype) + t = _to_tensor(n) + assert t.shape == (1, 17, 18) + assert t.dtype == torch.float32 + + +@PARAM_np_dtype +def test_to_tensor_3chan(np_dtype: np.dtype) -> None: + """ + Test that a tensor is appropriately output for the given ndarray with THREE + channel dimension. + """ + n = np.empty((19, 20, 3), dtype=np_dtype) + t = _to_tensor(n) + assert t.shape == (3, 19, 20) + assert t.dtype == torch.float32 + + +def test_aug_one_image() -> None: + """ + Test single-image augmentation pass with mock augmentation. + """ + m_aug = mock.MagicMock(spec=AutoAugment) + test_image = np.zeros((128, 224, 3), dtype=np.uint8) + test_boxes = np.zeros((1,4), dtype=np.uint8) + + ret = _aug_one_image(test_image, m_aug, test_boxes) + + # The mock augmentation does nothing to the `AugInput` given to it, so + # can treat it as a "NoOp" augmentation where the `aug_input.image` is the + # same image as was input to the parent function. + assert isinstance(ret, dict) + assert len(ret) == 3 + assert 'image' in ret + assert isinstance(ret['image'], torch.Tensor) + assert ret['image'].shape == (3, 128, 224) + assert 'height' in ret + assert ret['height'] == 128 + assert 'width' in ret + assert ret['width'] == 224 + +class TestMMDetBase: + + class StubPlugin(MMDetectionBase): + """ Stub implementation of abstract base class used to test base class + provided functionality. """ + def _get_augmentation(self) -> AutoAugment: ... + + @classmethod + def teardown_class(cls) -> None: + del cls.StubPlugin + # Clean-up locally defined pluggable implementation. + gc.collect() + + @mock.patch.object(MMDetectionBase, "_lazy_load_model") + def test_init_lazy_load(self, _: Any) -> None: + """ + Test that with lazy load on, construction does NOT attempt to load the + model. + """ + + inst = self.StubPlugin(TEST_CONFIG_PATH.as_posix(), model_lazy_load=True) + + inst._lazy_load_model.assert_not_called() # type: ignore + + @mock.patch.object(MMDetectionBase, "_lazy_load_model") + def test_init_eager_load(self, _: mock.Mock) -> None: + """ + Test that with lazy load off, the model attempts to initialize + immediately. + """ + inst = self.StubPlugin(TEST_CONFIG_PATH.as_posix(), model_lazy_load=False) + + inst._lazy_load_model.assert_called_once() # type: ignore + + @pytest.mark.parametrize("initially_lazy", [False, True], ids=lambda v: f"initially_lazy={v}") + def test_lazy_load_model_idempotent(self, initially_lazy: bool) -> None: + """ + Test that the lazy loading function returns the same object on + successive calls. + """ + inst = self.StubPlugin(TEST_CONFIG_PATH.as_posix(), model_lazy_load=initially_lazy) + inst_model1 = inst._lazy_load_model() + inst_model2 = inst._lazy_load_model() + assert inst_model1 is not None + assert inst_model1 is inst_model2 + diff --git a/tests/impls/detect_image_objects/test_mmdet_implementation.py b/tests/impls/detect_image_objects/test_mmdet_implementation.py new file mode 100644 index 0000000..dcbe526 --- /dev/null +++ b/tests/impls/detect_image_objects/test_mmdet_implementation.py @@ -0,0 +1,22 @@ +import numpy as np +from pathlib import Path + +from smqtk_detection.impls.detect_image_objects.mmdet_base import MMDetectionBase + +TEST_CONFIG_PATH = Path("../../data/mmdet_config/retinanet/retinanet_r50_fpn_mstrain_640-800_3x_coco.py") + + +class TestMMDetectionBaseReal: + + def test_smoke_random(self) -> None: + """ + Smoke-test running a model on a random RGB image. + """ + cfg_fpath = TEST_CONFIG_PATH + inst = MMDetectionBase( + cfg_fpath.as_posix(), + batch_size=1, model_lazy_load=False, num_workers=0 + ) + + random_image = (np.random.rand(244, 244, 3) * 255).astype(np.uint8) + results = list(list(inst.detect_objects([random_image]))[0]) From 30cb555de2501bf8357a8b1954e5e46f2a4b0543 Mon Sep 17 00:00:00 2001 From: ajlynch Date: Wed, 7 Sep 2022 11:12:51 -0400 Subject: [PATCH 2/2] removed unnecessary print statements --- smqtk_detection/impls/detect_image_objects/mmdet_base.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/smqtk_detection/impls/detect_image_objects/mmdet_base.py b/smqtk_detection/impls/detect_image_objects/mmdet_base.py index 5dd8b23..8ff320d 100644 --- a/smqtk_detection/impls/detect_image_objects/mmdet_base.py +++ b/smqtk_detection/impls/detect_image_objects/mmdet_base.py @@ -126,8 +126,6 @@ def detect_objects( img_iter: Iterable[np.array] )-> Iterable[Iterable[Tuple[AxisAlignedBoundingBox, Dict[Hashable, float]]]]: model = self._lazy_load_model() - print("model", model) - print("loaded model", self._classes) batch = [] with torch.no_grad(): for batch_input in img_iter: @@ -138,7 +136,6 @@ def detect_objects( batch_output = self._forward(model, batch) # For each output, yield an iteration converting outputs into # the interface-defined data-types - print("batch out", len(batch_output)) for output_dict in batch_output: yield self._format_detections(output_dict) batch = [] @@ -170,7 +167,6 @@ def _format_detections( preds ): # Empty dict to fill - print(self._classes) zero_dict: Dict[Hashable, float] = {lbl: 0. for lbl in self._classes} # Loop over each prediction and format result @@ -181,7 +177,6 @@ def _format_detections( for i, bbox in enumerate(pred): a_bboxes.append(AxisAlignedBoundingBox( [bbox[0], bbox[1]], [bbox[2], bbox[3]])) - #print(i, len(self._classes)) class_dict = zero_dict.copy() class_dict[self._classes[i]] = bbox[4] score_dicts.append(class_dict)