From 904f89f0fc9302334fa88ec459070075989cc2ca Mon Sep 17 00:00:00 2001 From: Ian Coccimiglio Date: Wed, 24 Apr 2024 20:45:52 -0700 Subject: [PATCH 1/3] Added cuda and detections count for Rcnn --- .../minimal_detection/prompt_generator.py | 5 ++++- src/napari_segment_everything/sam_helper.py | 5 ++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/napari_segment_everything/minimal_detection/prompt_generator.py b/src/napari_segment_everything/minimal_detection/prompt_generator.py index c8859d7..efae597 100644 --- a/src/napari_segment_everything/minimal_detection/prompt_generator.py +++ b/src/napari_segment_everything/minimal_detection/prompt_generator.py @@ -7,6 +7,7 @@ """ import cv2 import os, sys +from skimage.transform import resize from torchvision.models.detection import fasterrcnn_mobilenet_v3_large_fpn from torchvision.transforms import ToTensor from torchvision.ops import nms @@ -110,7 +111,9 @@ def __init__(self, model_path, device, trainable=True): super().__init__(model_path, trainable) self.model_type = "FasterRCNN" self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.model = fasterrcnn_mobilenet_v3_large_fpn().to(device) + self.model = fasterrcnn_mobilenet_v3_large_fpn( + box_detections_per_img=500, + ).to(device) self.model.load_state_dict(torch.load(model_path)) def train(self, training_data): diff --git a/src/napari_segment_everything/sam_helper.py b/src/napari_segment_everything/sam_helper.py index 6efe185..9f8a9c5 100644 --- a/src/napari_segment_everything/sam_helper.py +++ b/src/napari_segment_everything/sam_helper.py @@ -160,7 +160,7 @@ def get_bounding_boxes( conf=0.4, iou=0.5, imgsz=1024, - max_det=400, + max_det=2000, ): if detector_model == "YOLOv8": model = YoloDetector( @@ -171,8 +171,7 @@ def get_bounding_boxes( ) elif detector_model == "Finetuned": model = RcnnDetector( - str(get_weights_path("ObjectAwareModel_Cell_FT")), - device="cuda", + str(get_weights_path("ObjectAwareModel_Cell_FT")), device="cuda" ) bounding_boxes = model.get_bounding_boxes(image, conf=conf, iou=iou) print(bounding_boxes) From 33e26a0bba10d1b96134315f76a66c5746a9b1af Mon Sep 17 00:00:00 2001 From: Ian Coccimiglio Date: Thu, 25 Apr 2024 08:53:19 -0700 Subject: [PATCH 2/3] Added tests --- .../_tests/test_mobile_sam.py | 129 +++++++++++++++++- .../minimal_detection/prompt_generator.py | 2 +- src/napari_segment_everything/sam_helper.py | 7 +- 3 files changed, 129 insertions(+), 9 deletions(-) diff --git a/src/napari_segment_everything/_tests/test_mobile_sam.py b/src/napari_segment_everything/_tests/test_mobile_sam.py index bb6597f..395c211 100644 --- a/src/napari_segment_everything/_tests/test_mobile_sam.py +++ b/src/napari_segment_everything/_tests/test_mobile_sam.py @@ -3,10 +3,38 @@ from napari_segment_everything.sam_helper import ( get_mobileSAMv2, get_bounding_boxes, + add_properties_to_label_image, + SAM_WEIGHTS_URL, + get_weights_path, + get_device, + get_sam_automatic_mask_generator, ) +from napari_segment_everything.minimal_detection.prompt_generator import ( + RcnnDetector, + YoloDetector, +) +import os +import requests +from gdown.parse_url import parse_url + + +def test_urls(): + """ + Tests whether all the urls for the model weights exist. + """ + for url in SAM_WEIGHTS_URL.values(): + if url.startswith("https://drive.google.com/"): + _, path_exists = parse_url(url) + assert path_exists + else: + req = requests.head(url) + assert req.status_code == 200 def test_mobile_sam(): + """ + Tests the mobileSAMv2 process pipeline + """ # load a color examp image = data.coffee() @@ -24,15 +52,106 @@ def test_mobile_sam(): def test_bbox(): - # load a color examp + """ + Test whether bboxes can be generated + """ image = data.coffee() bounding_boxes = get_bounding_boxes( image, detector_model="Finetuned", device="cuda", conf=0.01, iou=0.99 ) print(f"Length of bounding boxes: {len(bounding_boxes)}") - segmentations = get_mobileSAMv2(image, bounding_boxes) - return segmentations + assert len(bounding_boxes) > 0 + + +def test_RCNN(): + """ + Test RCNN object detection on CPU and CUDA devices. + """ + image = data.coffee() + model_path = str(get_weights_path("ObjectAwareModel_Cell_FT")) + assert os.path.exists(model_path) + rcnn_cpu = RcnnDetector(model_path, device="cpu") + rcnn_cuda = RcnnDetector(model_path, device="cuda") + bbox_cpu = rcnn_cpu.get_bounding_boxes(image, conf=0.5, iou=0.2) + bbox_cuda = rcnn_cuda.get_bounding_boxes(image, conf=0.5, iou=0.2) + assert len(bbox_cpu) == 6 + assert len(bbox_cuda) == 6 + + +def test_YOLO(): + """ + Test YOLO object detection on CPU and CUDA devices. + """ + image = data.coffee() + model_path = str(get_weights_path("ObjectAwareModel")) + assert os.path.exists(model_path) + yolo_cpu = YoloDetector(model_path, device="cpu") + yolo_cuda = YoloDetector(model_path, device="cuda") + bbox_cpu = yolo_cpu.get_bounding_boxes( + image, conf=0.5, iou=0.2, max_det=400, imgsz=1024 + ) + bbox_cuda = yolo_cuda.get_bounding_boxes( + image, conf=0.5, iou=0.2, max_det=400, imgsz=1024 + ) + assert len(bbox_cpu) == 8 + assert len(bbox_cuda) == 8 + + +def test_weights_path(): + weights_path = get_weights_path("default") + assert os.path.exists(os.path.dirname(weights_path)) + + +def test_labels(): + """ + Tests whether region properties can be generated for segmentations for different models + """ + image = data.coffee() + device = get_device() + + bbox_yolo = get_bounding_boxes( + image, + detector_model="YOLOv8", + imgsz=1024, + device=device, + conf=0.4, + iou=0.9, + ) + bbox_rcnn = get_bounding_boxes( + image, + detector_model="Finetuned", + imgsz=1024, + device=device, + conf=0.4, + iou=0.9, + ) + + segmentations_rcnn = get_mobileSAMv2(image, bbox_rcnn) + segmentations_yolo = get_mobileSAMv2(image, bbox_yolo) + segmentations_vit_b = get_sam_automatic_mask_generator( + "vit_b", + points_per_side=4, + pred_iou_thresh=0.2, + stability_score_thresh=0.5, + box_nms_thresh=0.1, + crop_n_layers=0, + ).generate(image) + + add_properties_to_label_image(image, segmentations_rcnn) + add_properties_to_label_image(image, segmentations_yolo) + add_properties_to_label_image(image, segmentations_vit_b) + props_rcnn = segmentations_rcnn[0].keys() + assert len(props_rcnn) == 10 + props_yolo = segmentations_yolo[0].keys() + assert len(props_yolo) == 10 + props_vit_b = segmentations_vit_b[0].keys() + assert len(props_vit_b) == 13 -# seg = test_bbox() -seg = test_mobile_sam() +test_urls() +test_bbox() +test_mobile_sam() +test_RCNN() +test_YOLO() +test_weights_path() +test_labels() diff --git a/src/napari_segment_everything/minimal_detection/prompt_generator.py b/src/napari_segment_everything/minimal_detection/prompt_generator.py index efae597..8e68bc7 100644 --- a/src/napari_segment_everything/minimal_detection/prompt_generator.py +++ b/src/napari_segment_everything/minimal_detection/prompt_generator.py @@ -110,7 +110,7 @@ class RcnnDetector(BaseDetector): def __init__(self, model_path, device, trainable=True): super().__init__(model_path, trainable) self.model_type = "FasterRCNN" - self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = device self.model = fasterrcnn_mobilenet_v3_large_fpn( box_detections_per_img=500, ).to(device) diff --git a/src/napari_segment_everything/sam_helper.py b/src/napari_segment_everything/sam_helper.py index e5bf8aa..1e765f0 100644 --- a/src/napari_segment_everything/sam_helper.py +++ b/src/napari_segment_everything/sam_helper.py @@ -164,17 +164,17 @@ def get_bounding_boxes( ): if detector_model == "YOLOv8": model = YoloDetector( - str(get_weights_path("ObjectAwareModel")), device="cuda" + str(get_weights_path("ObjectAwareModel")), device=device ) bounding_boxes = model.get_bounding_boxes( image, conf=conf, iou=iou, imgsz=imgsz, max_det=max_det ) elif detector_model == "Finetuned": model = RcnnDetector( - str(get_weights_path("ObjectAwareModel_Cell_FT")), device="cuda" + str(get_weights_path("ObjectAwareModel_Cell_FT")), device=device ) bounding_boxes = model.get_bounding_boxes(image, conf=conf, iou=iou) - print(bounding_boxes) + # print(bounding_boxes) return bounding_boxes @@ -286,6 +286,7 @@ def filter_labels_3d_multi( def add_properties_to_label_image(orig_image, sorted_results): hsv_image = color.rgb2hsv(orig_image) + # switch to this? https://forum.image.sc/t/looking-for-a-faster-version-of-rgb2hsv/95214/12 hue = 255 * hsv_image[:, :, 0] saturation = 255 * hsv_image[:, :, 1] From 1553bab4ab59903c8fa1d713b8fc18ff3e2d87b6 Mon Sep 17 00:00:00 2001 From: Ian Coccimiglio Date: Thu, 25 Apr 2024 09:01:39 -0700 Subject: [PATCH 3/3] Added docstring, removed dependence on skimage.resize --- src/napari_segment_everything/_tests/test_mobile_sam.py | 4 ++++ .../minimal_detection/prompt_generator.py | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/napari_segment_everything/_tests/test_mobile_sam.py b/src/napari_segment_everything/_tests/test_mobile_sam.py index 395c211..bf87bb2 100644 --- a/src/napari_segment_everything/_tests/test_mobile_sam.py +++ b/src/napari_segment_everything/_tests/test_mobile_sam.py @@ -98,6 +98,9 @@ def test_YOLO(): def test_weights_path(): + """ + Tests whether the weights directory existing on the operating system + """ weights_path = get_weights_path("default") assert os.path.exists(os.path.dirname(weights_path)) @@ -148,6 +151,7 @@ def test_labels(): props_vit_b = segmentations_vit_b[0].keys() assert len(props_vit_b) == 13 + test_urls() test_bbox() test_mobile_sam() diff --git a/src/napari_segment_everything/minimal_detection/prompt_generator.py b/src/napari_segment_everything/minimal_detection/prompt_generator.py index 8e68bc7..278b2e4 100644 --- a/src/napari_segment_everything/minimal_detection/prompt_generator.py +++ b/src/napari_segment_everything/minimal_detection/prompt_generator.py @@ -7,7 +7,6 @@ """ import cv2 import os, sys -from skimage.transform import resize from torchvision.models.detection import fasterrcnn_mobilenet_v3_large_fpn from torchvision.transforms import ToTensor from torchvision.ops import nms