Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,9 @@ testpaths = ["tests"]
python_files = "test_*.py"
python_classes = "Test*"
python_functions = "test_*"
filterwarnings = [
"ignore::UserWarning:torch.amp.autocast_mode",
"ignore::UserWarning:torch.functional",
"ignore:.*CUDA is not available.*:UserWarning",
"ignore:.*torch.meshgrid.*indexing argument.*:UserWarning",
]
131 changes: 130 additions & 1 deletion sam3/model/sam3_image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,23 @@
from sam3.model import box_ops

from sam3.model.data_misc import FindStage, interpolate
from sam3.perflib.masks_ops import mask_iou
from torchvision.transforms import v2


class Sam3Processor:
""" """

def __init__(self, model, resolution=1008, device="cuda", confidence_threshold=0.5):
def __init__(self, model, resolution=1008, device="cuda", confidence_threshold=0.5, fuse_detections_iou_threshold=None):
"""
Args:
model: The SAM3 model
resolution: Image resolution for processing
device: Device to run on ('cuda' or 'cpu')
confidence_threshold: Minimum score to keep a detection (default: 0.5)
fuse_detections_iou_threshold: IoU threshold for fusing overlapping detections.
If None (default), fusion is disabled. Set to a value (e.g., 0.3) to enable fusion.
"""
self.model = model
self.resolution = resolution
self.device = device
Expand All @@ -27,6 +37,7 @@ def __init__(self, model, resolution=1008, device="cuda", confidence_threshold=0
]
)
self.confidence_threshold = confidence_threshold
self.fuse_detections_iou_threshold = fuse_detections_iou_threshold

self.find_stage = FindStage(
img_ids=torch.tensor([0], device=device, dtype=torch.long),
Expand Down Expand Up @@ -215,8 +226,126 @@ def _forward_grounding(self, state: Dict):
align_corners=False,
).sigmoid()

# Apply detection fusion if enabled (merges overlapping detections)
if self.fuse_detections_iou_threshold is not None and len(out_probs) > 0:
out_probs, out_masks, boxes = self._fuse_detections(
out_probs, out_masks, boxes, self.fuse_detections_iou_threshold
)

state["masks_logits"] = out_masks
state["masks"] = out_masks > 0.5
state["boxes"] = boxes
state["scores"] = out_probs
return state

def _fuse_detections(
self,
scores: torch.Tensor,
masks: torch.Tensor,
boxes: torch.Tensor,
iou_threshold: float,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Fuse overlapping detections by grouping them and merging masks.

Args:
scores: (N,) tensor of detection scores
masks: (N, 1, H, W) tensor of mask logits (before thresholding)
boxes: (N, 4) tensor of bounding boxes in [x0, y0, x1, y1] format
iou_threshold: IoU threshold for grouping detections (detections with IoU > threshold are fused)

Returns:
Fused scores, masks, and boxes tensors
"""
if len(scores) == 0:
return scores, masks, boxes

# Convert masks to binary for IoU computation
masks_binary = (masks.squeeze(1) > 0.5) # (N, H, W)

# Compute pairwise IoU matrix
ious = mask_iou(masks_binary, masks_binary) # (N, N)

# Find connected components based on IoU threshold
# Use Union-Find to group overlapping detections
parent = list(range(len(scores)))

def find(x):
if parent[x] != x:
parent[x] = find(parent[x])
return parent[x]

def union(x, y):
px, py = find(x), find(y)
if px != py:
# Merge into the group with higher score
if scores[px] < scores[py]:
px, py = py, px
parent[py] = px

# Group detections that overlap above threshold
for i in range(len(scores)):
for j in range(i + 1, len(scores)):
if ious[i, j] > iou_threshold:
union(i, j)

# Find unique groups
groups = {}
for i in range(len(scores)):
root = find(i)
if root not in groups:
groups[root] = []
groups[root].append(i)

# Merge each group
fused_scores = []
fused_masks = []
fused_boxes = []

for group_indices in groups.values():
if len(group_indices) == 1:
# Single detection, keep as is - ensure shape is (1, 1, H, W) to match fused masks
fused_scores.append(scores[group_indices[0]])
single_mask = masks[group_indices[0]] # (1, H, W) from (N, 1, H, W)
if single_mask.dim() == 3:
single_mask = single_mask.unsqueeze(0) # (1, 1, H, W)
fused_masks.append(single_mask)
fused_boxes.append(boxes[group_indices[0]])
else:
# Multiple detections to fuse
group_masks_binary = masks_binary[group_indices] # (K, H, W)
group_scores = scores[group_indices]

# Merge masks: union of all masks in the group
merged_mask_binary = group_masks_binary.any(dim=0) # (H, W)

# Use the mask logits (before thresholding) and take max for merged regions
group_masks_logits = masks[group_indices].squeeze(1) # (K, H, W)
merged_mask_logits = group_masks_logits.max(dim=0)[0] # (H, W)
# Set merged regions to high confidence
merged_mask_logits = torch.where(
merged_mask_binary,
torch.clamp(merged_mask_logits, min=0.5),
merged_mask_logits
)
merged_mask_logits = merged_mask_logits.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)

# Compute bounding box from merged mask
merged_mask_for_box = merged_mask_binary.unsqueeze(0).float() # (1, H, W)
merged_box = box_ops.masks_to_boxes(merged_mask_for_box)[0] # (4,)

# Use max score from the group
max_score = group_scores.max()

fused_scores.append(max_score)
fused_masks.append(merged_mask_logits)
fused_boxes.append(merged_box)

if len(fused_scores) == 0:
return scores, masks, boxes

fused_scores = torch.stack(fused_scores)
fused_masks = torch.cat(fused_masks, dim=0)
fused_boxes = torch.stack(fused_boxes)

return fused_scores, fused_masks, fused_boxes
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
Loading