From 09cb8884641ebb896dd6073a6d69ccdc5f030f33 Mon Sep 17 00:00:00 2001 From: Ian Coccimiglio Date: Tue, 23 Apr 2024 05:58:46 -0700 Subject: [PATCH] Fixed tests for more exact argument passing --- .../_tests/test_mobile_sam.py | 13 ++++++++++--- src/napari_segment_everything/sam_helper.py | 15 +++++++++++++-- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/napari_segment_everything/_tests/test_mobile_sam.py b/src/napari_segment_everything/_tests/test_mobile_sam.py index ad811c4..bb6597f 100644 --- a/src/napari_segment_everything/_tests/test_mobile_sam.py +++ b/src/napari_segment_everything/_tests/test_mobile_sam.py @@ -10,7 +10,14 @@ def test_mobile_sam(): # load a color examp image = data.coffee() - bounding_boxes = get_bounding_boxes(image, imgsz=1024, device="cuda") + bounding_boxes = get_bounding_boxes( + image, + detector_model="YOLOv8", + imgsz=1024, + device="cuda", + conf=0.4, + iou=0.9, + ) segmentations = get_mobileSAMv2(image, bounding_boxes) assert len(segmentations) == 11 @@ -27,5 +34,5 @@ def test_bbox(): return segmentations -seg = test_bbox() -# seg, regions=test_mobile_sam() +# seg = test_bbox() +seg = test_mobile_sam() diff --git a/src/napari_segment_everything/sam_helper.py b/src/napari_segment_everything/sam_helper.py index b24399d..6efe185 100644 --- a/src/napari_segment_everything/sam_helper.py +++ b/src/napari_segment_everything/sam_helper.py @@ -153,17 +153,28 @@ def get_sam_automatic_mask_generator( return sam_anything_predictor -def get_bounding_boxes(image, detector_model, device="cpu", conf=0.5, iou=0.2): +def get_bounding_boxes( + image, + detector_model, + device="cpu", + conf=0.4, + iou=0.5, + imgsz=1024, + max_det=400, +): if detector_model == "YOLOv8": model = YoloDetector( str(get_weights_path("ObjectAwareModel")), device="cuda" ) + bounding_boxes = model.get_bounding_boxes( + image, conf=conf, iou=iou, imgsz=imgsz, max_det=max_det + ) elif detector_model == "Finetuned": model = RcnnDetector( str(get_weights_path("ObjectAwareModel_Cell_FT")), device="cuda", ) - bounding_boxes = model.get_bounding_boxes(image, conf=conf, iou=iou) + bounding_boxes = model.get_bounding_boxes(image, conf=conf, iou=iou) print(bounding_boxes) return bounding_boxes