diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index 0480091..e12b20a 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -23,7 +23,7 @@ jobs: strategy: matrix: platform: [ubuntu-latest, windows-latest, macos-latest] - python-version: ['3.8', '3.9', '3.10'] + python-version: ['3.9', '3.10'] steps: - uses: actions/checkout@v3 diff --git a/LICENSE b/LICENSE index e6bb677..8946808 100644 --- a/LICENSE +++ b/LICENSE @@ -2,9 +2,12 @@ Copyright (c) 2024, Brian Northan All rights reserved. -This work re-uses some code from napari sam (see license here https://github.com/MIC-DKFZ/napari-sam/blob/main/LICENSE) +This work re-uses some code from the following sources: +- napari sam (License: https://github.com/MIC-DKFZ/napari-sam/blob/main/LICENSE) +- napari-segment-anything (License: https://github.com/royerlab/napari-segment-anything/blob/main/LICENSE) +- MobileSAMv2 (License: https://github.com/ChaoningZhang/MobileSAM/blob/master/LICENSE) -And also re-uses some code from napari-segment-anything (see license here https://github.com/royerlab/napari-segment-anything/blob/main/LICENSE) +Additionally, training data is from Cellpose's annotated dataset, which is licensed CC-by-NC (see: https://github.com/MouseLand/cellpose?tab=readme-ov-file) Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/src/napari_segment_everything/_tests/test_mobile_sam.py b/src/napari_segment_everything/_tests/test_mobile_sam.py index bf87bb2..8a09143 100644 --- a/src/napari_segment_everything/_tests/test_mobile_sam.py +++ b/src/napari_segment_everything/_tests/test_mobile_sam.py @@ -17,18 +17,30 @@ import requests from gdown.parse_url import parse_url +device = get_device() +if device == "mps": + device = "cpu" +#%% def test_urls(): """ - Tests whether all the urls for the model weights exist. + Tests whether all the urls for the model weights exist and are accessible. """ - for url in SAM_WEIGHTS_URL.values(): - if url.startswith("https://drive.google.com/"): - _, path_exists = parse_url(url) - assert path_exists - else: - req = requests.head(url) - assert req.status_code == 200 + TIMEOUT = 1 + for name, url in SAM_WEIGHTS_URL.items(): + try: + if url.startswith("https://drive.google.com/"): + _, path_exists = parse_url(url) + assert ( + path_exists + ), f"Google Drive URL path wasn't parsed correctly: {url}" + else: + req = requests.head(url, timeout=TIMEOUT) + assert ( + req.status_code == 200 + ), f"Failed to access URL: {url}, Status code: {req.status_code}" + except requests.exceptions.Timeout: + print(f"Request timed out for URL: {url}") def test_mobile_sam(): @@ -42,7 +54,7 @@ def test_mobile_sam(): image, detector_model="YOLOv8", imgsz=1024, - device="cuda", + device=device, conf=0.4, iou=0.9, ) @@ -57,7 +69,7 @@ def test_bbox(): """ image = data.coffee() bounding_boxes = get_bounding_boxes( - image, detector_model="Finetuned", device="cuda", conf=0.01, iou=0.99 + image, detector_model="YOLOv8", device=device, conf=0.9, iou=0.90 ) print(f"Length of bounding boxes: {len(bounding_boxes)}") assert len(bounding_boxes) > 0 @@ -70,12 +82,9 @@ def test_RCNN(): image = data.coffee() model_path = str(get_weights_path("ObjectAwareModel_Cell_FT")) assert os.path.exists(model_path) - rcnn_cpu = RcnnDetector(model_path, device="cpu") - rcnn_cuda = RcnnDetector(model_path, device="cuda") - bbox_cpu = rcnn_cpu.get_bounding_boxes(image, conf=0.5, iou=0.2) - bbox_cuda = rcnn_cuda.get_bounding_boxes(image, conf=0.5, iou=0.2) - assert len(bbox_cpu) == 6 - assert len(bbox_cuda) == 6 + rcnn = RcnnDetector(model_path, device=device) + bbox = rcnn.get_bounding_boxes(image, conf=0.5, iou=0.2) + assert len(bbox) == 6 def test_YOLO(): @@ -85,16 +94,11 @@ def test_YOLO(): image = data.coffee() model_path = str(get_weights_path("ObjectAwareModel")) assert os.path.exists(model_path) - yolo_cpu = YoloDetector(model_path, device="cpu") - yolo_cuda = YoloDetector(model_path, device="cuda") - bbox_cpu = yolo_cpu.get_bounding_boxes( - image, conf=0.5, iou=0.2, max_det=400, imgsz=1024 - ) - bbox_cuda = yolo_cuda.get_bounding_boxes( + yolo = YoloDetector(model_path, device=device) + bbox = yolo.get_bounding_boxes( image, conf=0.5, iou=0.2, max_det=400, imgsz=1024 ) - assert len(bbox_cpu) == 8 - assert len(bbox_cuda) == 8 + assert len(bbox) == 8 def test_weights_path(): @@ -110,7 +114,6 @@ def test_labels(): Tests whether region properties can be generated for segmentations for different models """ image = data.coffee() - device = get_device() bbox_yolo = get_bounding_boxes( image, @@ -150,8 +153,7 @@ def test_labels(): assert len(props_yolo) == 10 props_vit_b = segmentations_vit_b[0].keys() assert len(props_vit_b) == 13 - - + test_urls() test_bbox() test_mobile_sam() diff --git a/src/napari_segment_everything/minimal_detection/detect_and_segment.py b/src/napari_segment_everything/minimal_detection/detect_and_segment.py index 5b39b53..533cbb2 100644 --- a/src/napari_segment_everything/minimal_detection/detect_and_segment.py +++ b/src/napari_segment_everything/minimal_detection/detect_and_segment.py @@ -35,7 +35,7 @@ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] -def segment_from_bbox(bounding_boxes, predictor, mobilesamv2): +def segment_from_bbox(bounding_boxes, predictor, mobilesamv2, device): """ Segments everything given the bounding boxes of the objects and the mobileSAMv2 prediction model. Code from mobileSAMv2 @@ -43,7 +43,10 @@ def segment_from_bbox(bounding_boxes, predictor, mobilesamv2): input_boxes = predictor.transform.apply_boxes( bounding_boxes, predictor.original_size ) # Does this need to be transformed? - input_boxes = torch.from_numpy(input_boxes).cuda() + if device == "cuda": + input_boxes = torch.from_numpy(input_boxes).cuda() + elif device == "cpu": + input_boxes = torch.from_numpy(input_boxes) sam_mask = [] predicted_ious = [] diff --git a/src/napari_segment_everything/minimal_detection/prompt_generator.py b/src/napari_segment_everything/minimal_detection/prompt_generator.py index 278b2e4..e89da6b 100644 --- a/src/napari_segment_everything/minimal_detection/prompt_generator.py +++ b/src/napari_segment_everything/minimal_detection/prompt_generator.py @@ -109,11 +109,13 @@ class RcnnDetector(BaseDetector): def __init__(self, model_path, device, trainable=True): super().__init__(model_path, trainable) self.model_type = "FasterRCNN" + if device == "mps": + device = "cpu" self.device = device self.model = fasterrcnn_mobilenet_v3_large_fpn( box_detections_per_img=500, - ).to(device) - self.model.load_state_dict(torch.load(model_path)) + ).to(self.device) + self.model.load_state_dict(torch.load(model_path, map_location=self.device)) def train(self, training_data): if self.trainable: diff --git a/src/napari_segment_everything/sam_helper.py b/src/napari_segment_everything/sam_helper.py index 165e5a1..12be8b8 100644 --- a/src/napari_segment_everything/sam_helper.py +++ b/src/napari_segment_everything/sam_helper.py @@ -132,8 +132,11 @@ def get_sam_automatic_mask_generator( crop_n_layers=1, ): + device = get_device() + if device == "mps": + device = "cpu" sam = sam_model_registry[model_type](get_weights_path(model_type)) - sam.to(get_device()) + sam.to() sam_anything_predictor = SamAutomaticMaskGenerator( sam, points_per_side=int(points_per_side), @@ -178,7 +181,7 @@ def get_bounding_boxes( return bounding_boxes -def get_mobileSAMv2(image=None, bounding_boxes=None): +def get_mobileSAMv2(image=None, bounding_boxes=None, device=get_device()): """ Uses a SAM model to make predictions from bounding boxes. @@ -201,7 +204,6 @@ def get_mobileSAMv2(image=None, bounding_boxes=None): return if image.ndim < 3: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - device = "cuda" if torch.cuda.is_available() else "cpu" # device = "cpu" weights_path_VIT = get_weights_path("efficientvit_l2") samV2 = create_MS_model() @@ -209,11 +211,15 @@ def get_mobileSAMv2(image=None, bounding_boxes=None): samV2.image_encoder = sam_model_registry["efficientvit_l2"]( weights_path_VIT ) + if device == "mps": + device="cpu" samV2.to(device=device) samV2.eval() predictor = SamPredictorV2(samV2) predictor.set_image(image) - sam_masks = segment_from_bbox(bounding_boxes, predictor, samV2) + sam_masks = segment_from_bbox( + bounding_boxes, predictor, samV2, device=device + ) del bounding_boxes gc.collect() @@ -304,7 +310,7 @@ def add_properties_to_label_image(orig_image, sorted_results): # for small pixelated objects, circularity can be > 1 so we cap it if result["circularity"] > 1: result["circularity"] = 1 - + result["solidity"] = regions[0].solidity intensity_pixels = intensity[coords] result["mean_intensity"] = np.mean(intensity_pixels) diff --git a/src/napari_segment_everything/segment_everything.py b/src/napari_segment_everything/segment_everything.py index 6785b76..996ed67 100644 --- a/src/napari_segment_everything/segment_everything.py +++ b/src/napari_segment_everything/segment_everything.py @@ -17,6 +17,7 @@ from napari_segment_everything.sam_helper import ( get_bounding_boxes, get_mobileSAMv2, + get_device, ) import pickle @@ -35,7 +36,7 @@ QMessageBox, QTextBrowser, QProgressBar, - QApplication + QApplication, ) @@ -161,7 +162,6 @@ def on_index_changed(index): self.stacked_algorithm_params_layout = QStackedWidget() - self.bbox_conf_spinner = LabeledSpinner( "Bounding Box Confidence", 0, 1, 0.1, None, is_double=True ) @@ -186,7 +186,6 @@ def on_index_changed(index): self.yolo_params_layout.addWidget(self.bbbox_max_det_spinner) self.widgetGroup1.setLayout(self.yolo_params_layout) - self.points_per_side_spinner = LabeledSpinner( "Points per side", 4, 100, 32, None ) @@ -392,30 +391,29 @@ def open_project(self): image = project["image"] self.load_project(image, results) - - def load_project(self, image, results): self.results = results - self.results = sorted(self.results, key=lambda x: x['area'], reverse=False) + self.results = sorted( + self.results, key=lambda x: x["area"], reverse=False + ) label_num = 1 for result in self.results: - result['keep'] = True - result['label_num'] = label_num + result["keep"] = True + result["label_num"] = label_num label_num += 1 - self.image = image add_properties_to_label_image(self.image, self.results) self.viewer.add_image(image) - + self._3D_labels_layer.data = make_label_image_3d(self.results) self.viewer.dims.ndisplay = 3 self._3D_labels_layer.translate = (-len(self.results), 0, 0) - + self.add_points() self.add_boxes() self.update_slider_min_max() - + def save_project(self): options = QFileDialog.Options() file_name, _ = QFileDialog.getSaveFileName( @@ -485,7 +483,7 @@ def process(self): bounding_boxes = get_bounding_boxes( self.image, detector_model="Finetuned", - device="cuda", + device=get_device(), conf=bbox_conf, iou=bbox_iou, ) @@ -511,11 +509,10 @@ def process(self): ) self.textBrowser_log.repaint() QApplication.processEvents() - bounding_boxes = get_bounding_boxes( self.image, detector_model="YOLOv8", - device="cuda", + device=get_device(), conf=bbox_conf, iou=bbox_iou, imgsz=bbox_imgsz, diff --git a/tox.ini b/tox.ini index 3c2e09a..cbcf971 100644 --- a/tox.ini +++ b/tox.ini @@ -5,7 +5,7 @@ isolated_build=true [gh-actions] python = - 3.8: py38 +# 3.8: py38 3.9: py39 3.10: py310