diff --git a/src/napari_segment_everything/_tests/test_mobile_sam.py b/src/napari_segment_everything/_tests/test_mobile_sam.py index d8397d7..bb6597f 100644 --- a/src/napari_segment_everything/_tests/test_mobile_sam.py +++ b/src/napari_segment_everything/_tests/test_mobile_sam.py @@ -1,12 +1,38 @@ # import skimag example images from skimage import data -from napari_segment_everything.sam_helper import get_mobileSAMv2, get_bounding_boxes +from napari_segment_everything.sam_helper import ( + get_mobileSAMv2, + get_bounding_boxes, +) + def test_mobile_sam(): # load a color examp image = data.coffee() - - bounding_boxes = get_bounding_boxes(image, imgsz=1024, device = 'cuda') + + bounding_boxes = get_bounding_boxes( + image, + detector_model="YOLOv8", + imgsz=1024, + device="cuda", + conf=0.4, + iou=0.9, + ) + segmentations = get_mobileSAMv2(image, bounding_boxes) + + assert len(segmentations) == 11 + + +def test_bbox(): + # load a color examp + image = data.coffee() + bounding_boxes = get_bounding_boxes( + image, detector_model="Finetuned", device="cuda", conf=0.01, iou=0.99 + ) + print(f"Length of bounding boxes: {len(bounding_boxes)}") segmentations = get_mobileSAMv2(image, bounding_boxes) - - assert len(segmentations) == 11 \ No newline at end of file + return segmentations + + +# seg = test_bbox() +seg = test_mobile_sam() diff --git a/src/napari_segment_everything/interactive_launch.py b/src/napari_segment_everything/interactive_launch.py index 08ddd2e..72251a0 100644 --- a/src/napari_segment_everything/interactive_launch.py +++ b/src/napari_segment_everything/interactive_launch.py @@ -1,8 +1,9 @@ import napari -from napari_segment_everything import segment_everything +from napari_segment_everything import segment_everything viewer = napari.Viewer() -viewer.window.add_dock_widget(segment_everything.NapariSegmentEverything(viewer)) - -k=input("press close to exit") +viewer.window.add_dock_widget( + segment_everything.NapariSegmentEverything(viewer) +) +napari.run() diff --git a/src/napari_segment_everything/minimal_detection/detect_and_segment.py b/src/napari_segment_everything/minimal_detection/detect_and_segment.py index a223d73..5b39b53 100644 --- a/src/napari_segment_everything/minimal_detection/detect_and_segment.py +++ b/src/napari_segment_everything/minimal_detection/detect_and_segment.py @@ -8,23 +8,7 @@ from segment_anything.utils.amg import calculate_stability_score import gc -import sys - -""" -This is necessary because the torch weights were -pickled with its environment, which messed up the imports -""" - current_dir = os.path.dirname(__file__) -obj_detect_dir = os.path.join(current_dir, "object_detection") -sys.path.insert(0, obj_detect_dir) - -from .object_detection.ultralytics.prompt_mobilesamv2 import ObjectAwareModel - - -def create_OA_model(weights_path): - object_aware_model = ObjectAwareModel(weights_path) - return object_aware_model def create_MS_model(): @@ -51,32 +35,6 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] -def detect_bbox( - object_aware_model, - image, - imgsz=1024, - conf=0.4, - iou=0.9, - device="cpu", - max_det=300, -): - """ - Uses an object aware model to produce bounding boxes for a given image at image_path. - - Returns a list of bounding boxes, as well as extra properties. - """ - obj_results = object_aware_model( - image, - device=device, - retina_masks=True, - imgsz=imgsz, - conf=conf, - iou=iou, - max_det=max_det, - ) - return obj_results - - def segment_from_bbox(bounding_boxes, predictor, mobilesamv2): """ Segments everything given the bounding boxes of the objects and the mobileSAMv2 prediction model. @@ -90,7 +48,7 @@ def segment_from_bbox(bounding_boxes, predictor, mobilesamv2): predicted_ious = [] stability_scores = [] - + image_embedding = predictor.features image_embedding = torch.repeat_interleave(image_embedding, 400, dim=0) @@ -133,7 +91,7 @@ def segment_from_bbox(bounding_boxes, predictor, mobilesamv2): stability_scores.extend(stability_score.flatten().tolist()) sam_mask = torch.cat(sam_mask) - #predicted_ious = pred_ious.cpu().numpy() + # predicted_ious = pred_ious.cpu().numpy() cpu_segmentations = sam_mask.cpu().numpy() del sam_mask @@ -148,5 +106,9 @@ def segment_from_bbox(bounding_boxes, predictor, mobilesamv2): "predicted_iou": predicted_ious[idx], "stability_score": stability_scores[idx], } + if ( + cpu_segmentations[idx].max() < 1 + ): # this means that bboxes won't always == segmentations + continue curr_anns.append(ann) return curr_anns diff --git a/src/napari_segment_everything/minimal_detection/prompt_generator.py b/src/napari_segment_everything/minimal_detection/prompt_generator.py new file mode 100644 index 0000000..c8859d7 --- /dev/null +++ b/src/napari_segment_everything/minimal_detection/prompt_generator.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Mon Apr 22 23:58:14 2024 + +@author: ian +""" +import cv2 +import os, sys +from torchvision.models.detection import fasterrcnn_mobilenet_v3_large_fpn +from torchvision.transforms import ToTensor +from torchvision.ops import nms +import torch + +""" +This is necessary because the torch weights were +pickled with its environment, which messed up the imports +""" + +current_dir = os.path.dirname(__file__) +obj_detect_dir = os.path.join(current_dir, "object_detection") +sys.path.insert(0, obj_detect_dir) + +from napari_segment_everything.minimal_detection.object_detection.ultralytics.prompt_mobilesamv2 import ( + ObjectAwareModel, +) + + +class BaseDetector: + def __init__(self, model_path, trainable=False): + self.model_path = model_path + self.model_name = self.model_path.split("/")[-1] + self.trainable = trainable + + def train(self, training_data): + raise NotImplementedError() + + def predict(self, image_data): + raise NotImplementedError() + + +class YoloDetector(BaseDetector): + def __init__(self, model_path, device, trainable=False): + super().__init__(model_path) + self.model_type = "YOLOv8" + self.model = ObjectAwareModel(model_path) + self.device = device + + def train(self): + print( + "YOLO detector is not yet trainable, use RcnnDetector for training" + ) + + def get_bounding_boxes( + self, + image_data, + retina_masks=True, + imgsz=1024, + conf=0.4, + iou=0.9, + max_det=400, + ): + """ + Generates a series of bounding boxes in xyxy-format from an image, using the YOLOv8 ObjectAwareModel. + + Parameters + ---------- + image : numpy.ndarray + A 2D-image in grayscale or RGB. + imgsz : INT, optional + Size of the input image. The default is 1024. + conf : FLOAT, optional + Confidence threshold for the bounding boxes. Lower means more boxes will be detected. The default is 0.4. + iou : FLOAT, optional + Threshold for how many intersecting bounding boxes should be allowed. Lower means fewer intersecting boxes will be returned. The default is 0.9. + max_det : INT, optional + Maximum number of detections that will be returned. The default is 400. + + Returns + ------- + bounding_boxes : numpy.ndarray + An array of boxes in xyxy-coordinates. + """ + + print("Predicting bounding boxes for image data") + image_cv2 = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB) + + obj_results = self.model.predict( + image_cv2, + device=self.device, + retina_masks=True, + imgsz=imgsz, + conf=conf, + iou=iou, + max_det=max_det, + ) + bounding_boxes = obj_results[0].boxes.xyxy.cpu().numpy() + + return bounding_boxes + + def __str__(self): + s = f"\n{'Model':<10}: {self.model_name}\n" + s += f"{'Type':<10}: {str(self.model_type)}\n" + s += f"{'Trainable':<10}: {str(self.trainable)}" + return s + + +class RcnnDetector(BaseDetector): + def __init__(self, model_path, device, trainable=True): + super().__init__(model_path, trainable) + self.model_type = "FasterRCNN" + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model = fasterrcnn_mobilenet_v3_large_fpn().to(device) + self.model.load_state_dict(torch.load(model_path)) + + def train(self, training_data): + if self.trainable: + print("Training model") + print(self.model_path) + print(training_data) + + def _get_transform(self, train): + from torchvision.transforms import v2 as T + + transforms = [] + if train: + transforms.append(T.RandomHorizontalFlip(0.5)) + transforms.append(T.ToDtype(torch.float, scale=True)) + transforms.append(T.ToPureTensor()) + return T.Compose(transforms) + + @torch.inference_mode() + def get_bounding_boxes(self, image_data, conf=0.5, iou=0.2): + image_cv2 = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB) + + print("Predicting bounding boxes for image data") + convert_tensor = ToTensor() + eval_transform = self._get_transform(train=False) + tensor_image = convert_tensor(image_cv2) + x = eval_transform(tensor_image) + # convert RGBA -> RGB and move to device + x = x[:3, ...].to(self.device) + self.model.eval() + predictions = self.model([x]) + pred = predictions[0] + # print(pred) + idx_after = nms(pred["boxes"], pred["scores"], iou_threshold=iou) + pred_boxes = pred["boxes"][idx_after] + pred_scores = pred["scores"][idx_after] + pred_boxes_conf = pred_boxes[pred_scores > conf] + return pred_boxes_conf.cpu().numpy() + + def __str__(self): + s = f"\n{'Model':<10}: {self.model_name}\n" + s += f"{'Type':<10}: {str(self.model_type)}\n" + s += f"{'Trainable':<10}: {str(self.trainable)}" + return s diff --git a/src/napari_segment_everything/sam_helper.py b/src/napari_segment_everything/sam_helper.py index 6ebb362..6efe185 100644 --- a/src/napari_segment_everything/sam_helper.py +++ b/src/napari_segment_everything/sam_helper.py @@ -1,9 +1,7 @@ from segment_anything import SamPredictor from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator from napari_segment_everything.minimal_detection.detect_and_segment import ( - create_OA_model, create_MS_model, - detect_bbox, segment_from_bbox, ) from napari_segment_everything.minimal_detection.mobilesamv2 import ( @@ -11,16 +9,19 @@ SamPredictor as SamPredictorV2, ) +from napari_segment_everything.minimal_detection.prompt_generator import ( + RcnnDetector, + YoloDetector, +) + import urllib.request import warnings from pathlib import Path from typing import Optional -from skimage.measure import label -from skimage.measure import regionprops +from skimage.measure import regionprops, label from skimage import color import cv2 import torch - import toolz as tz from napari.utils import progress @@ -37,6 +38,7 @@ "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", "ObjectAwareModel": "https://drive.google.com/uc?id=1_vb_0SHBUnQhtg5SEE24kOog9_5Qpk5Z/ObjectAwareModel.pt", + "ObjectAwareModel_Cell_FT": "https://drive.google.com/uc?id=1efZ40ti87O346dJW5lp7inCZ84N_nugS/ObjectAwareModel_Cell_FT.pt", "efficientvit_l2": "https://drive.google.com/uc?id=10Emd1k9obcXZZALiqlW8FLIYZTfLs-xu/l2.pt", } @@ -99,6 +101,7 @@ def get_weights_path(model_type: str) -> Optional[Path]: return weight_path + def get_device(): if torch.cuda.is_available(): return "cuda" @@ -106,7 +109,7 @@ def get_device(): return "mps" else: return "cpu" - + def get_sam(model_type: str): sam = sam_model_registry[model_type](get_weights_path(model_type)) @@ -151,35 +154,54 @@ def get_sam_automatic_mask_generator( def get_bounding_boxes( - image, imgsz=1024, conf=0.4, iou=0.9, device="cpu", max_det=400 + image, + detector_model, + device="cpu", + conf=0.4, + iou=0.5, + imgsz=1024, + max_det=400, ): - - image_cv2 = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - - weights_path_OA = get_weights_path("ObjectAwareModel") - objAwareModel = create_OA_model(weights_path_OA) - - """Uses an object-aware model (YOLOv8) to determine the bounding boxes of objects""" - obj_results = detect_bbox( - objAwareModel, - image_cv2, - device=device, - imgsz=imgsz, - conf=conf, - iou=iou, - max_det=max_det, - ) - - bounding_boxes = obj_results[0].boxes.xyxy.cpu().numpy() + if detector_model == "YOLOv8": + model = YoloDetector( + str(get_weights_path("ObjectAwareModel")), device="cuda" + ) + bounding_boxes = model.get_bounding_boxes( + image, conf=conf, iou=iou, imgsz=imgsz, max_det=max_det + ) + elif detector_model == "Finetuned": + model = RcnnDetector( + str(get_weights_path("ObjectAwareModel_Cell_FT")), + device="cuda", + ) + bounding_boxes = model.get_bounding_boxes(image, conf=conf, iou=iou) + print(bounding_boxes) + return bounding_boxes - print(f"Discovered {len(bounding_boxes)} objects") - - return bounding_boxes def get_mobileSAMv2(image=None, bounding_boxes=None): + """ + Uses a SAM model to make predictions from bounding boxes. + + Parameters + ---------- + image : numpy.ndarray, optional + A 2D-image in grayscale or RGB. The default is None. + bounding_boxes : numpy.ndarray, optional + An array of boxes in xyxy-coordinates. The default is None. + + Returns + ------- + sam_masks : LIST + A list of results dictionaries, one for each segmentation mask. + Each sam_mask has keys for segmentation, area, predicted_iou, and stability_score. + + """ if image is None: print("Upload an image first") return + if image.ndim < 3: + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) device = "cuda" if torch.cuda.is_available() else "cpu" # device = "cpu" weights_path_VIT = get_weights_path("efficientvit_l2") diff --git a/src/napari_segment_everything/segment_everything.py b/src/napari_segment_everything/segment_everything.py index 9348d29..3e4c7c6 100644 --- a/src/napari_segment_everything/segment_everything.py +++ b/src/napari_segment_everything/segment_everything.py @@ -65,7 +65,7 @@ def __init__(self, napari_viewer, parent=None): ) self.load_image(self.im_layer_widget.value) - self._pts_layer = self.viewer.add_points( name = "SAM points") + self._pts_layer = self.viewer.add_points(name="SAM points") self._boxes_layer = self.viewer.add_shapes( name="SAM box", @@ -117,12 +117,18 @@ def initUI(self): self.open_project_button.clicked.connect(self.open_project) self.sam_layout.addWidget(self.open_project_button) - # Dropdown for selecting the recipe + # Dropdown for selecting the recipe recipe_label = QLabel( "Select Recipe" ) # Dropdown for selecting the recipe self.recipe_dropdown = QComboBox() - self.recipe_dropdown.addItems(["Mobile SAM v2", "Sam Automatic Mask Generator"]) + self.recipe_dropdown.addItems( + [ + "Mobile SAM v2", + "Mobile SAM Finetuned", + "Sam Automatic Mask Generator", + ] + ) recipe_layout = QHBoxLayout() recipe_layout.addWidget(recipe_label) recipe_layout.addWidget(self.recipe_dropdown) @@ -321,7 +327,6 @@ def initUI(self): layout.addWidget(self.textBrowser_log) layout.addWidget(self.progressBar) - self.setLayout(layout) def change_slider(self, index): @@ -369,11 +374,13 @@ def save_project(self): def process(self): if self.image is None: return - + recipe_selection = self.recipe_dropdown.currentText() points_per_side = self.points_per_side_spinner.spinner.value() pred_iou_thresh = self.pred_iou_thresh_spinner.spinner.value() - stability_score_thresh = self.stability_score_thresh_spinner.spinner.value() + stability_score_thresh = ( + self.stability_score_thresh_spinner.spinner.value() + ) box_nms_thresh = self.box_nms_thresh_spinner.spinner.value() crop_n_layers = self.crop_n_layers_spinner.spinner.value() @@ -390,37 +397,83 @@ def process(self): box_nms_thresh=box_nms_thresh, crop_n_layers=crop_n_layers, ) - - self.textBrowser_log.append("Running SAM automatic mask generator recipe") - self.textBrowser_log.append(f"SAM prompt is grid of {points_per_side} by {points_per_side} points") + + self.textBrowser_log.append( + "Running SAM automatic mask generator recipe" + ) + self.textBrowser_log.append( + f"SAM prompt is grid of {points_per_side} by {points_per_side} points" + ) self.progressBar.setValue(30) self.textBrowser_log.append("Generating 3D labels with vitb model") - + self.results = self._predictor.generate(self.image) + elif recipe_selection == "Mobile SAM Finetuned": + self.textBrowser_log.append("Running finetuned mobileSAMv2 recipe") + self.progressBar.setValue(20) + self.textBrowser_log.append( + "Detecting bounding boxes with FasterRCNN Object Aware Model" + ) + self.textBrowser_log.repaint() + QApplication.processEvents() + + bounding_boxes = get_bounding_boxes( + self.image, + detector_model="Finetuned", + device="cuda", + conf=0.4, + iou=0.5, + ) + self.textBrowser_log.append( + f"SAM prompt is {len(bounding_boxes)} bounding boxes" + ) + self.progressBar.setValue(40) + self.textBrowser_log.append( + "Generating 3D labels with efficientvit_l2 model" + ) + + self.results = get_mobileSAMv2(self.image, bounding_boxes) + + for result, bbox in zip(self.results, bounding_boxes): + result["prompt_bbox"] = bbox + elif recipe_selection == "Mobile SAM v2": self.textBrowser_log.append("Running mobileSAMv2 recipe") self.progressBar.setValue(20) - self.textBrowser_log.append("Detecting bounding boxes with YOLO Object Aware Model") + + self.textBrowser_log.append( + "Detecting bounding boxes with YOLO Object Aware Model" + ) self.textBrowser_log.repaint() QApplication.processEvents() - - bounding_boxes = get_bounding_boxes(self.image, imgsz=1024, iou = 0.5, conf=0.01, max_det=10000, device='cuda') - - self.textBrowser_log.append(f"SAM prompt is {len(bounding_boxes)} bounding boxes") + + bounding_boxes = get_bounding_boxes( + self.image, + detector_model="YOLOv8", + device="cuda", + conf=0.4, + iou=0.5, + ) + + self.textBrowser_log.append( + f"SAM prompt is {len(bounding_boxes)} bounding boxes" + ) self.progressBar.setValue(40) - self.textBrowser_log.append("Generating 3D labels with efficientvit_l2 model") - + self.textBrowser_log.append( + "Generating 3D labels with efficientvit_l2 model" + ) + self.results = get_mobileSAMv2(self.image, bounding_boxes) - + for result, bbox in zip(self.results, bounding_boxes): result["prompt_bbox"] = bbox - + self.results = sorted( self.results, key=lambda x: x["area"], reverse=False ) - self.textBrowser_log.append(str(len(self.results))+" objects found") + self.textBrowser_log.append(str(len(self.results)) + " objects found") self.progressBar.setValue(60) label_num = 1 @@ -428,9 +481,9 @@ def process(self): result["keep"] = True result["label_num"] = label_num label_num += 1 - + self.textBrowser_log.append("Adding properties to label image") - add_properties_to_label_image(self.image, self.results) + add_properties_to_label_image(self.image, self.results) label_image = make_label_image_3d(self.results) self.add_points() @@ -639,7 +692,7 @@ def load_image(self, im_layer: Optional[Image]) -> None: self.min_max_area_slider.max_spinbox.setRange(0, max_area) self.min_max_area_slider.min_slider.setRange(0, max_area) self.min_max_area_slider.max_slider.setRange(0, max_area) - + for i in range(len(self.viewer.layers)): if self.viewer.layers[i].name == "SAM box": self.viewer.layers.move(i, -1) @@ -668,7 +721,7 @@ def add_points(self): def add_boxes(self): # delete old boxes self._boxes_layer.data = [] - self.viewer.dims.ndisplay = 2 + self.viewer.dims.ndisplay = 2 for result in self.results: # if point_coords is a key if "prompt_bbox" in result: diff --git a/src/napari_segment_everything/test_images/010_img.png b/src/napari_segment_everything/test_images/010_img.png new file mode 100755 index 0000000..c9c7f10 Binary files /dev/null and b/src/napari_segment_everything/test_images/010_img.png differ diff --git a/src/napari_segment_everything/test_images/016_img.png b/src/napari_segment_everything/test_images/016_img.png new file mode 100755 index 0000000..c837131 Binary files /dev/null and b/src/napari_segment_everything/test_images/016_img.png differ diff --git a/src/napari_segment_everything/test_images/040_img.png b/src/napari_segment_everything/test_images/040_img.png new file mode 100755 index 0000000..bef52f1 Binary files /dev/null and b/src/napari_segment_everything/test_images/040_img.png differ diff --git a/src/napari_segment_everything/test_images/057_img.png b/src/napari_segment_everything/test_images/057_img.png new file mode 100755 index 0000000..9ff3103 Binary files /dev/null and b/src/napari_segment_everything/test_images/057_img.png differ diff --git a/src/napari_segment_everything/test_images/067_img.png b/src/napari_segment_everything/test_images/067_img.png new file mode 100755 index 0000000..d05605a Binary files /dev/null and b/src/napari_segment_everything/test_images/067_img.png differ