Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cell Finetuned Model + Tests + OOP #3

Merged
merged 11 commits into from
Apr 23, 2024
36 changes: 31 additions & 5 deletions src/napari_segment_everything/_tests/test_mobile_sam.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,38 @@
# import skimag example images
from skimage import data
from napari_segment_everything.sam_helper import get_mobileSAMv2, get_bounding_boxes
from napari_segment_everything.sam_helper import (
get_mobileSAMv2,
get_bounding_boxes,
)


def test_mobile_sam():
# load a color examp
image = data.coffee()

bounding_boxes = get_bounding_boxes(image, imgsz=1024, device = 'cuda')

bounding_boxes = get_bounding_boxes(
image,
detector_model="YOLOv8",
imgsz=1024,
device="cuda",
conf=0.4,
iou=0.9,
)
segmentations = get_mobileSAMv2(image, bounding_boxes)

assert len(segmentations) == 11


def test_bbox():
# load a color examp
image = data.coffee()
bounding_boxes = get_bounding_boxes(
image, detector_model="Finetuned", device="cuda", conf=0.01, iou=0.99
)
print(f"Length of bounding boxes: {len(bounding_boxes)}")
segmentations = get_mobileSAMv2(image, bounding_boxes)

assert len(segmentations) == 11
return segmentations


# seg = test_bbox()
seg = test_mobile_sam()
9 changes: 5 additions & 4 deletions src/napari_segment_everything/interactive_launch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import napari
from napari_segment_everything import segment_everything
from napari_segment_everything import segment_everything

viewer = napari.Viewer()

viewer.window.add_dock_widget(segment_everything.NapariSegmentEverything(viewer))

k=input("press close to exit")
viewer.window.add_dock_widget(
segment_everything.NapariSegmentEverything(viewer)
)
napari.run()
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,7 @@
from segment_anything.utils.amg import calculate_stability_score
import gc

import sys

"""
This is necessary because the torch weights were
pickled with its environment, which messed up the imports
"""

current_dir = os.path.dirname(__file__)
obj_detect_dir = os.path.join(current_dir, "object_detection")
sys.path.insert(0, obj_detect_dir)

from .object_detection.ultralytics.prompt_mobilesamv2 import ObjectAwareModel


def create_OA_model(weights_path):
object_aware_model = ObjectAwareModel(weights_path)
return object_aware_model


def create_MS_model():
Expand All @@ -51,32 +35,6 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]


def detect_bbox(
object_aware_model,
image,
imgsz=1024,
conf=0.4,
iou=0.9,
device="cpu",
max_det=300,
):
"""
Uses an object aware model to produce bounding boxes for a given image at image_path.

Returns a list of bounding boxes, as well as extra properties.
"""
obj_results = object_aware_model(
image,
device=device,
retina_masks=True,
imgsz=imgsz,
conf=conf,
iou=iou,
max_det=max_det,
)
return obj_results


def segment_from_bbox(bounding_boxes, predictor, mobilesamv2):
"""
Segments everything given the bounding boxes of the objects and the mobileSAMv2 prediction model.
Expand All @@ -90,7 +48,7 @@ def segment_from_bbox(bounding_boxes, predictor, mobilesamv2):

predicted_ious = []
stability_scores = []

image_embedding = predictor.features
image_embedding = torch.repeat_interleave(image_embedding, 400, dim=0)

Expand Down Expand Up @@ -133,7 +91,7 @@ def segment_from_bbox(bounding_boxes, predictor, mobilesamv2):
stability_scores.extend(stability_score.flatten().tolist())

sam_mask = torch.cat(sam_mask)
#predicted_ious = pred_ious.cpu().numpy()
# predicted_ious = pred_ious.cpu().numpy()
cpu_segmentations = sam_mask.cpu().numpy()
del sam_mask

Expand All @@ -148,5 +106,9 @@ def segment_from_bbox(bounding_boxes, predictor, mobilesamv2):
"predicted_iou": predicted_ious[idx],
"stability_score": stability_scores[idx],
}
if (
cpu_segmentations[idx].max() < 1
): # this means that bboxes won't always == segmentations
continue
curr_anns.append(ann)
return curr_anns
157 changes: 157 additions & 0 deletions src/napari_segment_everything/minimal_detection/prompt_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Apr 22 23:58:14 2024

@author: ian
"""
import cv2
import os, sys
from torchvision.models.detection import fasterrcnn_mobilenet_v3_large_fpn
from torchvision.transforms import ToTensor
from torchvision.ops import nms
import torch

"""
This is necessary because the torch weights were
pickled with its environment, which messed up the imports
"""

current_dir = os.path.dirname(__file__)
obj_detect_dir = os.path.join(current_dir, "object_detection")
sys.path.insert(0, obj_detect_dir)

from napari_segment_everything.minimal_detection.object_detection.ultralytics.prompt_mobilesamv2 import (
ObjectAwareModel,
)


class BaseDetector:
def __init__(self, model_path, trainable=False):
self.model_path = model_path
self.model_name = self.model_path.split("/")[-1]
self.trainable = trainable

def train(self, training_data):
raise NotImplementedError()

def predict(self, image_data):
raise NotImplementedError()


class YoloDetector(BaseDetector):
def __init__(self, model_path, device, trainable=False):
super().__init__(model_path)
self.model_type = "YOLOv8"
self.model = ObjectAwareModel(model_path)
self.device = device

def train(self):
print(
"YOLO detector is not yet trainable, use RcnnDetector for training"
)

def get_bounding_boxes(
self,
image_data,
retina_masks=True,
imgsz=1024,
conf=0.4,
iou=0.9,
max_det=400,
):
"""
Generates a series of bounding boxes in xyxy-format from an image, using the YOLOv8 ObjectAwareModel.

Parameters
----------
image : numpy.ndarray
A 2D-image in grayscale or RGB.
imgsz : INT, optional
Size of the input image. The default is 1024.
conf : FLOAT, optional
Confidence threshold for the bounding boxes. Lower means more boxes will be detected. The default is 0.4.
iou : FLOAT, optional
Threshold for how many intersecting bounding boxes should be allowed. Lower means fewer intersecting boxes will be returned. The default is 0.9.
max_det : INT, optional
Maximum number of detections that will be returned. The default is 400.

Returns
-------
bounding_boxes : numpy.ndarray
An array of boxes in xyxy-coordinates.
"""

print("Predicting bounding boxes for image data")
image_cv2 = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)

obj_results = self.model.predict(
image_cv2,
device=self.device,
retina_masks=True,
imgsz=imgsz,
conf=conf,
iou=iou,
max_det=max_det,
)
bounding_boxes = obj_results[0].boxes.xyxy.cpu().numpy()

return bounding_boxes

def __str__(self):
s = f"\n{'Model':<10}: {self.model_name}\n"
s += f"{'Type':<10}: {str(self.model_type)}\n"
s += f"{'Trainable':<10}: {str(self.trainable)}"
return s


class RcnnDetector(BaseDetector):
def __init__(self, model_path, device, trainable=True):
super().__init__(model_path, trainable)
self.model_type = "FasterRCNN"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = fasterrcnn_mobilenet_v3_large_fpn().to(device)
self.model.load_state_dict(torch.load(model_path))

def train(self, training_data):
if self.trainable:
print("Training model")
print(self.model_path)
print(training_data)

def _get_transform(self, train):
from torchvision.transforms import v2 as T

transforms = []
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
transforms.append(T.ToDtype(torch.float, scale=True))
transforms.append(T.ToPureTensor())
return T.Compose(transforms)

@torch.inference_mode()
def get_bounding_boxes(self, image_data, conf=0.5, iou=0.2):
image_cv2 = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)

print("Predicting bounding boxes for image data")
convert_tensor = ToTensor()
eval_transform = self._get_transform(train=False)
tensor_image = convert_tensor(image_cv2)
x = eval_transform(tensor_image)
# convert RGBA -> RGB and move to device
x = x[:3, ...].to(self.device)
self.model.eval()
predictions = self.model([x])
pred = predictions[0]
# print(pred)
idx_after = nms(pred["boxes"], pred["scores"], iou_threshold=iou)
pred_boxes = pred["boxes"][idx_after]
pred_scores = pred["scores"][idx_after]
pred_boxes_conf = pred_boxes[pred_scores > conf]
return pred_boxes_conf.cpu().numpy()

def __str__(self):
s = f"\n{'Model':<10}: {self.model_name}\n"
s += f"{'Type':<10}: {str(self.model_type)}\n"
s += f"{'Trainable':<10}: {str(self.trainable)}"
return s
Loading
Loading