Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mac and non-GPU support #6

Merged
merged 14 commits into from
Apr 29, 2024
2 changes: 1 addition & 1 deletion .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
strategy:
matrix:
platform: [ubuntu-latest, windows-latest, macos-latest]
python-version: ['3.8', '3.9', '3.10']
python-version: ['3.9', '3.10']

steps:
- uses: actions/checkout@v3
Expand Down
7 changes: 5 additions & 2 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
Copyright (c) 2024, Brian Northan
All rights reserved.

This work re-uses some code from napari sam (see license here https://github.com/MIC-DKFZ/napari-sam/blob/main/LICENSE)
This work re-uses some code from the following sources:
- napari sam (License: https://github.com/MIC-DKFZ/napari-sam/blob/main/LICENSE)
- napari-segment-anything (License: https://github.com/royerlab/napari-segment-anything/blob/main/LICENSE)
- MobileSAMv2 (License: https://github.com/ChaoningZhang/MobileSAM/blob/master/LICENSE)

And also re-uses some code from napari-segment-anything (see license here https://github.com/royerlab/napari-segment-anything/blob/main/LICENSE)
Additionally, training data is from Cellpose's annotated dataset, which is licensed CC-by-NC (see: https://github.com/MouseLand/cellpose?tab=readme-ov-file)

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
Expand Down
56 changes: 29 additions & 27 deletions src/napari_segment_everything/_tests/test_mobile_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,30 @@
import requests
from gdown.parse_url import parse_url

device = get_device()
if device == "mps":
device = "cpu"

#%%
def test_urls():
"""
Tests whether all the urls for the model weights exist.
Tests whether all the urls for the model weights exist and are accessible.
"""
for url in SAM_WEIGHTS_URL.values():
if url.startswith("https://drive.google.com/"):
_, path_exists = parse_url(url)
assert path_exists
else:
req = requests.head(url)
assert req.status_code == 200
TIMEOUT = 1
for name, url in SAM_WEIGHTS_URL.items():
try:
if url.startswith("https://drive.google.com/"):
_, path_exists = parse_url(url)
assert (
path_exists
), f"Google Drive URL path wasn't parsed correctly: {url}"
else:
req = requests.head(url, timeout=TIMEOUT)
assert (
req.status_code == 200
), f"Failed to access URL: {url}, Status code: {req.status_code}"
except requests.exceptions.Timeout:
print(f"Request timed out for URL: {url}")


def test_mobile_sam():
Expand All @@ -42,7 +54,7 @@ def test_mobile_sam():
image,
detector_model="YOLOv8",
imgsz=1024,
device="cuda",
device=device,
conf=0.4,
iou=0.9,
)
Expand All @@ -57,7 +69,7 @@ def test_bbox():
"""
image = data.coffee()
bounding_boxes = get_bounding_boxes(
image, detector_model="Finetuned", device="cuda", conf=0.01, iou=0.99
image, detector_model="YOLOv8", device=device, conf=0.9, iou=0.90
)
print(f"Length of bounding boxes: {len(bounding_boxes)}")
assert len(bounding_boxes) > 0
Expand All @@ -70,12 +82,9 @@ def test_RCNN():
image = data.coffee()
model_path = str(get_weights_path("ObjectAwareModel_Cell_FT"))
assert os.path.exists(model_path)
rcnn_cpu = RcnnDetector(model_path, device="cpu")
rcnn_cuda = RcnnDetector(model_path, device="cuda")
bbox_cpu = rcnn_cpu.get_bounding_boxes(image, conf=0.5, iou=0.2)
bbox_cuda = rcnn_cuda.get_bounding_boxes(image, conf=0.5, iou=0.2)
assert len(bbox_cpu) == 6
assert len(bbox_cuda) == 6
rcnn = RcnnDetector(model_path, device=device)
bbox = rcnn.get_bounding_boxes(image, conf=0.5, iou=0.2)
assert len(bbox) == 6


def test_YOLO():
Expand All @@ -85,16 +94,11 @@ def test_YOLO():
image = data.coffee()
model_path = str(get_weights_path("ObjectAwareModel"))
assert os.path.exists(model_path)
yolo_cpu = YoloDetector(model_path, device="cpu")
yolo_cuda = YoloDetector(model_path, device="cuda")
bbox_cpu = yolo_cpu.get_bounding_boxes(
image, conf=0.5, iou=0.2, max_det=400, imgsz=1024
)
bbox_cuda = yolo_cuda.get_bounding_boxes(
yolo = YoloDetector(model_path, device=device)
bbox = yolo.get_bounding_boxes(
image, conf=0.5, iou=0.2, max_det=400, imgsz=1024
)
assert len(bbox_cpu) == 8
assert len(bbox_cuda) == 8
assert len(bbox) == 8


def test_weights_path():
Expand All @@ -110,7 +114,6 @@ def test_labels():
Tests whether region properties can be generated for segmentations for different models
"""
image = data.coffee()
device = get_device()

bbox_yolo = get_bounding_boxes(
image,
Expand Down Expand Up @@ -150,8 +153,7 @@ def test_labels():
assert len(props_yolo) == 10
props_vit_b = segmentations_vit_b[0].keys()
assert len(props_vit_b) == 13



test_urls()
test_bbox()
test_mobile_sam()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,18 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]


def segment_from_bbox(bounding_boxes, predictor, mobilesamv2):
def segment_from_bbox(bounding_boxes, predictor, mobilesamv2, device):
"""
Segments everything given the bounding boxes of the objects and the mobileSAMv2 prediction model.
Code from mobileSAMv2
"""
input_boxes = predictor.transform.apply_boxes(
bounding_boxes, predictor.original_size
) # Does this need to be transformed?
input_boxes = torch.from_numpy(input_boxes).cuda()
if device == "cuda":
input_boxes = torch.from_numpy(input_boxes).cuda()
elif device == "cpu":
input_boxes = torch.from_numpy(input_boxes)
sam_mask = []

predicted_ious = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,13 @@ class RcnnDetector(BaseDetector):
def __init__(self, model_path, device, trainable=True):
super().__init__(model_path, trainable)
self.model_type = "FasterRCNN"
if device == "mps":
device = "cpu"
self.device = device
self.model = fasterrcnn_mobilenet_v3_large_fpn(
box_detections_per_img=500,
).to(device)
self.model.load_state_dict(torch.load(model_path))
).to(self.device)
self.model.load_state_dict(torch.load(model_path, map_location=self.device))

def train(self, training_data):
if self.trainable:
Expand Down
16 changes: 11 additions & 5 deletions src/napari_segment_everything/sam_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,11 @@ def get_sam_automatic_mask_generator(
crop_n_layers=1,
):

device = get_device()
if device == "mps":
device = "cpu"
sam = sam_model_registry[model_type](get_weights_path(model_type))
sam.to(get_device())
sam.to()
sam_anything_predictor = SamAutomaticMaskGenerator(
sam,
points_per_side=int(points_per_side),
Expand Down Expand Up @@ -178,7 +181,7 @@ def get_bounding_boxes(
return bounding_boxes


def get_mobileSAMv2(image=None, bounding_boxes=None):
def get_mobileSAMv2(image=None, bounding_boxes=None, device=get_device()):
"""
Uses a SAM model to make predictions from bounding boxes.

Expand All @@ -201,19 +204,22 @@ def get_mobileSAMv2(image=None, bounding_boxes=None):
return
if image.ndim < 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
weights_path_VIT = get_weights_path("efficientvit_l2")
samV2 = create_MS_model()

samV2.image_encoder = sam_model_registry["efficientvit_l2"](
weights_path_VIT
)
if device == "mps":
device="cpu"
samV2.to(device=device)
samV2.eval()
predictor = SamPredictorV2(samV2)
predictor.set_image(image)
sam_masks = segment_from_bbox(bounding_boxes, predictor, samV2)
sam_masks = segment_from_bbox(
bounding_boxes, predictor, samV2, device=device
)
del bounding_boxes

gc.collect()
Expand Down Expand Up @@ -304,7 +310,7 @@ def add_properties_to_label_image(orig_image, sorted_results):
# for small pixelated objects, circularity can be > 1 so we cap it
if result["circularity"] > 1:
result["circularity"] = 1

result["solidity"] = regions[0].solidity
intensity_pixels = intensity[coords]
result["mean_intensity"] = np.mean(intensity_pixels)
Expand Down
27 changes: 12 additions & 15 deletions src/napari_segment_everything/segment_everything.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from napari_segment_everything.sam_helper import (
get_bounding_boxes,
get_mobileSAMv2,
get_device,
)
import pickle

Expand All @@ -35,7 +36,7 @@
QMessageBox,
QTextBrowser,
QProgressBar,
QApplication
QApplication,
)


Expand Down Expand Up @@ -161,7 +162,6 @@ def on_index_changed(index):

self.stacked_algorithm_params_layout = QStackedWidget()


self.bbox_conf_spinner = LabeledSpinner(
"Bounding Box Confidence", 0, 1, 0.1, None, is_double=True
)
Expand All @@ -186,7 +186,6 @@ def on_index_changed(index):
self.yolo_params_layout.addWidget(self.bbbox_max_det_spinner)
self.widgetGroup1.setLayout(self.yolo_params_layout)


self.points_per_side_spinner = LabeledSpinner(
"Points per side", 4, 100, 32, None
)
Expand Down Expand Up @@ -392,30 +391,29 @@ def open_project(self):
image = project["image"]
self.load_project(image, results)



def load_project(self, image, results):
self.results = results
self.results = sorted(self.results, key=lambda x: x['area'], reverse=False)
self.results = sorted(
self.results, key=lambda x: x["area"], reverse=False
)
label_num = 1
for result in self.results:
result['keep'] = True
result['label_num'] = label_num
result["keep"] = True
result["label_num"] = label_num
label_num += 1


self.image = image
add_properties_to_label_image(self.image, self.results)
self.viewer.add_image(image)

self._3D_labels_layer.data = make_label_image_3d(self.results)
self.viewer.dims.ndisplay = 3
self._3D_labels_layer.translate = (-len(self.results), 0, 0)

self.add_points()
self.add_boxes()
self.update_slider_min_max()

def save_project(self):
options = QFileDialog.Options()
file_name, _ = QFileDialog.getSaveFileName(
Expand Down Expand Up @@ -485,7 +483,7 @@ def process(self):
bounding_boxes = get_bounding_boxes(
self.image,
detector_model="Finetuned",
device="cuda",
device=get_device(),
conf=bbox_conf,
iou=bbox_iou,
)
Expand All @@ -511,11 +509,10 @@ def process(self):
)
self.textBrowser_log.repaint()
QApplication.processEvents()

bounding_boxes = get_bounding_boxes(
self.image,
detector_model="YOLOv8",
device="cuda",
device=get_device(),
conf=bbox_conf,
iou=bbox_iou,
imgsz=bbox_imgsz,
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ isolated_build=true

[gh-actions]
python =
3.8: py38
# 3.8: py38
3.9: py39
3.10: py310

Expand Down
Loading