Skip to content

Commit

Permalink
Fixed tests for more exact argument passing
Browse files Browse the repository at this point in the history
  • Loading branch information
ian-coccimiglio committed Apr 23, 2024
1 parent 56f15b0 commit 09cb888
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
13 changes: 10 additions & 3 deletions src/napari_segment_everything/_tests/test_mobile_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@ def test_mobile_sam():
# load a color examp
image = data.coffee()

bounding_boxes = get_bounding_boxes(image, imgsz=1024, device="cuda")
bounding_boxes = get_bounding_boxes(
image,
detector_model="YOLOv8",
imgsz=1024,
device="cuda",
conf=0.4,
iou=0.9,
)
segmentations = get_mobileSAMv2(image, bounding_boxes)

assert len(segmentations) == 11
Expand All @@ -27,5 +34,5 @@ def test_bbox():
return segmentations


seg = test_bbox()
# seg, regions=test_mobile_sam()
# seg = test_bbox()
seg = test_mobile_sam()
15 changes: 13 additions & 2 deletions src/napari_segment_everything/sam_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,17 +153,28 @@ def get_sam_automatic_mask_generator(
return sam_anything_predictor


def get_bounding_boxes(image, detector_model, device="cpu", conf=0.5, iou=0.2):
def get_bounding_boxes(
image,
detector_model,
device="cpu",
conf=0.4,
iou=0.5,
imgsz=1024,
max_det=400,
):
if detector_model == "YOLOv8":
model = YoloDetector(
str(get_weights_path("ObjectAwareModel")), device="cuda"
)
bounding_boxes = model.get_bounding_boxes(
image, conf=conf, iou=iou, imgsz=imgsz, max_det=max_det
)
elif detector_model == "Finetuned":
model = RcnnDetector(
str(get_weights_path("ObjectAwareModel_Cell_FT")),
device="cuda",
)
bounding_boxes = model.get_bounding_boxes(image, conf=conf, iou=iou)
bounding_boxes = model.get_bounding_boxes(image, conf=conf, iou=iou)
print(bounding_boxes)
return bounding_boxes

Expand Down

0 comments on commit 09cb888

Please sign in to comment.