diff --git a/pyproject.toml b/pyproject.toml index e4998de0..609b3bb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -129,3 +129,9 @@ testpaths = ["tests"] python_files = "test_*.py" python_classes = "Test*" python_functions = "test_*" +filterwarnings = [ + "ignore::UserWarning:torch.amp.autocast_mode", + "ignore::UserWarning:torch.functional", + "ignore:.*CUDA is not available.*:UserWarning", + "ignore:.*torch.meshgrid.*indexing argument.*:UserWarning", +] diff --git a/sam3/model/sam3_image_processor.py b/sam3/model/sam3_image_processor.py index 4d98fbfb..57630881 100644 --- a/sam3/model/sam3_image_processor.py +++ b/sam3/model/sam3_image_processor.py @@ -8,13 +8,23 @@ from sam3.model import box_ops from sam3.model.data_misc import FindStage, interpolate +from sam3.perflib.masks_ops import mask_iou from torchvision.transforms import v2 class Sam3Processor: """ """ - def __init__(self, model, resolution=1008, device="cuda", confidence_threshold=0.5): + def __init__(self, model, resolution=1008, device="cuda", confidence_threshold=0.5, fuse_detections_iou_threshold=None): + """ + Args: + model: The SAM3 model + resolution: Image resolution for processing + device: Device to run on ('cuda' or 'cpu') + confidence_threshold: Minimum score to keep a detection (default: 0.5) + fuse_detections_iou_threshold: IoU threshold for fusing overlapping detections. + If None (default), fusion is disabled. Set to a value (e.g., 0.3) to enable fusion. + """ self.model = model self.resolution = resolution self.device = device @@ -27,6 +37,7 @@ def __init__(self, model, resolution=1008, device="cuda", confidence_threshold=0 ] ) self.confidence_threshold = confidence_threshold + self.fuse_detections_iou_threshold = fuse_detections_iou_threshold self.find_stage = FindStage( img_ids=torch.tensor([0], device=device, dtype=torch.long), @@ -215,8 +226,126 @@ def _forward_grounding(self, state: Dict): align_corners=False, ).sigmoid() + # Apply detection fusion if enabled (merges overlapping detections) + if self.fuse_detections_iou_threshold is not None and len(out_probs) > 0: + out_probs, out_masks, boxes = self._fuse_detections( + out_probs, out_masks, boxes, self.fuse_detections_iou_threshold + ) + state["masks_logits"] = out_masks state["masks"] = out_masks > 0.5 state["boxes"] = boxes state["scores"] = out_probs return state + + def _fuse_detections( + self, + scores: torch.Tensor, + masks: torch.Tensor, + boxes: torch.Tensor, + iou_threshold: float, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Fuse overlapping detections by grouping them and merging masks. + + Args: + scores: (N,) tensor of detection scores + masks: (N, 1, H, W) tensor of mask logits (before thresholding) + boxes: (N, 4) tensor of bounding boxes in [x0, y0, x1, y1] format + iou_threshold: IoU threshold for grouping detections (detections with IoU > threshold are fused) + + Returns: + Fused scores, masks, and boxes tensors + """ + if len(scores) == 0: + return scores, masks, boxes + + # Convert masks to binary for IoU computation + masks_binary = (masks.squeeze(1) > 0.5) # (N, H, W) + + # Compute pairwise IoU matrix + ious = mask_iou(masks_binary, masks_binary) # (N, N) + + # Find connected components based on IoU threshold + # Use Union-Find to group overlapping detections + parent = list(range(len(scores))) + + def find(x): + if parent[x] != x: + parent[x] = find(parent[x]) + return parent[x] + + def union(x, y): + px, py = find(x), find(y) + if px != py: + # Merge into the group with higher score + if scores[px] < scores[py]: + px, py = py, px + parent[py] = px + + # Group detections that overlap above threshold + for i in range(len(scores)): + for j in range(i + 1, len(scores)): + if ious[i, j] > iou_threshold: + union(i, j) + + # Find unique groups + groups = {} + for i in range(len(scores)): + root = find(i) + if root not in groups: + groups[root] = [] + groups[root].append(i) + + # Merge each group + fused_scores = [] + fused_masks = [] + fused_boxes = [] + + for group_indices in groups.values(): + if len(group_indices) == 1: + # Single detection, keep as is - ensure shape is (1, 1, H, W) to match fused masks + fused_scores.append(scores[group_indices[0]]) + single_mask = masks[group_indices[0]] # (1, H, W) from (N, 1, H, W) + if single_mask.dim() == 3: + single_mask = single_mask.unsqueeze(0) # (1, 1, H, W) + fused_masks.append(single_mask) + fused_boxes.append(boxes[group_indices[0]]) + else: + # Multiple detections to fuse + group_masks_binary = masks_binary[group_indices] # (K, H, W) + group_scores = scores[group_indices] + + # Merge masks: union of all masks in the group + merged_mask_binary = group_masks_binary.any(dim=0) # (H, W) + + # Use the mask logits (before thresholding) and take max for merged regions + group_masks_logits = masks[group_indices].squeeze(1) # (K, H, W) + merged_mask_logits = group_masks_logits.max(dim=0)[0] # (H, W) + # Set merged regions to high confidence + merged_mask_logits = torch.where( + merged_mask_binary, + torch.clamp(merged_mask_logits, min=0.5), + merged_mask_logits + ) + merged_mask_logits = merged_mask_logits.unsqueeze(0).unsqueeze(0) # (1, 1, H, W) + + # Compute bounding box from merged mask + merged_mask_for_box = merged_mask_binary.unsqueeze(0).float() # (1, H, W) + merged_box = box_ops.masks_to_boxes(merged_mask_for_box)[0] # (4,) + + # Use max score from the group + max_score = group_scores.max() + + fused_scores.append(max_score) + fused_masks.append(merged_mask_logits) + fused_boxes.append(merged_box) + + if len(fused_scores) == 0: + return scores, masks, boxes + + fused_scores = torch.stack(fused_scores) + fused_masks = torch.cat(fused_masks, dim=0) + fused_boxes = torch.stack(fused_boxes) + + return fused_scores, fused_masks, fused_boxes diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..46d37d2a --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved diff --git a/tests/test_sam3_image_processor.py b/tests/test_sam3_image_processor.py new file mode 100644 index 00000000..29559fbf --- /dev/null +++ b/tests/test_sam3_image_processor.py @@ -0,0 +1,260 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +""" +Tests for detection fusion functionality in Sam3Processor. + +Note: These tests must be run in an environment with the sam3 package dependencies +installed (torch, torchvision, etc.). Run with: + conda activate sam3 + pytest tests/test_sam3_image_processor.py +""" + +import pytest + +# Check for required dependencies +try: + import torch +except ImportError: + pytest.skip("torch not available", allow_module_level=True) + +try: + from sam3.model.sam3_image_processor import Sam3Processor + from sam3.perflib.masks_ops import mask_iou +except ImportError as e: + pytest.skip(f"sam3 package not available: {e}", allow_module_level=True) + + +class TestSam3ProcessorFusion: + """Tests for detection fusion functionality in Sam3Processor""" + + def test_fusion_disabled_by_default(self): + """Test that fusion is disabled when fuse_detections_iou_threshold is None""" + # Create a mock model (we'll test the fusion logic directly) + class MockModel: + pass + + processor = Sam3Processor(MockModel(), device="cpu") + assert processor.fuse_detections_iou_threshold is None + + def test_fusion_enabled_when_threshold_set(self): + """Test that fusion is enabled when fuse_detections_iou_threshold is set""" + class MockModel: + pass + + processor = Sam3Processor( + MockModel(), device="cpu", fuse_detections_iou_threshold=0.3 + ) + assert processor.fuse_detections_iou_threshold == 0.3 + + def test_fuse_detections_empty_input(self): + """Test fusion with empty detections""" + class MockModel: + pass + + processor = Sam3Processor(MockModel(), device="cpu") + scores = torch.tensor([], dtype=torch.float32) + masks = torch.zeros((0, 1, 10, 10)) + boxes = torch.zeros((0, 4)) + + fused_scores, fused_masks, fused_boxes = processor._fuse_detections( + scores, masks, boxes, iou_threshold=0.3 + ) + + assert len(fused_scores) == 0 + assert fused_masks.shape[0] == 0 + assert len(fused_boxes) == 0 + + def test_fuse_detections_single_detection(self): + """Test fusion with a single detection (should remain unchanged)""" + class MockModel: + pass + + processor = Sam3Processor(MockModel(), device="cpu") + scores = torch.tensor([0.8], dtype=torch.float32) + masks = torch.zeros((1, 1, 10, 10)) + masks[0, 0, 2:5, 2:5] = 1.0 # Small square mask + boxes = torch.tensor([[2.0, 2.0, 5.0, 5.0]], dtype=torch.float32) + + fused_scores, fused_masks, fused_boxes = processor._fuse_detections( + scores, masks, boxes, iou_threshold=0.3 + ) + + assert len(fused_scores) == 1 + assert fused_scores[0] == scores[0] + assert fused_masks.shape == (1, 1, 10, 10) + torch.testing.assert_close(fused_masks, masks) + assert len(fused_boxes) == 1 + + def test_fuse_detections_non_overlapping(self): + """Test that non-overlapping detections are not fused""" + class MockModel: + pass + + processor = Sam3Processor(MockModel(), device="cpu") + scores = torch.tensor([0.8, 0.7], dtype=torch.float32) + masks = torch.zeros((2, 1, 20, 20)) + # Two non-overlapping masks + masks[0, 0, 2:5, 2:5] = 1.0 # Top-left + masks[1, 0, 15:18, 15:18] = 1.0 # Bottom-right + boxes = torch.tensor( + [[2.0, 2.0, 5.0, 5.0], [15.0, 15.0, 18.0, 18.0]], dtype=torch.float32 + ) + + fused_scores, fused_masks, fused_boxes = processor._fuse_detections( + scores, masks, boxes, iou_threshold=0.3 + ) + + # Should still have 2 detections (not fused) + assert len(fused_scores) == 2 + assert fused_masks.shape[0] == 2 + assert len(fused_boxes) == 2 + + def test_fuse_detections_overlapping(self): + """Test that overlapping detections are fused""" + class MockModel: + pass + + processor = Sam3Processor(MockModel(), device="cpu") + scores = torch.tensor([0.8, 0.7], dtype=torch.float32) + masks = torch.zeros((2, 1, 20, 20)) + # Two overlapping masks with high overlap (same center, different sizes) + masks[0, 0, 5:10, 5:10] = 1.0 # 5x5 square + masks[1, 0, 6:11, 6:11] = 1.0 # 5x5 square, overlaps by 4x4 = 16 pixels + boxes = torch.tensor( + [[5.0, 5.0, 10.0, 10.0], [6.0, 6.0, 11.0, 11.0]], dtype=torch.float32 + ) + + fused_scores, fused_masks, fused_boxes = processor._fuse_detections( + scores, masks, boxes, iou_threshold=0.3 + ) + + # Should be fused into 1 detection + assert len(fused_scores) == 1 + assert fused_masks.shape[0] == 1 + assert len(fused_boxes) == 1 + # Score should be max of the two + assert fused_scores[0] == 0.8 + + def test_fuse_detections_mask_union(self): + """Test that fused masks are the union of overlapping masks""" + class MockModel: + pass + + processor = Sam3Processor(MockModel(), device="cpu") + scores = torch.tensor([0.8, 0.7], dtype=torch.float32) + masks = torch.zeros((2, 1, 20, 20)) + # Two overlapping masks with sufficient overlap + masks[0, 0, 5:10, 5:10] = 1.0 # Left square + masks[1, 0, 6:11, 6:11] = 1.0 # Right square, overlaps significantly + boxes = torch.tensor( + [[5.0, 5.0, 10.0, 10.0], [6.0, 6.0, 11.0, 11.0]], dtype=torch.float32 + ) + + fused_scores, fused_masks, fused_boxes = processor._fuse_detections( + scores, masks, boxes, iou_threshold=0.3 + ) + + # Should be fused + assert len(fused_scores) == 1 + + # Check that fused mask contains union of both masks + fused_mask_binary = fused_masks[0, 0] > 0.5 + original_mask1_binary = masks[0, 0] > 0.5 + original_mask2_binary = masks[1, 0] > 0.5 + union_binary = original_mask1_binary | original_mask2_binary + + # Fused mask should be at least as large as union + assert (fused_mask_binary >= union_binary).all() + + def test_fuse_detections_multiple_groups(self): + """Test fusion with multiple groups of overlapping detections""" + class MockModel: + pass + + processor = Sam3Processor(MockModel(), device="cpu") + scores = torch.tensor([0.9, 0.8, 0.7, 0.6], dtype=torch.float32) + masks = torch.zeros((4, 1, 30, 30)) + # Group 1: masks 0 and 1 overlap significantly + masks[0, 0, 5:10, 5:10] = 1.0 + masks[1, 0, 6:11, 6:11] = 1.0 # High overlap + # Group 2: masks 2 and 3 overlap significantly (separate location) + masks[2, 0, 20:25, 20:25] = 1.0 + masks[3, 0, 21:26, 21:26] = 1.0 # High overlap + boxes = torch.tensor( + [ + [5.0, 5.0, 10.0, 10.0], + [6.0, 6.0, 11.0, 11.0], + [20.0, 20.0, 25.0, 25.0], + [21.0, 21.0, 26.0, 26.0], + ], + dtype=torch.float32, + ) + + fused_scores, fused_masks, fused_boxes = processor._fuse_detections( + scores, masks, boxes, iou_threshold=0.3 + ) + + # Should have 2 fused detections (one per group) + assert len(fused_scores) == 2 + assert fused_masks.shape[0] == 2 + assert len(fused_boxes) == 2 + # Scores should be max of each group + assert fused_scores[0] == 0.9 # max(0.9, 0.8) + assert fused_scores[1] == 0.7 # max(0.7, 0.6) + + def test_fuse_detections_iou_threshold(self): + """Test that IoU threshold correctly controls fusion""" + class MockModel: + pass + + processor = Sam3Processor(MockModel(), device="cpu") + scores = torch.tensor([0.8, 0.7], dtype=torch.float32) + masks = torch.zeros((2, 1, 20, 20)) + # Two masks with low overlap + masks[0, 0, 5:10, 5:10] = 1.0 + masks[1, 0, 9:14, 9:14] = 1.0 # Small overlap + + # Compute actual IoU + mask1_binary = masks[0, 0] > 0.5 + mask2_binary = masks[1, 0] > 0.5 + ious = mask_iou(mask1_binary.unsqueeze(0), mask2_binary.unsqueeze(0)) + actual_iou = ious[0, 0].item() + + boxes = torch.tensor( + [[5.0, 5.0, 10.0, 10.0], [9.0, 9.0, 14.0, 14.0]], dtype=torch.float32 + ) + + # With threshold below actual IoU - should fuse + fused_scores_low, _, _ = processor._fuse_detections( + scores, masks, boxes, iou_threshold=actual_iou - 0.1 + ) + assert len(fused_scores_low) == 1 + + # With threshold above actual IoU - should not fuse + fused_scores_high, _, _ = processor._fuse_detections( + scores, masks, boxes, iou_threshold=actual_iou + 0.1 + ) + assert len(fused_scores_high) == 2 + + def test_fuse_detections_score_ordering(self): + """Test that fusion preserves the highest score from each group""" + class MockModel: + pass + + processor = Sam3Processor(MockModel(), device="cpu") + # Lower score first to test that max is used + scores = torch.tensor([0.6, 0.9], dtype=torch.float32) + masks = torch.zeros((2, 1, 20, 20)) + masks[0, 0, 5:10, 5:10] = 1.0 + masks[1, 0, 6:11, 6:11] = 1.0 # High overlap + boxes = torch.tensor( + [[5.0, 5.0, 10.0, 10.0], [6.0, 6.0, 11.0, 11.0]], dtype=torch.float32 + ) + + fused_scores, _, _ = processor._fuse_detections( + scores, masks, boxes, iou_threshold=0.3 + ) + + # Should use max score (0.9) even though it was second + assert len(fused_scores) == 1 + assert fused_scores[0] == 0.9