diff --git a/README.md b/README.md
index 669242df..d3493320 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,4 @@
-# SAM 3: Segment Anything with Concepts
+# SAM 3: Segment Anything with Concepts WITH MPS/CPU SUPPORT FOR APPLE METAL
Meta Superintelligence Labs
diff --git a/examples/live_camera_segmentation.py b/examples/live_camera_segmentation.py
new file mode 100644
index 00000000..283149fc
--- /dev/null
+++ b/examples/live_camera_segmentation.py
@@ -0,0 +1,970 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
+
+"""
+Live Camera Segmentation with SAM3
+
+This script captures video from a device camera and runs real-time segmentation
+using SAM3. It supports text-based detection or interactive point/box prompts.
+
+Usage:
+ # Detect objects using text prompt
+ python live_camera_segmentation.py --prompt "person"
+
+ # Detect multiple object types using comma-separated prompts
+ python live_camera_segmentation.py --prompt "person, car, dog, cat"
+
+ # Use specific camera device
+ python live_camera_segmentation.py --camera 0 --prompt "cat"
+
+ # Specify device (cuda, mps, or cpu)
+ python live_camera_segmentation.py --device mps --prompt "dog"
+
+ # Interactive mode - click to add box prompts
+ python live_camera_segmentation.py --interactive
+
+ # Skip frames with tracking (masks follow objects between full inference frames)
+ python live_camera_segmentation.py --prompt "person" --skip-frames 5 --track
+
+Controls:
+ - 'q' or ESC: Quit
+ - 'r': Reset/clear all segments
+ - 's': Save current frame
+ - 'p': Pause/resume
+ - Left click + drag: Draw box prompt (in interactive mode)
+ - 't': Enter new text prompt
+"""
+
+import argparse
+import time
+from collections import deque
+from typing import Optional, Tuple
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+
+from sam3.utils.device import get_device, get_device_str, setup_device_optimizations
+
+
+class LiveCameraSegmenter:
+ """Real-time camera segmentation using SAM3."""
+
+ # Color palette for different object masks (BGR format for OpenCV)
+ COLORS = [
+ (255, 0, 0), # Blue
+ (0, 255, 0), # Green
+ (0, 0, 255), # Red
+ (255, 255, 0), # Cyan
+ (255, 0, 255), # Magenta
+ (0, 255, 255), # Yellow
+ (128, 0, 255), # Purple
+ (255, 128, 0), # Orange
+ (0, 128, 255), # Light blue
+ (128, 255, 0), # Lime
+ ]
+
+ def __init__(
+ self,
+ camera_id: int = 0,
+ device: Optional[str] = None,
+ text_prompt: str = "object",
+ confidence_threshold: float = 0.3,
+ checkpoint_path: Optional[str] = None,
+ interactive: bool = False,
+ process_every_n_frames: int = 1,
+ use_half_precision: bool = False,
+ enable_tracking: bool = False,
+ ):
+ """
+ Initialize the live camera segmenter.
+
+ Args:
+ camera_id: Camera device ID (default 0 for primary camera)
+ device: Device to run on ('cuda', 'mps', 'cpu', or None for auto)
+ text_prompt: Text description of objects to detect
+ confidence_threshold: Confidence threshold for detections
+ checkpoint_path: Optional path to model checkpoint
+ interactive: Enable interactive box-based prompting
+ process_every_n_frames: Only process every N frames (higher = faster but less smooth)
+ use_half_precision: Use float16 for faster inference (may reduce accuracy)
+ enable_tracking: Enable mask tracking between skipped frames
+ """
+ self.camera_id = camera_id
+ self.device_str = device if device else get_device_str()
+ self.device = torch.device(self.device_str)
+ self.text_prompt = text_prompt
+ self.confidence_threshold = confidence_threshold
+ self.interactive = interactive
+ self.process_every_n_frames = process_every_n_frames
+ self.use_half_precision = use_half_precision
+ self.enable_tracking = enable_tracking
+ self.frame_count = 0
+
+ # State
+ self.paused = False
+ self.state = None
+ self.fps_history = deque(maxlen=30)
+
+ # Tracking state
+ self.tracker = None
+ self.tracker_state = None
+ self.last_masks = None
+ self.last_boxes = None
+ self.last_scores = None # Store confidence scores
+ self.last_labels = None # Store per-object labels for multi-prompt mode
+ self.video_height = None
+ self.video_width = None
+
+ # For interactive box drawing
+ self.drawing = False
+ self.box_start = None
+ self.box_end = None
+
+ print(f"Initializing SAM3 on device: {self.device}")
+ self._load_model(checkpoint_path)
+
+ def _load_model(self, checkpoint_path: Optional[str] = None):
+ """Load the SAM3 model and processor."""
+ from sam3.model_builder import build_sam3_image_model
+ from sam3.model.sam3_image_processor import Sam3Processor
+
+ # Setup device-specific optimizations (MPS memory, CUDA TF32, etc.)
+ setup_device_optimizations()
+
+ print("Loading SAM3 model...")
+ model = build_sam3_image_model(
+ device=self.device_str,
+ checkpoint_path=checkpoint_path,
+ load_from_HF=checkpoint_path is None,
+ eval_mode=True,
+ enable_segmentation=True,
+ )
+
+ # Convert to half precision for faster inference (CUDA only - MPS doesn't support it)
+ if self.use_half_precision:
+ if self.device_str == "mps":
+ print("Warning: Half precision not supported on MPS due to Metal limitations, using float32")
+ self.use_half_precision = False
+ else:
+ print("Converting model to half precision (float16)...")
+ model = model.half()
+
+ self.processor = Sam3Processor(
+ model=model,
+ resolution=1008, # Fixed resolution due to precomputed positional encodings
+ device=self.device_str,
+ confidence_threshold=self.confidence_threshold,
+ )
+ print("Model loaded successfully!")
+
+ # For tracking between keyframes, we use optical flow instead of the full SAM3 tracker
+ # This provides lightweight motion-based tracking without device compatibility issues
+ if self.enable_tracking:
+ print("Tracking mode enabled - using optical flow for inter-frame tracking")
+ self.prev_gray = None # Store previous frame for optical flow
+
+ def _load_tracker(self, checkpoint_path: Optional[str] = None):
+ """Load the SAM3 tracker for mask propagation between frames."""
+ from sam3.model_builder import build_tracker
+
+ print("Loading SAM3 tracker for inter-frame tracking...")
+
+ # Build tracker with backbone for processing new frames
+ self.tracker = build_tracker(
+ apply_temporal_disambiguation=True,
+ with_backbone=True,
+ )
+ self.tracker = self.tracker.to(self.device)
+ self.tracker.eval()
+
+ # Try to load tracker weights from the same source as the main model
+ # The tracker shares weights with the main SAM3 model
+ import os
+ tracker_ckpt_path = None
+
+ # Use provided checkpoint path first
+ if checkpoint_path and os.path.exists(checkpoint_path):
+ tracker_ckpt_path = checkpoint_path
+ else:
+ # Check common locations for the checkpoint
+ # Get the directory where this script is located
+ script_dir = os.path.dirname(os.path.abspath(__file__))
+ possible_paths = [
+ os.path.join(script_dir, "sam3.pt"), # Same folder as script (examples/)
+ "sam3.pt",
+ "./sam3.pt",
+ "../sam3.pt",
+ "examples/sam3.pt",
+ os.path.expanduser("~/.cache/huggingface/hub/models--facebook--sam3/sam3.pt"),
+ ]
+
+ for path in possible_paths:
+ if os.path.exists(path):
+ tracker_ckpt_path = path
+ break
+
+ if tracker_ckpt_path is None:
+ print("Warning: Could not find sam3.pt checkpoint for tracker.")
+ print("Please ensure sam3.pt is in the current directory or provide --checkpoint path.")
+ print("Tracking will be disabled.")
+ self.tracker = None
+ return
+
+ print(f"Loading tracker weights from: {tracker_ckpt_path}")
+ tracker_state_dict = torch.load(tracker_ckpt_path, map_location=self.device, weights_only=False)
+
+ # Filter and load tracker-compatible weights
+ tracker_keys = set(k for k in self.tracker.state_dict().keys())
+ filtered_state_dict = {k: v for k, v in tracker_state_dict.items() if k in tracker_keys}
+ self.tracker.load_state_dict(filtered_state_dict, strict=False)
+
+ print("Tracker loaded successfully!")
+
+ def _init_tracker_state(self, height: int, width: int):
+ """Initialize tracking state for a video stream."""
+ self.video_height = height
+ self.video_width = width
+ # Reset masks and optical flow state
+ self.last_masks = None
+ self.last_boxes = None
+ self.last_scores = None
+ self.last_labels = None
+ self.prev_gray = None
+
+ def _track_frame(self, frame: np.ndarray, frame_idx: int) -> Optional[torch.Tensor]:
+ """
+ Use optical flow to track masks to a new frame.
+
+ This provides lightweight motion-based tracking between keyframes
+ without needing the full SAM3 tracker model.
+
+ Returns the tracked masks or None if tracking isn't available.
+ """
+ if self.last_masks is None or len(self.last_masks) == 0:
+ return None
+
+ if self.prev_gray is None:
+ return self.last_masks
+
+ try:
+ # Convert current frame to grayscale
+ curr_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+
+ # Calculate dense optical flow using Farneback method
+ flow = cv2.calcOpticalFlowFarneback(
+ self.prev_gray, curr_gray,
+ None,
+ pyr_scale=0.5,
+ levels=3,
+ winsize=15,
+ iterations=3,
+ poly_n=5,
+ poly_sigma=1.2,
+ flags=0
+ )
+
+ # Create coordinate grids for remapping
+ h, w = curr_gray.shape
+ flow_map_x = np.arange(w).reshape(1, -1).repeat(h, axis=0).astype(np.float32)
+ flow_map_y = np.arange(h).reshape(-1, 1).repeat(w, axis=1).astype(np.float32)
+
+ # Add flow to get new positions
+ flow_map_x += flow[:, :, 0]
+ flow_map_y += flow[:, :, 1]
+
+ # Warp each mask using the flow
+ tracked_masks = []
+ for mask in self.last_masks:
+ # Convert mask to numpy for warping
+ if isinstance(mask, torch.Tensor):
+ mask_np = mask.cpu().numpy().squeeze()
+ else:
+ mask_np = mask.squeeze()
+
+ # Ensure mask is the right size
+ if mask_np.shape != (h, w):
+ mask_np = cv2.resize(mask_np.astype(np.float32), (w, h))
+
+ # Warp mask using optical flow
+ warped_mask = cv2.remap(
+ mask_np.astype(np.float32),
+ flow_map_x, flow_map_y,
+ interpolation=cv2.INTER_LINEAR,
+ borderMode=cv2.BORDER_CONSTANT,
+ borderValue=0
+ )
+
+ # Threshold to get binary mask
+ warped_mask = (warped_mask > 0.5).astype(np.float32)
+
+ # Convert back to tensor
+ tracked_masks.append(
+ torch.from_numpy(warped_mask).unsqueeze(0).to(self.device)
+ )
+
+ # Update prev_gray for next iteration
+ self.prev_gray = curr_gray
+
+ if tracked_masks:
+ return torch.stack(tracked_masks)
+
+ except Exception as e:
+ print(f"Optical flow tracking error: {e}")
+
+ return self.last_masks
+
+ def _add_mask_to_tracker(self, masks: torch.Tensor, frame: np.ndarray, frame_idx: int):
+ """Store frame for optical flow tracking."""
+ # Store grayscale frame for optical flow computation
+ self.prev_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+ # Masks are already stored in self.last_masks by the caller
+
+ def _process_frame(self, frame: np.ndarray) -> dict:
+ """Process a frame through SAM3."""
+ # Convert BGR to RGB PIL Image
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ pil_image = Image.fromarray(frame_rgb)
+
+ # Set the image
+ self.state = self.processor.set_image(pil_image, self.state)
+
+ # Run text-based detection
+ if not self.interactive:
+ # Support multiple prompts separated by commas
+ prompts = [p.strip() for p in self.text_prompt.split(',')]
+
+ if len(prompts) == 1:
+ # Single prompt - use normal detection
+ self.state = self.processor.set_text_prompt(prompts[0], self.state)
+ else:
+ # Multiple prompts - run detection for each and combine results
+ all_masks = []
+ all_boxes = []
+ all_scores = []
+ all_labels = []
+
+ for prompt in prompts:
+ # Reset geometric prompt for each detection
+ if "geometric_prompt" in self.state:
+ del self.state["geometric_prompt"]
+
+ self.state = self.processor.set_text_prompt(prompt, self.state)
+
+ masks = self.state.get("masks")
+ boxes = self.state.get("boxes")
+ scores = self.state.get("scores")
+
+ if masks is not None and masks.numel() > 0:
+ for i in range(len(masks)):
+ all_masks.append(masks[i:i+1])
+ if boxes is not None and i < len(boxes):
+ all_boxes.append(boxes[i:i+1])
+ if scores is not None and i < len(scores):
+ all_scores.append(scores[i:i+1])
+ all_labels.append(prompt)
+
+ # Combine all detections
+ if all_masks:
+ self.state["masks"] = torch.cat(all_masks, dim=0)
+ self.state["boxes"] = torch.cat(all_boxes, dim=0) if all_boxes else None
+ self.state["scores"] = torch.cat(all_scores, dim=0) if all_scores else None
+ self.state["labels"] = all_labels # Store labels for each detection
+ else:
+ self.state["masks"] = None
+ self.state["boxes"] = None
+ self.state["scores"] = None
+ self.state["labels"] = []
+
+ return self.state
+
+ def _add_box_prompt(self, box: Tuple[int, int, int, int], frame_size: Tuple[int, int]):
+ """Add a box prompt in interactive mode."""
+ if self.state is None:
+ return
+
+ h, w = frame_size
+ x1, y1, x2, y2 = box
+
+ # Convert to center format and normalize to [0, 1]
+ cx = (x1 + x2) / 2 / w
+ cy = (y1 + y2) / 2 / h
+ bw = abs(x2 - x1) / w
+ bh = abs(y2 - y1) / h
+
+ normalized_box = [cx, cy, bw, bh]
+ self.state = self.processor.add_geometric_prompt(
+ box=normalized_box,
+ label=True, # Positive box
+ state=self.state,
+ )
+
+ def _overlay_masks(
+ self,
+ frame: np.ndarray,
+ masks: torch.Tensor,
+ boxes: torch.Tensor = None,
+ scores: torch.Tensor = None,
+ labels: list = None,
+ alpha: float = 0.5,
+ ) -> np.ndarray:
+ """Overlay segmentation masks on the frame with labels and confidence scores."""
+ if masks is None or masks.numel() == 0:
+ return frame
+
+ overlay = frame.copy()
+ h, w = frame.shape[:2]
+
+ # masks shape: [N, 1, H, W]
+ masks_np = masks.squeeze(1).cpu().numpy()
+
+ # Get scores if available
+ scores_np = None
+ if scores is not None:
+ scores_np = scores.cpu().numpy()
+
+ # Get boxes if available
+ boxes_np = None
+ if boxes is not None:
+ boxes_np = boxes.cpu().numpy()
+
+ for i, mask in enumerate(masks_np):
+ # Resize mask to frame size if needed
+ if mask.shape != (h, w):
+ mask = cv2.resize(mask.astype(np.float32), (w, h)) > 0.5
+
+ # Get color for this mask
+ color = self.COLORS[i % len(self.COLORS)]
+
+ # Create colored overlay
+ mask_region = mask.astype(bool)
+ overlay[mask_region] = (
+ overlay[mask_region] * (1 - alpha) +
+ np.array(color) * alpha
+ ).astype(np.uint8)
+
+ # Draw contour
+ contours, _ = cv2.findContours(
+ mask.astype(np.uint8),
+ cv2.RETR_EXTERNAL,
+ cv2.CHAIN_APPROX_SIMPLE
+ )
+ cv2.drawContours(overlay, contours, -1, color, 2)
+
+ # Draw label with confidence score
+ # Find the top-center of the mask for label placement
+ if len(contours) > 0:
+ # Get bounding rect of largest contour
+ largest_contour = max(contours, key=cv2.contourArea)
+ x, y, cw, ch = cv2.boundingRect(largest_contour)
+
+ # Get confidence score
+ conf = scores_np[i] if scores_np is not None and i < len(scores_np) else 0.0
+
+ # Get label - use per-object label if available, otherwise use prompt
+ if labels is not None and i < len(labels):
+ obj_label = labels[i]
+ else:
+ obj_label = self.text_prompt.split(',')[0].strip() # Use first prompt as fallback
+
+ label = f"{obj_label}"
+ conf_text = f"{conf:.0%}"
+
+ # Draw label background
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ font_scale = 0.5
+ thickness = 1
+
+ # Get text sizes
+ (label_w, label_h), _ = cv2.getTextSize(label, font, font_scale, thickness)
+ (conf_w, conf_h), _ = cv2.getTextSize(conf_text, font, font_scale, thickness)
+
+ # Position at top of bounding box
+ label_x = x + cw // 2 - label_w // 2
+ label_y = max(y - 5, label_h + 5)
+
+ # Draw label background
+ cv2.rectangle(overlay,
+ (label_x - 2, label_y - label_h - 2),
+ (label_x + label_w + 2, label_y + 2),
+ color, -1)
+
+ # Draw label text
+ cv2.putText(overlay, label,
+ (label_x, label_y),
+ font, font_scale, (255, 255, 255), thickness)
+
+ # Draw confidence below label
+ conf_x = x + cw // 2 - conf_w // 2
+ conf_y = label_y + conf_h + 8
+
+ cv2.rectangle(overlay,
+ (conf_x - 2, conf_y - conf_h - 2),
+ (conf_x + conf_w + 2, conf_y + 2),
+ (0, 0, 0), -1)
+ cv2.putText(overlay, conf_text,
+ (conf_x, conf_y),
+ font, font_scale, (0, 255, 0), thickness)
+
+ return overlay
+
+ def _draw_boxes(self, frame: np.ndarray, boxes: torch.Tensor, scores: torch.Tensor = None) -> np.ndarray:
+ """Draw bounding boxes on the frame with labels."""
+ if boxes is None or boxes.numel() == 0:
+ return frame
+
+ boxes_np = boxes.cpu().numpy()
+ scores_np = scores.cpu().numpy() if scores is not None else None
+
+ for i, box in enumerate(boxes_np):
+ x1, y1, x2, y2 = box.astype(int)
+ color = self.COLORS[i % len(self.COLORS)]
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
+
+ return frame
+
+ def _draw_object_panel(self, frame: np.ndarray, masks: torch.Tensor,
+ boxes: torch.Tensor, scores: torch.Tensor,
+ labels: list = None) -> np.ndarray:
+ """Draw an info panel on the right side showing detected objects."""
+ h, w = frame.shape[:2]
+
+ # Panel dimensions
+ panel_width = 200
+ panel_x = w - panel_width - 10
+
+ # Count objects
+ num_objects = len(masks) if masks is not None else 0
+
+ # Calculate panel height based on number of objects
+ header_height = 40
+ object_height = 50
+ panel_height = header_height + max(num_objects, 1) * object_height + 20
+
+ # Draw semi-transparent panel background
+ overlay = frame.copy()
+ cv2.rectangle(overlay,
+ (panel_x, 10),
+ (w - 10, min(10 + panel_height, h - 10)),
+ (0, 0, 0), -1)
+ frame = cv2.addWeighted(overlay, 0.7, frame, 0.3, 0)
+
+ # Draw panel header
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ cv2.putText(frame, "DETECTED OBJECTS",
+ (panel_x + 10, 35),
+ font, 0.5, (255, 255, 255), 1)
+ cv2.line(frame, (panel_x + 5, 45), (w - 15, 45), (100, 100, 100), 1)
+
+ if num_objects == 0:
+ cv2.putText(frame, "No objects found",
+ (panel_x + 10, 75),
+ font, 0.4, (150, 150, 150), 1)
+ return frame
+
+ # Draw each object
+ masks_np = masks.squeeze(1).cpu().numpy() if masks is not None else []
+ scores_np = scores.cpu().numpy() if scores is not None else []
+ boxes_np = boxes.cpu().numpy() if boxes is not None else []
+
+ for i in range(num_objects):
+ y_offset = header_height + 15 + i * object_height
+
+ if 10 + y_offset + 40 > h - 10:
+ # Panel would exceed frame height
+ cv2.putText(frame, f"... +{num_objects - i} more",
+ (panel_x + 10, 10 + y_offset),
+ font, 0.4, (150, 150, 150), 1)
+ break
+
+ color = self.COLORS[i % len(self.COLORS)]
+
+ # Color indicator
+ cv2.rectangle(frame,
+ (panel_x + 10, 10 + y_offset),
+ (panel_x + 25, 10 + y_offset + 15),
+ color, -1)
+
+ # Object label - use per-object label if available
+ if labels is not None and i < len(labels):
+ obj_label = labels[i]
+ else:
+ obj_label = self.text_prompt.split(',')[0].strip()
+
+ # Truncate label if too long
+ if len(obj_label) > 15:
+ obj_label = obj_label[:12] + "..."
+
+ cv2.putText(frame, obj_label,
+ (panel_x + 35, 10 + y_offset + 12),
+ font, 0.4, (255, 255, 255), 1)
+
+ # Confidence score
+ if i < len(scores_np):
+ conf = scores_np[i]
+ conf_color = (0, 255, 0) if conf > 0.7 else (0, 255, 255) if conf > 0.4 else (0, 0, 255)
+ cv2.putText(frame, f"Conf: {conf:.0%}",
+ (panel_x + 35, 10 + y_offset + 28),
+ font, 0.35, conf_color, 1)
+
+ # Bounding box size
+ if i < len(boxes_np):
+ box = boxes_np[i]
+ bw = int(box[2] - box[0])
+ bh = int(box[3] - box[1])
+ cv2.putText(frame, f"Size: {bw}x{bh}",
+ (panel_x + 100, 10 + y_offset + 28),
+ font, 0.35, (150, 150, 150), 1)
+
+ return frame
+
+ def _draw_info(self, frame: np.ndarray, fps: float, num_objects: int) -> np.ndarray:
+ """Draw information overlay on the frame."""
+ h, w = frame.shape[:2]
+
+ # Semi-transparent background for text
+ overlay = frame.copy()
+ info_height = 165 if self.enable_tracking else 140
+ cv2.rectangle(overlay, (10, 10), (350, info_height), (0, 0, 0), -1)
+ frame = cv2.addWeighted(overlay, 0.3, frame, 0.7, 0)
+
+ # Draw text
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ cv2.putText(frame, f"FPS: {fps:.1f}", (20, 35), font, 0.6, (255, 255, 255), 2)
+ cv2.putText(frame, f"Objects: {num_objects}", (20, 60), font, 0.6, (255, 255, 255), 2)
+ cv2.putText(frame, f"Device: {self.device_str}", (20, 85), font, 0.6, (255, 255, 255), 2)
+
+ mode = "Interactive" if self.interactive else f"Prompt: {self.text_prompt}"
+ cv2.putText(frame, f"Mode: {mode}", (20, 110), font, 0.6, (255, 255, 255), 2)
+ cv2.putText(frame, f"Threshold: {self.confidence_threshold:.2f}", (20, 135), font, 0.6, (255, 255, 255), 2)
+
+ if self.enable_tracking:
+ skip_info = f"Skip: {self.process_every_n_frames} (tracking ON)"
+ cv2.putText(frame, skip_info, (20, 160), font, 0.6, (0, 255, 0), 2)
+
+ # Draw controls hint at bottom
+ hint = "Q: Quit | R: Reset | S: Save | P: Pause | T: New prompt"
+ cv2.putText(frame, hint, (10, h - 10), font, 0.4, (200, 200, 200), 1)
+
+ return frame
+
+ def _draw_current_box(self, frame: np.ndarray) -> np.ndarray:
+ """Draw the box currently being drawn."""
+ if self.drawing and self.box_start and self.box_end:
+ cv2.rectangle(
+ frame,
+ self.box_start,
+ self.box_end,
+ (0, 255, 0),
+ 2
+ )
+ return frame
+
+ def _mouse_callback(self, event, x, y, flags, param):
+ """Handle mouse events for interactive mode."""
+ if not self.interactive:
+ return
+
+ if event == cv2.EVENT_LBUTTONDOWN:
+ self.drawing = True
+ self.box_start = (x, y)
+ self.box_end = (x, y)
+
+ elif event == cv2.EVENT_MOUSEMOVE:
+ if self.drawing:
+ self.box_end = (x, y)
+
+ elif event == cv2.EVENT_LBUTTONUP:
+ if self.drawing:
+ self.drawing = False
+ self.box_end = (x, y)
+
+ # Add the box prompt if it's a valid box
+ x1, y1 = self.box_start
+ x2, y2 = self.box_end
+ if abs(x2 - x1) > 5 and abs(y2 - y1) > 5:
+ frame_size = param # Passed as param
+ self._add_box_prompt((x1, y1, x2, y2), frame_size)
+
+ self.box_start = None
+ self.box_end = None
+
+ def run(self):
+ """Run the live camera segmentation loop."""
+ # Open camera
+ print(f"Opening camera {self.camera_id}...")
+ cap = cv2.VideoCapture(self.camera_id)
+
+ if not cap.isOpened():
+ print(f"Error: Could not open camera {self.camera_id}")
+ return
+
+ # Get camera properties
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ print(f"Camera resolution: {frame_width}x{frame_height}")
+
+ # Initialize tracker state if tracking is enabled
+ if self.enable_tracking:
+ print("Initializing tracker state...")
+ self._init_tracker_state(frame_height, frame_width)
+
+ # Create window
+ window_name = "SAM3 Live Segmentation"
+ cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
+ cv2.setMouseCallback(window_name, self._mouse_callback, (frame_height, frame_width))
+
+ print("\nStarting live segmentation...")
+ print("Controls:")
+ print(" Q/ESC: Quit")
+ print(" R: Reset segments")
+ print(" S: Save frame")
+ print(" P: Pause/resume")
+ print(" T: Enter new text prompt")
+ if self.interactive:
+ print(" Left click + drag: Draw box prompt")
+
+ frame_count = 0
+ try:
+ while True:
+ start_time = time.time()
+
+ # Capture frame
+ ret, frame = cap.read()
+ if not ret:
+ print("Failed to grab frame")
+ break
+
+ display_frame = frame.copy()
+ self.frame_count += 1
+
+ if not self.paused:
+ is_keyframe = self.frame_count % self.process_every_n_frames == 0
+
+ if is_keyframe:
+ # Full inference frame - run text detection
+ self._process_frame(frame)
+
+ # Store masks, boxes, scores, and labels for tracking
+ if self.state is not None:
+ self.last_masks = self.state.get("masks")
+ self.last_boxes = self.state.get("boxes")
+ self.last_scores = self.state.get("scores")
+ self.last_labels = self.state.get("labels")
+
+ # Add masks to tracker for memory-based propagation
+ if self.enable_tracking and self.last_masks is not None:
+ self._add_mask_to_tracker(self.last_masks, frame, self.frame_count)
+
+ elif self.enable_tracking and self.last_masks is not None:
+ # Intermediate frame - use tracker to propagate masks
+ tracked_masks = self._track_frame(frame, self.frame_count)
+ if tracked_masks is not None:
+ self.last_masks = tracked_masks
+ # Update state with tracked masks
+ if self.state is not None:
+ self.state["masks"] = tracked_masks
+ # else: Just reuse last masks (no tracking)
+
+ # Overlay results - use last_masks if tracking is enabled
+ masks_to_display = None
+ boxes_to_display = None
+ scores_to_display = None
+ labels_to_display = None
+
+ if self.enable_tracking:
+ masks_to_display = self.last_masks
+ boxes_to_display = self.last_boxes
+ scores_to_display = self.last_scores
+ labels_to_display = self.last_labels
+ elif self.state is not None:
+ masks_to_display = self.state.get("masks")
+ boxes_to_display = self.state.get("boxes")
+ scores_to_display = self.state.get("scores")
+ labels_to_display = self.state.get("labels")
+
+ if masks_to_display is not None:
+ display_frame = self._overlay_masks(
+ display_frame, masks_to_display,
+ boxes=boxes_to_display, scores=scores_to_display,
+ labels=labels_to_display
+ )
+ if boxes_to_display is not None:
+ display_frame = self._draw_boxes(display_frame, boxes_to_display, scores_to_display)
+
+ # Draw object info panel on the right
+ display_frame = self._draw_object_panel(
+ display_frame, masks_to_display, boxes_to_display, scores_to_display,
+ labels=labels_to_display
+ )
+
+ # Draw current box being drawn
+ if self.interactive:
+ display_frame = self._draw_current_box(display_frame)
+
+ # Calculate FPS
+ elapsed = time.time() - start_time
+ fps = 1.0 / elapsed if elapsed > 0 else 0
+ self.fps_history.append(fps)
+ avg_fps = sum(self.fps_history) / len(self.fps_history)
+
+ # Draw info overlay
+ num_objects = 0
+ if masks_to_display is not None:
+ num_objects = len(masks_to_display)
+ display_frame = self._draw_info(display_frame, avg_fps, num_objects)
+
+ # Show frame
+ cv2.imshow(window_name, display_frame)
+
+ # Handle keyboard input
+ key = cv2.waitKey(1) & 0xFF
+
+ if key == ord('q') or key == 27: # Q or ESC
+ print("Quitting...")
+ break
+
+ elif key == ord('r'): # Reset
+ print("Resetting segments...")
+ if self.state is not None:
+ self.processor.reset_all_prompts(self.state)
+ self.state = None
+ self.last_masks = None
+ self.last_boxes = None
+ self.last_scores = None
+ self.last_labels = None
+ # Reset tracker state
+ if self.enable_tracking:
+ self._init_tracker_state(frame_height, frame_width)
+
+ elif key == ord('s'): # Save
+ filename = f"sam3_capture_{frame_count}.png"
+ cv2.imwrite(filename, display_frame)
+ print(f"Saved frame to {filename}")
+
+ elif key == ord('p'): # Pause
+ self.paused = not self.paused
+ print("Paused" if self.paused else "Resumed")
+
+ elif key == ord('t'): # New text prompt
+ self.paused = True
+ new_prompt = input("Enter new text prompt: ").strip()
+ if new_prompt:
+ self.text_prompt = new_prompt
+ if self.state is not None:
+ self.processor.reset_all_prompts(self.state)
+ self.state = None
+ self.last_masks = None
+ self.last_boxes = None
+ self.last_scores = None
+ self.last_labels = None
+ # Reset tracker for new prompt
+ if self.enable_tracking:
+ self._init_tracker_state(frame_height, frame_width)
+ print(f"Text prompt set to: {self.text_prompt}")
+ self.paused = False
+
+ frame_count += 1
+
+ except KeyboardInterrupt:
+ print("\nInterrupted by user")
+
+ finally:
+ cap.release()
+ cv2.destroyAllWindows()
+ print("Cleanup complete")
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Live Camera Segmentation with SAM3",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog=__doc__,
+ )
+ parser.add_argument(
+ "--camera", "-c",
+ type=int,
+ default=0,
+ help="Camera device ID (default: 0)",
+ )
+ parser.add_argument(
+ "--device", "-d",
+ type=str,
+ default=None,
+ choices=["cuda", "mps", "cpu"],
+ help="Device to run on (default: auto-detect)",
+ )
+ parser.add_argument(
+ "--prompt",
+ type=str,
+ default="object",
+ help="Text prompt for detection (default: 'object')",
+ )
+ parser.add_argument(
+ "--threshold",
+ type=float,
+ default=0.3,
+ help="Detection confidence threshold (default: 0.3)",
+ )
+ parser.add_argument(
+ "--checkpoint",
+ type=str,
+ default=None,
+ help="Path to model checkpoint (default: download from HuggingFace)",
+ )
+ parser.add_argument(
+ "--interactive", "-i",
+ action="store_true",
+ help="Enable interactive box-based prompting",
+ )
+ parser.add_argument(
+ "--skip-frames",
+ type=int,
+ default=1,
+ help="Process every N frames (higher = faster, default: 1)",
+ )
+ parser.add_argument(
+ "--half",
+ action="store_true",
+ help="Use half precision (float16) for faster inference",
+ )
+ parser.add_argument(
+ "--track",
+ action="store_true",
+ help="Enable mask tracking between skipped frames (smoother results when using --skip-frames)",
+ )
+
+ args = parser.parse_args()
+
+ # Print device info
+ device = args.device or get_device_str()
+ print(f"SAM3 Live Camera Segmentation")
+ print(f"=" * 40)
+ print(f"Device: {device}")
+ print(f"Camera: {args.camera}")
+ print(f"Text prompt: {args.prompt}")
+ print(f"Threshold: {args.threshold}")
+ print(f"Interactive: {args.interactive}")
+ print(f"Skip frames: {args.skip_frames}")
+ print(f"Half precision: {args.half}")
+ print(f"Tracking: {args.track}")
+ print(f"=" * 40)
+
+ # Create and run segmenter
+ segmenter = LiveCameraSegmenter(
+ camera_id=args.camera,
+ device=args.device,
+ text_prompt=args.prompt,
+ confidence_threshold=args.threshold,
+ checkpoint_path=args.checkpoint,
+ interactive=args.interactive,
+ process_every_n_frames=args.skip_frames,
+ use_half_precision=args.half,
+ enable_tracking=args.track,
+ )
+ segmenter.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/web_command_center/app.py b/examples/web_command_center/app.py
new file mode 100644
index 00000000..a2f17424
--- /dev/null
+++ b/examples/web_command_center/app.py
@@ -0,0 +1,5374 @@
+#!/usr/bin/env python3
+# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
+
+"""
+SAM3 Web Command Center
+
+A Flask-based web interface for real-time object detection and tracking
+using SAM3. Features include:
+- Live camera feed with segmentation overlay
+- Multi-prompt detection configuration
+- Object count limits with show/hide functionality
+- Claude Vision API integration for detailed object analysis
+- Video tracking with memory (SAM3 tracker)
+- Multi-object tracking with persistent IDs
+- Mask refinement (fill holes, non-overlap)
+- Advanced detection controls (boundary/occlusion suppression, hotstart)
+- YOLO integration for classification and pose estimation
+- Command center style interface with verbose logging
+
+Usage:
+ python app.py --prompt "person, car" --camera 0
+
+Then open http://localhost:5000 in your browser.
+"""
+
+import argparse
+import base64
+import io
+import ipaddress
+import json
+import os
+import sqlite3
+import ssl
+import sys
+import threading
+import time
+import uuid
+from collections import deque
+from datetime import datetime
+from typing import Optional, Dict, List, Any, Tuple
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+from flask import Flask, Response, render_template, request, jsonify
+from scipy import ndimage
+
+# Load environment variables from .env file if present
+try:
+ from dotenv import load_dotenv
+ # Look for .env in the web_command_center directory
+ env_path = os.path.join(os.path.dirname(__file__), '.env')
+ if os.path.exists(env_path):
+ load_dotenv(env_path)
+ print(f"Loaded environment from {env_path}")
+ else:
+ # Also check current working directory
+ load_dotenv()
+except ImportError:
+ pass # python-dotenv not installed, rely on system environment
+
+# Add parent directory to path for sam3 imports
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
+
+from sam3.utils.device import get_device, get_device_str, setup_device_optimizations, empty_cache
+
+app = Flask(__name__)
+
+# Global API key storage (can be set via CLI arg or environment)
+ANTHROPIC_API_KEY = os.environ.get('ANTHROPIC_API_KEY')
+
+
+# ===== SAM3 to COCO Label Mapping =====
+# Maps open-vocabulary SAM3 labels to COCO class indices for YOLO
+SAM3_TO_COCO = {
+ # Person variations -> COCO class 0
+ "person": 0, "human": 0, "man": 0, "woman": 0, "child": 0, "kid": 0,
+ "boy": 0, "girl": 0, "people": 0, "pedestrian": 0, "worker": 0,
+ "player": 0, "athlete": 0, "runner": 0, "cyclist": 0,
+
+ # Vehicles
+ "bicycle": 1, "bike": 1, "cycle": 1,
+ "car": 2, "automobile": 2, "vehicle": 2, "sedan": 2, "suv": 2,
+ "motorcycle": 3, "motorbike": 3, "scooter": 3,
+ "airplane": 4, "plane": 4, "aircraft": 4, "jet": 4,
+ "bus": 5, "coach": 5,
+ "train": 6, "locomotive": 6, "railway": 6,
+ "truck": 7, "lorry": 7, "pickup": 7,
+ "boat": 8, "ship": 8, "vessel": 8, "yacht": 8,
+
+ # Traffic
+ "traffic light": 9, "stoplight": 9,
+ "fire hydrant": 10, "hydrant": 10,
+ "stop sign": 11,
+ "parking meter": 12,
+
+ # Animals
+ "bird": 14, "sparrow": 14, "pigeon": 14, "crow": 14,
+ "cat": 15, "kitten": 15, "feline": 15, "kitty": 15,
+ "dog": 16, "puppy": 16, "canine": 16, "hound": 16,
+ "horse": 17, "pony": 17, "stallion": 17, "mare": 17,
+ "sheep": 18, "lamb": 18,
+ "cow": 19, "cattle": 19, "bull": 19,
+ "elephant": 20,
+ "bear": 21, "grizzly": 21,
+ "zebra": 22,
+ "giraffe": 23,
+
+ # Accessories
+ "backpack": 24, "bag": 24, "rucksack": 24,
+ "umbrella": 25, "parasol": 25,
+ "handbag": 26, "purse": 26,
+ "tie": 27, "necktie": 27,
+ "suitcase": 28, "luggage": 28,
+
+ # Sports
+ "frisbee": 29,
+ "skis": 30, "ski": 30,
+ "snowboard": 31,
+ "sports ball": 32, "ball": 32, "football": 32, "soccer ball": 32,
+ "kite": 33,
+ "baseball bat": 34, "bat": 34,
+ "baseball glove": 35, "glove": 35,
+ "skateboard": 36,
+ "surfboard": 37,
+ "tennis racket": 38, "racket": 38,
+
+ # Kitchen
+ "bottle": 39, "water bottle": 39,
+ "wine glass": 40, "glass": 40,
+ "cup": 41, "mug": 41, "coffee cup": 41,
+ "fork": 42,
+ "knife": 43,
+ "spoon": 44,
+ "bowl": 45,
+
+ # Food
+ "banana": 46,
+ "apple": 47,
+ "sandwich": 48,
+ "orange": 49,
+ "broccoli": 50,
+ "carrot": 51,
+ "hot dog": 52, "hotdog": 52,
+ "pizza": 53,
+ "donut": 54, "doughnut": 54,
+ "cake": 55,
+
+ # Furniture
+ "chair": 56, "seat": 56,
+ "couch": 57, "sofa": 57,
+ "potted plant": 58, "plant": 58, "houseplant": 58,
+ "bed": 59,
+ "dining table": 60, "table": 60, "desk": 60,
+ "toilet": 61,
+
+ # Electronics
+ "tv": 62, "television": 62, "monitor": 62, "screen": 62,
+ "laptop": 63, "notebook": 63, "computer": 63,
+ "mouse": 64, "computer mouse": 64,
+ "remote": 65, "remote control": 65,
+ "keyboard": 66,
+ "cell phone": 67, "phone": 67, "mobile": 67, "smartphone": 67,
+
+ # Appliances
+ "microwave": 68,
+ "oven": 69, "stove": 69,
+ "toaster": 70,
+ "sink": 71,
+ "refrigerator": 72, "fridge": 72,
+
+ # Other
+ "book": 73,
+ "clock": 74, "watch": 74,
+ "vase": 75,
+ "scissors": 76,
+ "teddy bear": 77, "stuffed animal": 77,
+ "hair drier": 78, "hairdryer": 78,
+ "toothbrush": 79,
+}
+
+# COCO class names (80 classes)
+COCO_CLASSES = [
+ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
+ 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
+ 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
+ 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
+ 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
+ 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
+ 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
+ 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse',
+ 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator',
+ 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
+]
+
+# Pose keypoint names (COCO format - 17 keypoints)
+POSE_KEYPOINTS = [
+ 'nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear',
+ 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow',
+ 'left_wrist', 'right_wrist', 'left_hip', 'right_hip',
+ 'left_knee', 'right_knee', 'left_ankle', 'right_ankle'
+]
+
+# Skeleton connections for drawing
+SKELETON_CONNECTIONS = [
+ (0, 1), (0, 2), (1, 3), (2, 4), # Face
+ (5, 6), (5, 7), (7, 9), (6, 8), (8, 10), # Arms
+ (5, 11), (6, 12), (11, 12), # Torso
+ (11, 13), (13, 15), (12, 14), (14, 16) # Legs
+]
+
+# Keypoint colors (BGR)
+KEYPOINT_COLORS = {
+ 'face': (255, 200, 100), # Light blue for face points
+ 'left_arm': (0, 255, 0), # Green for left side
+ 'right_arm': (0, 0, 255), # Red for right side
+ 'left_leg': (0, 200, 0), # Darker green
+ 'right_leg': (0, 0, 200), # Darker red
+ 'torso': (255, 255, 0), # Cyan
+}
+
+
+def get_keypoint_color(idx: int) -> Tuple[int, int, int]:
+ """Get color for a keypoint based on its index."""
+ if idx <= 4:
+ return KEYPOINT_COLORS['face']
+ elif idx in [5, 7, 9]:
+ return KEYPOINT_COLORS['left_arm']
+ elif idx in [6, 8, 10]:
+ return KEYPOINT_COLORS['right_arm']
+ elif idx in [11, 13, 15]:
+ return KEYPOINT_COLORS['left_leg']
+ elif idx in [12, 14, 16]:
+ return KEYPOINT_COLORS['right_leg']
+ else:
+ return KEYPOINT_COLORS['torso']
+
+
+# ===== DATABASE =====
+class Database:
+ """SQLite database for storing all command center data."""
+
+ def __init__(self, db_path: str = None):
+ if db_path is None:
+ db_path = os.path.join(os.path.dirname(__file__), 'command_center.db')
+ self.db_path = db_path
+ self.lock = threading.Lock()
+ self._init_db()
+
+ def _get_connection(self) -> sqlite3.Connection:
+ """Get a thread-local database connection."""
+ conn = sqlite3.connect(self.db_path, check_same_thread=False)
+ conn.row_factory = sqlite3.Row
+ return conn
+
+ def _init_db(self):
+ """Initialize database tables."""
+ with self._get_connection() as conn:
+ cursor = conn.cursor()
+
+ # Sessions table - tracks each app run
+ cursor.execute('''
+ CREATE TABLE IF NOT EXISTS sessions (
+ id TEXT PRIMARY KEY,
+ started_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ ended_at TIMESTAMP,
+ device TEXT,
+ prompts TEXT,
+ settings TEXT
+ )
+ ''')
+
+ # Detections table - all detected objects
+ cursor.execute('''
+ CREATE TABLE IF NOT EXISTS detections (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ session_id TEXT,
+ detection_id INTEGER,
+ persistent_id INTEGER,
+ label TEXT,
+ confidence REAL,
+ box TEXT,
+ mask_area INTEGER,
+ timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ frame_number INTEGER,
+ yolo_class TEXT,
+ yolo_confidence REAL,
+ FOREIGN KEY (session_id) REFERENCES sessions(id)
+ )
+ ''')
+
+ # Analysis results from Claude
+ cursor.execute('''
+ CREATE TABLE IF NOT EXISTS analysis_results (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ session_id TEXT,
+ detection_id INTEGER,
+ label TEXT,
+ analysis TEXT,
+ timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ image_data TEXT,
+ FOREIGN KEY (session_id) REFERENCES sessions(id)
+ )
+ ''')
+
+ # Location memory - where objects are typically found
+ cursor.execute('''
+ CREATE TABLE IF NOT EXISTS location_memory (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ label TEXT NOT NULL,
+ context TEXT,
+ position TEXT,
+ frequency INTEGER DEFAULT 1,
+ first_seen TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ last_seen TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ UNIQUE(label, context)
+ )
+ ''')
+
+ # Navigation sessions
+ cursor.execute('''
+ CREATE TABLE IF NOT EXISTS navigation_sessions (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ session_id TEXT,
+ target_label TEXT,
+ target_id INTEGER,
+ started_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ ended_at TIMESTAMP,
+ reached BOOLEAN DEFAULT FALSE,
+ path_history TEXT,
+ scene_context TEXT,
+ FOREIGN KEY (session_id) REFERENCES sessions(id)
+ )
+ ''')
+
+ # Obstacles detected during navigation
+ cursor.execute('''
+ CREATE TABLE IF NOT EXISTS obstacles (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ navigation_id INTEGER,
+ label TEXT,
+ obstacle_type TEXT,
+ box TEXT,
+ distance TEXT,
+ alert_sent BOOLEAN DEFAULT FALSE,
+ timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ FOREIGN KEY (navigation_id) REFERENCES navigation_sessions(id)
+ )
+ ''')
+
+ # Voice queries and results
+ cursor.execute('''
+ CREATE TABLE IF NOT EXISTS voice_queries (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ session_id TEXT,
+ query TEXT,
+ parsed_prompts TEXT,
+ was_search BOOLEAN,
+ was_describe BOOLEAN,
+ timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ FOREIGN KEY (session_id) REFERENCES sessions(id)
+ )
+ ''')
+
+ # General event log
+ cursor.execute('''
+ CREATE TABLE IF NOT EXISTS event_log (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ session_id TEXT,
+ event_type TEXT,
+ level TEXT DEFAULT 'INFO',
+ message TEXT,
+ data TEXT,
+ timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ FOREIGN KEY (session_id) REFERENCES sessions(id)
+ )
+ ''')
+
+ # Create indexes for common queries
+ cursor.execute('CREATE INDEX IF NOT EXISTS idx_detections_session ON detections(session_id)')
+ cursor.execute('CREATE INDEX IF NOT EXISTS idx_detections_label ON detections(label)')
+ cursor.execute('CREATE INDEX IF NOT EXISTS idx_location_label ON location_memory(label)')
+ cursor.execute('CREATE INDEX IF NOT EXISTS idx_obstacles_nav ON obstacles(navigation_id)')
+ cursor.execute('CREATE INDEX IF NOT EXISTS idx_events_session ON event_log(session_id)')
+
+ conn.commit()
+ print(f"Database initialized: {self.db_path}")
+
+ # ===== SESSION METHODS =====
+
+ def create_session(self, device: str, prompts: List[str], settings: Dict) -> str:
+ """Create a new session and return its ID."""
+ session_id = str(uuid.uuid4())
+ with self.lock:
+ with self._get_connection() as conn:
+ conn.execute(
+ 'INSERT INTO sessions (id, device, prompts, settings) VALUES (?, ?, ?, ?)',
+ (session_id, device, json.dumps(prompts), json.dumps(settings))
+ )
+ conn.commit()
+ return session_id
+
+ def end_session(self, session_id: str):
+ """Mark a session as ended."""
+ with self.lock:
+ with self._get_connection() as conn:
+ conn.execute(
+ 'UPDATE sessions SET ended_at = CURRENT_TIMESTAMP WHERE id = ?',
+ (session_id,)
+ )
+ conn.commit()
+
+ # ===== DETECTION METHODS =====
+
+ def save_detection(self, session_id: str, detection: Dict, frame_number: int):
+ """Save a detection to the database."""
+ with self.lock:
+ with self._get_connection() as conn:
+ conn.execute('''
+ INSERT INTO detections
+ (session_id, detection_id, persistent_id, label, confidence, box, mask_area, frame_number, yolo_class, yolo_confidence)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+ ''', (
+ session_id,
+ detection.get('id'),
+ detection.get('persistent_id'),
+ detection.get('label'),
+ detection.get('confidence'),
+ json.dumps(detection.get('box')),
+ detection.get('mask_area'),
+ frame_number,
+ detection.get('yolo_class'),
+ detection.get('yolo_confidence')
+ ))
+ conn.commit()
+
+ def save_detections_batch(self, session_id: str, detections: List[Dict], frame_number: int):
+ """Save multiple detections in a batch."""
+ if not detections:
+ return
+ with self.lock:
+ with self._get_connection() as conn:
+ conn.executemany('''
+ INSERT INTO detections
+ (session_id, detection_id, persistent_id, label, confidence, box, mask_area, frame_number, yolo_class, yolo_confidence)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+ ''', [(
+ session_id,
+ d.get('id'),
+ d.get('persistent_id'),
+ d.get('label'),
+ d.get('confidence'),
+ json.dumps(d.get('box')),
+ d.get('mask_area'),
+ frame_number,
+ d.get('yolo_class'),
+ d.get('yolo_confidence')
+ ) for d in detections])
+ conn.commit()
+
+ def get_detection_history(self, session_id: str = None, label: str = None, limit: int = 100) -> List[Dict]:
+ """Get detection history with optional filters."""
+ query = 'SELECT * FROM detections WHERE 1=1'
+ params = []
+
+ if session_id:
+ query += ' AND session_id = ?'
+ params.append(session_id)
+ if label:
+ query += ' AND label LIKE ?'
+ params.append(f'%{label}%')
+
+ query += ' ORDER BY timestamp DESC LIMIT ?'
+ params.append(limit)
+
+ with self._get_connection() as conn:
+ rows = conn.execute(query, params).fetchall()
+ return [dict(row) for row in rows]
+
+ # ===== ANALYSIS METHODS =====
+
+ def save_analysis(self, session_id: str, detection_id: int, label: str, analysis: str, image_data: str = None):
+ """Save Claude analysis result."""
+ with self.lock:
+ with self._get_connection() as conn:
+ conn.execute('''
+ INSERT INTO analysis_results (session_id, detection_id, label, analysis, image_data)
+ VALUES (?, ?, ?, ?, ?)
+ ''', (session_id, detection_id, label, analysis, image_data))
+ conn.commit()
+
+ def get_analysis_history(self, session_id: str = None, limit: int = 50) -> List[Dict]:
+ """Get analysis history."""
+ query = 'SELECT * FROM analysis_results'
+ params = []
+
+ if session_id:
+ query += ' WHERE session_id = ?'
+ params.append(session_id)
+
+ query += ' ORDER BY timestamp DESC LIMIT ?'
+ params.append(limit)
+
+ with self._get_connection() as conn:
+ rows = conn.execute(query, params).fetchall()
+ return [dict(row) for row in rows]
+
+ # ===== LOCATION MEMORY METHODS =====
+
+ def remember_location(self, label: str, context: str, position: Dict = None):
+ """Remember where an object was found."""
+ label_key = label.lower().strip()
+ context_key = context.lower().strip() if context else ""
+
+ with self.lock:
+ with self._get_connection() as conn:
+ # Try to update existing entry
+ cursor = conn.execute('''
+ UPDATE location_memory
+ SET frequency = frequency + 1,
+ last_seen = CURRENT_TIMESTAMP,
+ position = ?
+ WHERE label = ? AND context = ?
+ ''', (json.dumps(position) if position else None, label_key, context_key))
+
+ if cursor.rowcount == 0:
+ # Insert new entry
+ conn.execute('''
+ INSERT INTO location_memory (label, context, position, frequency)
+ VALUES (?, ?, ?, 1)
+ ''', (label_key, context_key, json.dumps(position) if position else None))
+
+ conn.commit()
+
+ def recall_location(self, label: str) -> Optional[Dict]:
+ """Recall where an object was typically found."""
+ label_key = label.lower().strip()
+
+ with self._get_connection() as conn:
+ row = conn.execute('''
+ SELECT * FROM location_memory
+ WHERE label = ?
+ ORDER BY frequency DESC, last_seen DESC
+ LIMIT 1
+ ''', (label_key,)).fetchone()
+
+ if row:
+ result = dict(row)
+ if result.get('position'):
+ result['position'] = json.loads(result['position'])
+ return result
+ return None
+
+ def get_all_location_memories(self) -> List[Dict]:
+ """Get all location memories."""
+ with self._get_connection() as conn:
+ rows = conn.execute('''
+ SELECT label, context, frequency, last_seen
+ FROM location_memory
+ ORDER BY frequency DESC, last_seen DESC
+ ''').fetchall()
+ return [dict(row) for row in rows]
+
+ def clear_location_memory(self, label: str = None):
+ """Clear location memory for a label or all."""
+ with self.lock:
+ with self._get_connection() as conn:
+ if label:
+ conn.execute('DELETE FROM location_memory WHERE label = ?', (label.lower().strip(),))
+ else:
+ conn.execute('DELETE FROM location_memory')
+ conn.commit()
+
+ # ===== NAVIGATION METHODS =====
+
+ def start_navigation_session(self, session_id: str, target_label: str, target_id: int = None) -> int:
+ """Start a new navigation session and return its ID."""
+ with self.lock:
+ with self._get_connection() as conn:
+ cursor = conn.execute('''
+ INSERT INTO navigation_sessions (session_id, target_label, target_id)
+ VALUES (?, ?, ?)
+ ''', (session_id, target_label, target_id))
+ conn.commit()
+ return cursor.lastrowid
+
+ def end_navigation_session(self, nav_id: int, reached: bool, path_history: List = None, scene_context: Dict = None):
+ """End a navigation session."""
+ with self.lock:
+ with self._get_connection() as conn:
+ conn.execute('''
+ UPDATE navigation_sessions
+ SET ended_at = CURRENT_TIMESTAMP,
+ reached = ?,
+ path_history = ?,
+ scene_context = ?
+ WHERE id = ?
+ ''', (reached, json.dumps(path_history), json.dumps(scene_context), nav_id))
+ conn.commit()
+
+ def save_obstacle(self, nav_id: int, label: str, obstacle_type: str, box: List, distance: str, alert_sent: bool = False):
+ """Save an obstacle detected during navigation."""
+ with self.lock:
+ with self._get_connection() as conn:
+ conn.execute('''
+ INSERT INTO obstacles (navigation_id, label, obstacle_type, box, distance, alert_sent)
+ VALUES (?, ?, ?, ?, ?, ?)
+ ''', (nav_id, label, obstacle_type, json.dumps(box), distance, alert_sent))
+ conn.commit()
+
+ def get_navigation_history(self, session_id: str = None, limit: int = 20) -> List[Dict]:
+ """Get navigation history."""
+ query = 'SELECT * FROM navigation_sessions'
+ params = []
+
+ if session_id:
+ query += ' WHERE session_id = ?'
+ params.append(session_id)
+
+ query += ' ORDER BY started_at DESC LIMIT ?'
+ params.append(limit)
+
+ with self._get_connection() as conn:
+ rows = conn.execute(query, params).fetchall()
+ return [dict(row) for row in rows]
+
+ # ===== VOICE QUERY METHODS =====
+
+ def save_voice_query(self, session_id: str, query: str, parsed_prompts: List[str],
+ was_search: bool = True, was_describe: bool = False):
+ """Save a voice query."""
+ with self.lock:
+ with self._get_connection() as conn:
+ conn.execute('''
+ INSERT INTO voice_queries (session_id, query, parsed_prompts, was_search, was_describe)
+ VALUES (?, ?, ?, ?, ?)
+ ''', (session_id, query, json.dumps(parsed_prompts), was_search, was_describe))
+ conn.commit()
+
+ # ===== EVENT LOG METHODS =====
+
+ def log_event(self, session_id: str, event_type: str, message: str, level: str = 'INFO', data: Dict = None):
+ """Log an event to the database."""
+ with self.lock:
+ with self._get_connection() as conn:
+ conn.execute('''
+ INSERT INTO event_log (session_id, event_type, level, message, data)
+ VALUES (?, ?, ?, ?, ?)
+ ''', (session_id, event_type, level, message, json.dumps(data) if data else None))
+ conn.commit()
+
+ def get_event_log(self, session_id: str = None, event_type: str = None, limit: int = 100) -> List[Dict]:
+ """Get event log with optional filters."""
+ query = 'SELECT * FROM event_log WHERE 1=1'
+ params = []
+
+ if session_id:
+ query += ' AND session_id = ?'
+ params.append(session_id)
+ if event_type:
+ query += ' AND event_type = ?'
+ params.append(event_type)
+
+ query += ' ORDER BY timestamp DESC LIMIT ?'
+ params.append(limit)
+
+ with self._get_connection() as conn:
+ rows = conn.execute(query, params).fetchall()
+ return [dict(row) for row in rows]
+
+ # ===== STATISTICS METHODS =====
+
+ def get_session_stats(self, session_id: str) -> Dict:
+ """Get statistics for a session."""
+ with self._get_connection() as conn:
+ stats = {}
+
+ # Detection count
+ row = conn.execute(
+ 'SELECT COUNT(*) as count FROM detections WHERE session_id = ?',
+ (session_id,)
+ ).fetchone()
+ stats['total_detections'] = row['count'] if row else 0
+
+ # Unique labels
+ rows = conn.execute(
+ 'SELECT DISTINCT label FROM detections WHERE session_id = ?',
+ (session_id,)
+ ).fetchall()
+ stats['unique_labels'] = [row['label'] for row in rows]
+ stats['unique_label_count'] = len(stats['unique_labels'])
+
+ # Analysis count
+ row = conn.execute(
+ 'SELECT COUNT(*) as count FROM analysis_results WHERE session_id = ?',
+ (session_id,)
+ ).fetchone()
+ stats['total_analyses'] = row['count'] if row else 0
+
+ # Navigation count
+ row = conn.execute(
+ 'SELECT COUNT(*) as count, SUM(CASE WHEN reached THEN 1 ELSE 0 END) as reached FROM navigation_sessions WHERE session_id = ?',
+ (session_id,)
+ ).fetchone()
+ stats['navigation_sessions'] = row['count'] if row else 0
+ stats['successful_navigations'] = row['reached'] if row and row['reached'] else 0
+
+ return stats
+
+ def migrate_from_json(self, location_memory_file: str):
+ """Migrate existing JSON location memory to SQLite."""
+ if not os.path.exists(location_memory_file):
+ return
+
+ try:
+ with open(location_memory_file, 'r') as f:
+ old_memory = json.load(f)
+
+ for label, entries in old_memory.items():
+ for entry in entries:
+ self.remember_location(
+ label=label,
+ context=entry.get('context', ''),
+ position=entry.get('position')
+ )
+ # Update frequency if specified
+ if entry.get('frequency', 1) > 1:
+ with self._get_connection() as conn:
+ conn.execute('''
+ UPDATE location_memory
+ SET frequency = ?
+ WHERE label = ? AND context = ?
+ ''', (entry['frequency'], label.lower(), entry.get('context', '').lower()))
+ conn.commit()
+
+ print(f"Migrated {len(old_memory)} items from JSON to SQLite")
+
+ # Optionally rename old file
+ backup_path = location_memory_file + '.bak'
+ os.rename(location_memory_file, backup_path)
+ print(f"Old JSON file backed up to {backup_path}")
+
+ except Exception as e:
+ print(f"Error migrating from JSON: {e}")
+
+
+# Global database instance
+db = Database()
+
+
+# ===== SMART OBSTACLE DETECTION =====
+# Uses Claude AI to understand context and identify actual obstacles in the path
+
+def analyze_obstacles_with_claude(image_data: str, target_label: str, target_box: List = None) -> List[Dict]:
+ """
+ Use Claude to intelligently identify obstacles in the user's path.
+
+ This is smarter than a static list because Claude:
+ 1. Understands what the user is looking for (won't mark it as obstacle)
+ 2. Understands spatial relationships (what's actually in the path)
+ 3. Understands environmental context (room type, indoor/outdoor)
+ 4. Can identify hazards specific to the situation
+ """
+ if not ANTHROPIC_API_KEY:
+ return []
+
+ try:
+ from anthropic import Anthropic
+ client = Anthropic(api_key=ANTHROPIC_API_KEY)
+
+ # Build context about the target
+ target_context = f"The user is navigating to find: {target_label}"
+ if target_box:
+ # Describe where the target is in the frame
+ frame_center_x = 320 # Assuming 640 width
+ target_center_x = (target_box[0] + target_box[2]) / 2
+ if target_center_x < frame_center_x - 100:
+ target_position = "on the left side of the view"
+ elif target_center_x > frame_center_x + 100:
+ target_position = "on the right side of the view"
+ else:
+ target_position = "ahead in the center of the view"
+ target_context += f". The {target_label} is currently visible {target_position}."
+
+ prompt = f"""You are helping a visually impaired person navigate to an object. Analyze this image for obstacles.
+
+{target_context}
+
+IMPORTANT RULES:
+1. The {target_label} is NOT an obstacle - it's the destination
+2. Only identify objects that could physically block the path to the {target_label}
+3. Focus on objects between the camera/user and the target
+4. Consider floor-level hazards (cables, steps, rugs, wet surfaces)
+5. Consider objects at body height that could be walked into
+6. Ignore objects that are clearly not in the walking path
+
+For each obstacle you identify, provide:
+- name: What the obstacle is (be specific, e.g., "wooden chair" not just "furniture")
+- severity: "high" (could cause injury/fall), "medium" (could cause collision), or "low" (minor obstruction)
+- position: Where in the frame (left, center, right, floor, ahead)
+- distance: How close it appears (very_close, close, medium, far)
+- reason: Brief explanation of why it's an obstacle
+
+Respond in JSON format:
+{{
+ "environment": "brief description of the space (e.g., living room, hallway, outdoor path)",
+ "path_clear": true/false,
+ "obstacles": [
+ {{
+ "name": "obstacle name",
+ "severity": "high/medium/low",
+ "position": "left/center/right/floor",
+ "distance": "very_close/close/medium/far",
+ "reason": "why this is in the way"
+ }}
+ ],
+ "safe_direction": "suggestion for safest path if obstacles present"
+}}
+
+If the path appears clear, return an empty obstacles array.
+Only include obstacles that are genuinely in the way - don't over-report."""
+
+ response = client.messages.create(
+ model="claude-sonnet-4-20250514",
+ max_tokens=1000,
+ messages=[
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": "image/jpeg",
+ "data": image_data
+ }
+ },
+ {
+ "type": "text",
+ "text": prompt
+ }
+ ]
+ }
+ ]
+ )
+
+ # Parse Claude's response
+ response_text = response.content[0].text
+
+ # Extract JSON from response
+ import re
+ json_match = re.search(r'\{[\s\S]*\}', response_text)
+ if json_match:
+ result = json.loads(json_match.group())
+
+ obstacles = []
+ for obs in result.get("obstacles", []):
+ obstacles.append({
+ "label": obs.get("name", "unknown obstacle"),
+ "type": obs.get("severity", "medium"),
+ "position": obs.get("position", "ahead"),
+ "distance": obs.get("distance", "medium"),
+ "reason": obs.get("reason", ""),
+ "from_claude": True
+ })
+
+ # Store environment info
+ if result.get("environment"):
+ cc.navigation_context = cc.navigation_context or {}
+ cc.navigation_context["environment"] = result.get("environment")
+ cc.navigation_context["path_clear"] = result.get("path_clear", True)
+ cc.navigation_context["safe_direction"] = result.get("safe_direction")
+
+ return obstacles
+
+ return []
+
+ except Exception as e:
+ cc.log(f"Claude obstacle analysis failed: {e}", "ERROR")
+ return []
+
+
+# Global state
+class CommandCenter:
+ """Global state manager for the command center."""
+
+ def __init__(self):
+ self.lock = threading.Lock()
+ self.running = False
+ self.paused = False
+
+ # Detection settings
+ self.prompts = ["object"]
+ self.confidence_threshold = 0.3
+ self.max_objects_per_prompt = {} # prompt -> max count (None = unlimited)
+ self.show_all_matches = {} # prompt -> bool (show all even if over limit)
+
+ # Current detection state
+ self.current_detections = [] # List of detection dicts
+ self.frame_count = 0
+ self.fps = 0.0
+ self.device_str = "cpu"
+
+ # Verbose log
+ self.log_entries = deque(maxlen=100)
+
+ # Claude analysis results
+ self.analysis_queue = [] # Objects waiting for analysis
+ self.analysis_results = deque(maxlen=20) # Recent analysis results
+ self.analyzing = False
+
+ # Frame for streaming
+ self.current_frame = None # Frame with overlays (for display)
+ self.current_raw_frame = None # Raw frame without overlays (for analysis)
+ self.current_frame_jpeg = None
+
+ # Camera and model
+ self.camera = None
+ self.processor = None
+ self.state = None
+ self.video_predictor = None # SAM3 video predictor for memory tracking
+
+ # Basic tracking state (optical flow)
+ self.enable_tracking = True
+ self.skip_frames = 3
+ self.last_masks = None
+ self.last_boxes = None
+ self.last_scores = None
+ self.last_labels = None
+ self.prev_gray = None
+
+ # ===== FEATURE TOGGLES =====
+
+ # Video Tracking with Memory (SAM3 tracker)
+ self.enable_memory_tracking = False
+ self.memory_bank = {} # object_id -> list of mask features
+ self.memory_max_frames = 10 # Max frames to keep in memory per object
+
+ # Multi-Object Tracking with Persistent IDs
+ self.enable_persistent_ids = False
+ self.object_registry = {} # object_id -> {label, first_seen, last_seen, color, ...}
+ self.next_object_id = 1
+ self.iou_threshold = 0.3 # IoU threshold for matching objects
+
+ # Multi-Object Video Tracking
+ self.tracked_objects = {} # object_id -> tracking state
+ self.object_colors = {} # object_id -> color
+
+ # Mask Refinement Options
+ self.enable_fill_holes = False
+ self.fill_hole_area = 100 # Max hole area to fill (pixels)
+ self.enable_non_overlap = False # Prevent mask overlaps
+ self.enable_smooth_edges = False
+ self.smooth_kernel_size = 5
+
+ # Advanced Detection Controls
+ self.enable_boundary_suppression = False
+ self.boundary_margin = 10 # Pixels from edge to suppress
+ self.enable_occlusion_suppression = False
+ self.occlusion_threshold = 0.5 # Overlap ratio to suppress
+ self.enable_hotstart = False
+ self.hotstart_frames = 5 # Frames before confirming new detection
+ self.pending_detections = {} # id -> {frames_seen, detection_data}
+
+ # ===== YOLO FEATURES =====
+ self.yolo_classify_model = None
+ self.yolo_pose_model = None
+ self.yolo_available = False
+
+ # YOLO Classification
+ self.enable_yolo_classify = False
+ self.yolo_classify_threshold = 0.3
+ self.yolo_classify_every_n = 1 # Run classification every N keyframes
+
+ # YOLO Pose Estimation
+ self.enable_yolo_pose = False
+ self.yolo_pose_threshold = 0.5
+ self.show_keypoint_labels = False
+ self.show_skeleton = True
+ self.keypoint_radius = 4
+ self.skeleton_thickness = 2
+
+ # Label spoofing (use SAM3->COCO mapping)
+ self.enable_label_spoofing = True
+
+ # Store pose results
+ self.last_poses = {} # object_id -> keypoints
+
+ # ===== VOICE SEARCH =====
+ self.voice_enabled = True
+ self.last_voice_query = ""
+ self.last_parsed_prompts = []
+ self.tts_enabled = True
+ self.tts_voice = "default"
+ self.voice_feedback_messages = deque(maxlen=10)
+
+ # ===== CAMERA SETTINGS =====
+ self.current_camera_id = 0
+ self.available_cameras = [] # List of {id, name, description}
+ self.flip_horizontal = False
+ self.flip_vertical = False
+
+ # ===== REFERENCE IMAGE SEARCH =====
+ self.clip_model = None
+ self.clip_processor = None
+ self.clip_available = False
+ self.reference_image = None # PIL Image
+ self.reference_embedding = None # CLIP embedding
+ self.reference_description = None # Text description from Claude
+ self.visual_match_threshold = 0.75 # Similarity threshold for CLIP matching
+ self.visual_match_enabled = False # Whether to use CLIP matching
+
+ # ===== GEOMETRIC PROMPTS (Draw to Search) =====
+ self.pending_box_prompt = None # (x1, y1, x2, y2) for box prompt
+ self.pending_point_prompt = None # (x, y) for point prompt
+ self.draw_mode = None # 'box' or 'point'
+
+ # ===== SESSION TRACKING =====
+ self.session_id = None # Current session ID for database
+
+ # ===== NAVIGATION SYSTEM (Accessibility) =====
+ self.navigation_active = False
+ self.navigation_target = None # Target object label
+ self.navigation_target_id = None # Target detection ID
+ self.navigation_db_id = None # Navigation session ID in database
+ self.navigation_start_time = None
+ self.navigation_last_seen = None # Last position of target
+ self.navigation_guidance_queue = deque(maxlen=10) # Pending guidance messages
+ self.navigation_last_guidance = None # Last spoken guidance
+ self.navigation_last_guidance_time = 0
+ self.navigation_guidance_interval = 1.5 # Seconds between guidance
+ self.navigation_reached = False # Whether target was reached
+ self.navigation_context = None # Scene context from Claude
+
+ # Navigation spatial tracking
+ self.navigation_target_history = [] # History of target positions
+ self.navigation_frame_center = (320, 240) # Frame center (updated dynamically)
+ self.navigation_proximity_threshold = 0.25 # Object covers 25% of frame = reachable
+ self.navigation_close_threshold = 0.15 # Getting close
+ self.navigation_direction_deadzone = 0.1 # Center deadzone
+
+ # ===== OBSTACLE DETECTION =====
+ self.obstacle_detection_active = False # Run obstacle detection during navigation
+ self.current_obstacles = [] # Currently detected obstacles
+ self.obstacle_alert_cooldown = {} # obstacle_label -> last_alert_time
+ self.obstacle_alert_interval = 3.0 # Seconds between repeated alerts for same obstacle
+ self.obstacle_masks = None # Masks for obstacles to render
+ self.obstacle_boxes = None # Boxes for obstacles
+
+ # ===== LOCATION MEMORY (Now uses SQLite) =====
+ self.location_memory_file = os.path.join(os.path.dirname(__file__), '.location_memory.json')
+ self._migrate_location_memory()
+
+ def _migrate_location_memory(self):
+ """Migrate old JSON location memory to SQLite if it exists."""
+ if os.path.exists(self.location_memory_file):
+ db.migrate_from_json(self.location_memory_file)
+
+ def remember_location(self, label: str, context: str, position: Dict = None):
+ """Remember where an object was found (uses SQLite)."""
+ db.remember_location(label, context, position)
+ self.log(f"Remembered: {label} found in {context}")
+
+ def recall_location(self, label: str) -> Optional[Dict]:
+ """Recall where an object was last found (uses SQLite)."""
+ return db.recall_location(label)
+
+ def get_all_location_memories(self) -> List[Dict]:
+ """Get all location memories from database."""
+ return db.get_all_location_memories()
+
+ def clear_location_memory(self, label: str = None):
+ """Clear location memory (uses SQLite)."""
+ db.clear_location_memory(label)
+ self.log(f"Cleared location memory" + (f" for {label}" if label else ""))
+
+ def _old_recall_location(self, label: str) -> Optional[Dict]:
+ """Old recall method - kept for reference."""
+ label_key = label.lower().strip()
+
+ if label_key not in self.location_memory:
+ return None
+
+ entries = self.location_memory[label_key]
+ if not entries:
+ return None
+
+ # Return most frequent location, or most recent
+ sorted_entries = sorted(entries, key=lambda x: (x.get("frequency", 1), x.get("timestamp", "")), reverse=True)
+ return sorted_entries[0]
+
+ def add_navigation_guidance(self, message: str, priority: int = 1):
+ """Add a guidance message to the queue."""
+ with self.lock:
+ self.navigation_guidance_queue.append({
+ "message": message,
+ "priority": priority,
+ "timestamp": time.time()
+ })
+
+ def get_pending_guidance(self) -> Optional[str]:
+ """Get the next pending guidance message."""
+ with self.lock:
+ if self.navigation_guidance_queue:
+ # Get highest priority message
+ sorted_queue = sorted(self.navigation_guidance_queue, key=lambda x: -x["priority"])
+ msg = sorted_queue[0]
+ self.navigation_guidance_queue.remove(msg)
+ return msg["message"]
+ return None
+
+ def add_voice_feedback(self, message: str, msg_type: str = "info"):
+ """Add a voice feedback message."""
+ with self.lock:
+ self.voice_feedback_messages.append({
+ "message": message,
+ "type": msg_type,
+ "timestamp": datetime.now().strftime("%H:%M:%S")
+ })
+
+ def log(self, message: str, level: str = "INFO"):
+ """Add a log entry."""
+ timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3]
+ entry = {
+ "timestamp": timestamp,
+ "level": level,
+ "message": message
+ }
+ with self.lock:
+ self.log_entries.append(entry)
+
+ def get_logs(self, limit: int = 50) -> List[Dict]:
+ """Get recent log entries."""
+ with self.lock:
+ return list(self.log_entries)[-limit:]
+
+ def add_detection(self, detection: Dict):
+ """Add a detection to the current list."""
+ with self.lock:
+ self.current_detections.append(detection)
+
+ def clear_detections(self):
+ """Clear all current detections."""
+ with self.lock:
+ self.current_detections = []
+
+ def get_filtered_detections(self) -> Tuple[List[Dict], Dict]:
+ """Get detections filtered by max count settings."""
+ with self.lock:
+ detections = self.current_detections.copy()
+
+ # Group by prompt
+ by_prompt = {}
+ for det in detections:
+ prompt = det.get("label", "unknown")
+ if prompt not in by_prompt:
+ by_prompt[prompt] = []
+ by_prompt[prompt].append(det)
+
+ # Apply filters
+ filtered = []
+ hidden_counts = {}
+
+ for prompt, dets in by_prompt.items():
+ max_count = self.max_objects_per_prompt.get(prompt)
+ show_all = self.show_all_matches.get(prompt, False)
+
+ if max_count is not None and not show_all:
+ dets_sorted = sorted(dets, key=lambda d: d.get("confidence", 0), reverse=True)
+ filtered.extend(dets_sorted[:max_count])
+ hidden = len(dets_sorted) - max_count
+ if hidden > 0:
+ hidden_counts[prompt] = hidden
+ else:
+ filtered.extend(dets)
+
+ return filtered, hidden_counts
+
+ def queue_analysis(self, detection_id: int, image_data: str):
+ """Queue an object for Claude analysis."""
+ with self.lock:
+ self.analysis_queue.append({
+ "id": detection_id,
+ "image_data": image_data,
+ "timestamp": datetime.now().isoformat()
+ })
+
+ def add_analysis_result(self, detection_id: int, result: str):
+ """Add a Claude analysis result."""
+ with self.lock:
+ self.analysis_results.append({
+ "id": detection_id,
+ "result": result,
+ "timestamp": datetime.now().strftime("%H:%M:%S")
+ })
+
+ def get_feature_status(self) -> Dict:
+ """Get status of all feature toggles."""
+ return {
+ "tracking": self.enable_tracking,
+ "memory_tracking": self.enable_memory_tracking,
+ "persistent_ids": self.enable_persistent_ids,
+ "fill_holes": self.enable_fill_holes,
+ "non_overlap": self.enable_non_overlap,
+ "smooth_edges": self.enable_smooth_edges,
+ "boundary_suppression": self.enable_boundary_suppression,
+ "occlusion_suppression": self.enable_occlusion_suppression,
+ "hotstart": self.enable_hotstart,
+ "yolo_classify": self.enable_yolo_classify,
+ "yolo_pose": self.enable_yolo_pose,
+ "show_keypoint_labels": self.show_keypoint_labels,
+ "show_skeleton": self.show_skeleton,
+ "label_spoofing": self.enable_label_spoofing,
+ }
+
+
+# Global command center instance
+cc = CommandCenter()
+
+
+# Color palette (BGR for OpenCV)
+COLORS = [
+ (255, 0, 0), # Blue
+ (0, 255, 0), # Green
+ (0, 0, 255), # Red
+ (255, 255, 0), # Cyan
+ (255, 0, 255), # Magenta
+ (0, 255, 255), # Yellow
+ (128, 0, 255), # Purple
+ (255, 128, 0), # Orange
+ (128, 255, 0), # Lime
+ (0, 128, 255), # Sky blue
+]
+
+
+def load_yolo_models():
+ """Load YOLO models for classification and pose estimation."""
+ global cc
+
+ try:
+ from ultralytics import YOLO
+
+ cc.log("Loading YOLO models...")
+
+ # Model priority: YOLO12 -> YOLO11 -> YOLOv8
+ # YOLO12 is newest (Feb 2025) but pretrained weights may not be available for all tasks
+ cls_models = ['yolo12n-cls.pt', 'yolo11n-cls.pt', 'yolov8n-cls.pt']
+ pose_models = ['yolo12n-pose.pt', 'yolo11n-pose.pt', 'yolov8n-pose.pt']
+
+ # Load classification model
+ cc.yolo_classify_model = None
+ for model_name in cls_models:
+ try:
+ cc.yolo_classify_model = YOLO(model_name)
+ cc.log(f"YOLO classification model loaded ({model_name})", "SUCCESS")
+ break
+ except Exception as e:
+ cc.log(f"Could not load {model_name}: {e}", "WARN")
+ continue
+
+ if cc.yolo_classify_model is None:
+ cc.log("No classification model available", "WARN")
+
+ # Load pose estimation model
+ cc.yolo_pose_model = None
+ for model_name in pose_models:
+ try:
+ cc.yolo_pose_model = YOLO(model_name)
+ cc.log(f"YOLO pose model loaded ({model_name})", "SUCCESS")
+ break
+ except Exception as e:
+ cc.log(f"Could not load {model_name}: {e}", "WARN")
+ continue
+
+ if cc.yolo_pose_model is None:
+ cc.log("No pose model available", "WARN")
+
+ cc.yolo_available = cc.yolo_classify_model is not None or cc.yolo_pose_model is not None
+
+ if cc.yolo_available:
+ cc.log("YOLO models ready", "SUCCESS")
+ else:
+ cc.log("No YOLO models available", "WARN")
+
+ except ImportError:
+ cc.log("ultralytics not installed. YOLO features disabled. Install with: pip install ultralytics", "WARN")
+ cc.yolo_available = False
+
+
+def load_clip_model():
+ """Load CLIP model for visual similarity matching."""
+ global cc
+
+ try:
+ from transformers import CLIPProcessor, CLIPModel
+
+ cc.log("Loading CLIP model for visual matching...")
+
+ # Use a smaller/faster CLIP model
+ model_name = "openai/clip-vit-base-patch32"
+
+ cc.clip_processor = CLIPProcessor.from_pretrained(model_name)
+ cc.clip_model = CLIPModel.from_pretrained(model_name)
+
+ # Move to appropriate device
+ device = get_device()
+ cc.clip_model = cc.clip_model.to(device)
+ cc.clip_model.eval()
+
+ cc.clip_available = True
+ cc.log("CLIP model loaded successfully", "SUCCESS")
+
+ except ImportError:
+ cc.log("transformers not installed. Visual matching disabled. Install with: pip install transformers", "WARN")
+ cc.clip_available = False
+ except Exception as e:
+ cc.log(f"Failed to load CLIP model: {e}", "ERROR")
+ cc.clip_available = False
+
+
+def get_clip_embedding(image: Image.Image) -> Optional[torch.Tensor]:
+ """Get CLIP embedding for an image."""
+ global cc
+
+ if not cc.clip_available or cc.clip_model is None:
+ return None
+
+ try:
+ device = get_device()
+ inputs = cc.clip_processor(images=image, return_tensors="pt")
+ inputs = {k: v.to(device) for k, v in inputs.items()}
+
+ with torch.no_grad():
+ embedding = cc.clip_model.get_image_features(**inputs)
+ # Normalize
+ embedding = embedding / embedding.norm(dim=-1, keepdim=True)
+
+ return embedding
+
+ except Exception as e:
+ cc.log(f"Failed to get CLIP embedding: {e}", "ERROR")
+ return None
+
+
+def compute_clip_similarity(embedding1: torch.Tensor, embedding2: torch.Tensor) -> float:
+ """Compute cosine similarity between two CLIP embeddings."""
+ if embedding1 is None or embedding2 is None:
+ return 0.0
+
+ with torch.no_grad():
+ similarity = torch.nn.functional.cosine_similarity(embedding1, embedding2)
+ return float(similarity.item())
+
+
+def describe_image_with_claude(image_data: str) -> Optional[str]:
+ """Use Claude to generate a detailed description of an image for search."""
+ global ANTHROPIC_API_KEY
+
+ if not ANTHROPIC_API_KEY:
+ return None
+
+ try:
+ import anthropic
+
+ client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)
+
+ message = client.messages.create(
+ model="claude-sonnet-4-20250514",
+ max_tokens=200,
+ messages=[
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": "image/jpeg",
+ "data": image_data,
+ },
+ },
+ {
+ "type": "text",
+ "text": "Describe this object concisely for visual detection. Focus on: type of object, color, distinctive features, shape. Return ONLY the description phrase (e.g., 'red baseball cap with white Nike logo', 'black leather handbag with gold clasp'). No other text."
+ }
+ ],
+ }
+ ],
+ )
+
+ return message.content[0].text.strip()
+
+ except Exception as e:
+ cc.log(f"Failed to describe image with Claude: {e}", "ERROR")
+ return None
+
+
+# ===== NAVIGATION SYSTEM FUNCTIONS =====
+
+def analyze_scene_context(image_data: str) -> Optional[Dict]:
+ """
+ Use Claude to analyze the scene for navigation context.
+ Returns location type, obstacles, and spatial awareness info.
+ """
+ global ANTHROPIC_API_KEY
+
+ if not ANTHROPIC_API_KEY:
+ return None
+
+ try:
+ import anthropic
+
+ client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)
+
+ message = client.messages.create(
+ model="claude-sonnet-4-20250514",
+ max_tokens=300,
+ messages=[
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": "image/jpeg",
+ "data": image_data,
+ },
+ },
+ {
+ "type": "text",
+ "text": """Analyze this scene for navigation assistance. Return JSON only:
+{
+ "location": "room type (kitchen, living room, bedroom, bathroom, office, hallway, outdoor, etc.)",
+ "obstacles": ["list of obstacles or hazards visible"],
+ "surfaces": ["tables, counters, shelves visible"],
+ "lighting": "bright/dim/dark",
+ "space": "open/cluttered/narrow",
+ "landmarks": ["notable items that help orient"]
+}"""
+ }
+ ],
+ }
+ ],
+ )
+
+ response_text = message.content[0].text.strip()
+
+ # Parse JSON
+ if "```json" in response_text:
+ response_text = response_text.split("```json")[1].split("```")[0].strip()
+ elif "```" in response_text:
+ response_text = response_text.split("```")[1].split("```")[0].strip()
+
+ return json.loads(response_text)
+
+ except Exception as e:
+ cc.log(f"Scene analysis failed: {e}", "WARN")
+ return None
+
+
+def compute_navigation_guidance(target_box: List[float], frame_shape: Tuple[int, int]) -> Dict:
+ """
+ Compute navigation guidance based on target position in frame.
+
+ Returns:
+ direction: 'left', 'right', 'center', 'up', 'down'
+ distance: 'far', 'medium', 'close', 'reachable'
+ guidance_text: Human-readable guidance
+ arrow_angle: Angle for AR arrow (degrees)
+ confidence: How confident we are in the guidance
+ """
+ global cc
+
+ if not target_box:
+ return {
+ "direction": "unknown",
+ "distance": "unknown",
+ "guidance_text": "Looking for target...",
+ "arrow_angle": 0,
+ "confidence": 0
+ }
+
+ h, w = frame_shape[:2]
+ x1, y1, x2, y2 = target_box
+
+ # Object center
+ obj_center_x = (x1 + x2) / 2
+ obj_center_y = (y1 + y2) / 2
+
+ # Frame center
+ frame_center_x = w / 2
+ frame_center_y = h / 2
+
+ # Normalized position (-1 to 1, 0 = center)
+ norm_x = (obj_center_x - frame_center_x) / (w / 2)
+ norm_y = (obj_center_y - frame_center_y) / (h / 2)
+
+ # Object size relative to frame
+ obj_width = x2 - x1
+ obj_height = y2 - y1
+ obj_area_ratio = (obj_width * obj_height) / (w * h)
+
+ # Determine direction
+ deadzone = cc.navigation_direction_deadzone
+
+ if abs(norm_x) < deadzone and abs(norm_y) < deadzone:
+ direction = "center"
+ h_dir = ""
+ elif abs(norm_x) > abs(norm_y):
+ direction = "right" if norm_x > 0 else "left"
+ h_dir = direction
+ else:
+ direction = "down" if norm_y > 0 else "up"
+ h_dir = ""
+
+ # Secondary direction
+ if direction in ["center"]:
+ secondary = ""
+ elif direction in ["left", "right"]:
+ if norm_y < -deadzone:
+ secondary = " and up"
+ elif norm_y > deadzone:
+ secondary = " and down"
+ else:
+ secondary = ""
+ else:
+ if norm_x < -deadzone:
+ secondary = " and left"
+ elif norm_x > deadzone:
+ secondary = " and right"
+ else:
+ secondary = ""
+
+ # Determine distance based on object size
+ if obj_area_ratio >= cc.navigation_proximity_threshold:
+ distance = "reachable"
+ elif obj_area_ratio >= cc.navigation_close_threshold:
+ distance = "close"
+ elif obj_area_ratio >= 0.05:
+ distance = "medium"
+ else:
+ distance = "far"
+
+ # Calculate arrow angle (0 = up, 90 = right, etc.)
+ import math
+ arrow_angle = math.degrees(math.atan2(norm_x, -norm_y))
+
+ # Generate guidance text
+ if distance == "reachable":
+ if direction == "center":
+ guidance_text = "Object is directly in front of you, within reach!"
+ else:
+ guidance_text = f"Object is within reach, slightly to the {direction}{secondary}"
+ elif distance == "close":
+ if direction == "center":
+ guidance_text = "Almost there! Object is straight ahead, getting close"
+ else:
+ guidance_text = f"Getting close! Turn {direction}{secondary}"
+ elif distance == "medium":
+ if direction == "center":
+ guidance_text = "Keep moving forward, object ahead"
+ else:
+ guidance_text = f"Object is to the {direction}{secondary}, move that way"
+ else: # far
+ if direction == "center":
+ guidance_text = "Object detected ahead, continue forward"
+ else:
+ guidance_text = f"Object is far to the {direction}{secondary}"
+
+ return {
+ "direction": direction,
+ "secondary": secondary.strip(),
+ "distance": distance,
+ "guidance_text": guidance_text,
+ "arrow_angle": arrow_angle,
+ "norm_x": norm_x,
+ "norm_y": norm_y,
+ "obj_area_ratio": obj_area_ratio,
+ "confidence": min(1.0, obj_area_ratio * 10 + 0.5) # Higher for larger objects
+ }
+
+
+def get_navigation_status() -> Dict:
+ """Get current navigation status and guidance."""
+ global cc
+
+ if not cc.navigation_active:
+ return {
+ "active": False,
+ "target": None,
+ "guidance": None
+ }
+
+ # Find target in current detections
+ target_detection = None
+ for det in cc.current_detections:
+ if det.get("label", "").lower() == cc.navigation_target.lower():
+ target_detection = det
+ break
+ if cc.navigation_target_id is not None and det.get("id") == cc.navigation_target_id:
+ target_detection = det
+ break
+
+ if target_detection:
+ cc.navigation_last_seen = target_detection
+ box = target_detection.get("box")
+
+ if cc.current_raw_frame is not None:
+ frame_shape = cc.current_raw_frame.shape
+ else:
+ frame_shape = (480, 640)
+
+ guidance = compute_navigation_guidance(box, frame_shape)
+
+ # Check if reached
+ if guidance["distance"] == "reachable" and not cc.navigation_reached:
+ cc.navigation_reached = True
+ guidance["reached"] = True
+ guidance["guidance_text"] = f"You've reached the {cc.navigation_target}! It's right in front of you."
+
+ return {
+ "active": True,
+ "target": cc.navigation_target,
+ "target_visible": True,
+ "target_bbox": box, # For AR path rendering
+ "guidance": guidance,
+ "reached": cc.navigation_reached,
+ "context": cc.navigation_context,
+ "duration": time.time() - cc.navigation_start_time if cc.navigation_start_time else 0
+ }
+ else:
+ # Target not currently visible
+ last_guidance = None
+ if cc.navigation_last_seen:
+ box = cc.navigation_last_seen.get("box")
+ if box:
+ frame_shape = (480, 640)
+ if cc.current_raw_frame is not None:
+ frame_shape = cc.current_raw_frame.shape
+ last_guidance = compute_navigation_guidance(box, frame_shape)
+ last_guidance["guidance_text"] = f"Lost sight of {cc.navigation_target}. Last seen to the {last_guidance['direction']}"
+
+ return {
+ "active": True,
+ "target": cc.navigation_target,
+ "target_visible": False,
+ "guidance": last_guidance or {
+ "direction": "unknown",
+ "distance": "unknown",
+ "guidance_text": f"Looking for {cc.navigation_target}... Turn slowly to scan the area",
+ "arrow_angle": 0
+ },
+ "reached": False,
+ "context": cc.navigation_context,
+ "searching": True
+ }
+
+
+def get_coco_class_for_label(sam3_label: str) -> Optional[int]:
+ """Get COCO class ID for a SAM3 label using the mapping."""
+ label_lower = sam3_label.lower().strip()
+
+ # Direct lookup
+ if label_lower in SAM3_TO_COCO:
+ return SAM3_TO_COCO[label_lower]
+
+ # Try partial match
+ for key, coco_id in SAM3_TO_COCO.items():
+ if key in label_lower or label_lower in key:
+ return coco_id
+
+ return None
+
+
+def is_person_label(label: str) -> bool:
+ """Check if a label refers to a person."""
+ coco_id = get_coco_class_for_label(label)
+ return coco_id == 0
+
+
+def classify_region(frame: np.ndarray, box: List[float], sam3_label: str) -> Dict:
+ """
+ Run YOLO classification on a detected region.
+
+ Returns dict with:
+ - yolo_class: Top predicted class name
+ - yolo_confidence: Confidence score
+ - top5_classes: List of top 5 (class, confidence) tuples
+ - matches_sam3: Whether YOLO agrees with SAM3 label
+ """
+ if cc.yolo_classify_model is None:
+ return None
+
+ try:
+ # Crop region from frame
+ x1, y1, x2, y2 = [int(v) for v in box]
+ h, w = frame.shape[:2]
+
+ # Add padding
+ pad = 10
+ x1 = max(0, x1 - pad)
+ y1 = max(0, y1 - pad)
+ x2 = min(w, x2 + pad)
+ y2 = min(h, y2 + pad)
+
+ crop = frame[y1:y2, x1:x2]
+
+ if crop.size == 0:
+ return None
+
+ # Run classification
+ results = cc.yolo_classify_model(crop, verbose=False)
+
+ if len(results) == 0:
+ return None
+
+ probs = results[0].probs
+
+ if probs is None:
+ return None
+
+ # Get top 5 predictions
+ top5_indices = probs.top5
+ top5_confs = probs.top5conf.cpu().numpy()
+
+ # Get class names from model
+ names = cc.yolo_classify_model.names
+
+ top5_classes = [(names[idx], float(conf)) for idx, conf in zip(top5_indices, top5_confs)]
+
+ top_class = top5_classes[0][0] if top5_classes else "unknown"
+ top_conf = top5_classes[0][1] if top5_classes else 0.0
+
+ # Check if YOLO agrees with SAM3
+ sam3_coco_id = get_coco_class_for_label(sam3_label)
+ matches = False
+
+ if sam3_coco_id is not None and sam3_coco_id < len(COCO_CLASSES):
+ sam3_coco_name = COCO_CLASSES[sam3_coco_id]
+ # Check if any top-5 class matches
+ for cls_name, conf in top5_classes:
+ if cls_name.lower() == sam3_coco_name.lower() or sam3_coco_name.lower() in cls_name.lower():
+ matches = True
+ break
+
+ return {
+ "yolo_class": top_class,
+ "yolo_confidence": top_conf,
+ "top5_classes": top5_classes,
+ "matches_sam3": matches
+ }
+
+ except Exception as e:
+ cc.log(f"YOLO classification error: {e}", "ERROR")
+ return None
+
+
+def estimate_pose(frame: np.ndarray, box: List[float]) -> Dict:
+ """
+ Run YOLO pose estimation on a person region.
+
+ Returns dict with:
+ - keypoints: List of 17 (x, y, confidence) tuples
+ - keypoint_names: List of keypoint names
+ - confidence: Overall pose confidence
+ """
+ if cc.yolo_pose_model is None:
+ return None
+
+ try:
+ # Crop region from frame (with extra padding for pose)
+ x1, y1, x2, y2 = [int(v) for v in box]
+ h, w = frame.shape[:2]
+
+ # Add generous padding for pose estimation
+ box_w = x2 - x1
+ box_h = y2 - y1
+ pad_x = int(box_w * 0.2)
+ pad_y = int(box_h * 0.1)
+
+ x1 = max(0, x1 - pad_x)
+ y1 = max(0, y1 - pad_y)
+ x2 = min(w, x2 + pad_x)
+ y2 = min(h, y2 + pad_y)
+
+ crop = frame[y1:y2, x1:x2]
+
+ if crop.size == 0:
+ return None
+
+ # Run pose estimation
+ results = cc.yolo_pose_model(crop, verbose=False)
+
+ if len(results) == 0 or results[0].keypoints is None:
+ return None
+
+ keypoints_data = results[0].keypoints
+
+ if keypoints_data.xy is None or len(keypoints_data.xy) == 0:
+ return None
+
+ # Get first person's keypoints (we're analyzing one box at a time)
+ kpts = keypoints_data.xy[0].cpu().numpy() # Shape: (17, 2)
+ confs = keypoints_data.conf[0].cpu().numpy() if keypoints_data.conf is not None else np.ones(17)
+
+ # Convert coordinates back to full frame
+ keypoints = []
+ for i, (pt, conf) in enumerate(zip(kpts, confs)):
+ # Add offset back to get coordinates in original frame
+ frame_x = pt[0] + x1
+ frame_y = pt[1] + y1
+ keypoints.append((float(frame_x), float(frame_y), float(conf)))
+
+ # Calculate overall confidence
+ valid_confs = [c for x, y, c in keypoints if c > 0.1]
+ overall_conf = np.mean(valid_confs) if valid_confs else 0.0
+
+ return {
+ "keypoints": keypoints,
+ "keypoint_names": POSE_KEYPOINTS,
+ "confidence": float(overall_conf),
+ "box_offset": (x1, y1) # For reference
+ }
+
+ except Exception as e:
+ cc.log(f"YOLO pose estimation error: {e}", "ERROR")
+ return None
+
+
+def draw_pose_overlay(frame: np.ndarray, pose_data: Dict, object_id: int = None) -> np.ndarray:
+ """Draw pose keypoints and skeleton on frame."""
+ if pose_data is None or "keypoints" not in pose_data:
+ return frame
+
+ overlay = frame.copy()
+ keypoints = pose_data["keypoints"]
+
+ # Draw skeleton connections first (so points are on top)
+ if cc.show_skeleton:
+ for start_idx, end_idx in SKELETON_CONNECTIONS:
+ if start_idx < len(keypoints) and end_idx < len(keypoints):
+ x1, y1, c1 = keypoints[start_idx]
+ x2, y2, c2 = keypoints[end_idx]
+
+ # Only draw if both points have sufficient confidence
+ if c1 > cc.yolo_pose_threshold and c2 > cc.yolo_pose_threshold:
+ pt1 = (int(x1), int(y1))
+ pt2 = (int(x2), int(y2))
+
+ # Get color based on connection type
+ color = get_keypoint_color(start_idx)
+ cv2.line(overlay, pt1, pt2, color, cc.skeleton_thickness)
+
+ # Draw keypoints
+ for i, (x, y, conf) in enumerate(keypoints):
+ if conf > cc.yolo_pose_threshold:
+ pt = (int(x), int(y))
+ color = get_keypoint_color(i)
+
+ # Draw filled circle
+ cv2.circle(overlay, pt, cc.keypoint_radius, color, -1)
+ # Draw outline
+ cv2.circle(overlay, pt, cc.keypoint_radius, (255, 255, 255), 1)
+
+ # Draw label if enabled
+ if cc.show_keypoint_labels and i < len(POSE_KEYPOINTS):
+ label = POSE_KEYPOINTS[i].replace('_', ' ')
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ font_scale = 0.35
+ (tw, th), _ = cv2.getTextSize(label, font, font_scale, 1)
+
+ # Position label above point
+ label_x = int(x - tw/2)
+ label_y = int(y - cc.keypoint_radius - 3)
+
+ # Background
+ cv2.rectangle(overlay,
+ (label_x - 1, label_y - th - 1),
+ (label_x + tw + 1, label_y + 1),
+ (0, 0, 0), -1)
+
+ # Text
+ cv2.putText(overlay, label, (label_x, label_y),
+ font, font_scale, (255, 255, 255), 1)
+
+ return overlay
+
+
+def load_model(checkpoint_path: Optional[str] = None):
+ """Load the SAM3 model."""
+ from sam3.model_builder import build_sam3_image_model
+ from sam3.model.sam3_image_processor import Sam3Processor
+
+ cc.log("Loading SAM3 model...")
+ cc.device_str = get_device_str()
+
+ # Setup device-specific optimizations (MPS memory, CUDA TF32, etc.)
+ setup_device_optimizations()
+ cc.log(f"Device optimizations enabled for {cc.device_str}")
+
+ model = build_sam3_image_model(
+ device=cc.device_str,
+ checkpoint_path=checkpoint_path,
+ load_from_HF=checkpoint_path is None,
+ eval_mode=True,
+ enable_segmentation=True,
+ )
+
+ cc.processor = Sam3Processor(
+ model=model,
+ resolution=1008,
+ device=cc.device_str,
+ confidence_threshold=cc.confidence_threshold,
+ )
+
+ cc.log(f"Model loaded on {cc.device_str}", "SUCCESS")
+
+ # Load YOLO models
+ load_yolo_models()
+
+ # Load CLIP model for visual matching (optional)
+ load_clip_model()
+
+
+# ===== CAMERA FUNCTIONS =====
+
+def detect_available_cameras(max_cameras: int = 10) -> List[Dict]:
+ """
+ Detect available cameras on the system.
+
+ Returns list of dicts with:
+ - id: Camera index
+ - name: Camera name/description
+ - resolution: (width, height) if detectable
+ """
+ cameras = []
+
+ for i in range(max_cameras):
+ cap = cv2.VideoCapture(i)
+ if cap.isOpened():
+ # Get camera properties
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ fps = cap.get(cv2.CAP_PROP_FPS)
+
+ # Try to get backend name
+ backend = cap.getBackendName()
+
+ # Create descriptive name
+ if i == 0:
+ name = "Default Camera"
+ else:
+ name = f"Camera {i}"
+
+ # Add platform-specific hints
+ import platform
+ if platform.system() == "Darwin": # macOS
+ if i == 0:
+ name = "FaceTime HD Camera (Built-in)"
+ elif i == 1:
+ name = "External Camera"
+ elif platform.system() == "Linux":
+ # Try to read device name from v4l2
+ try:
+ import subprocess
+ result = subprocess.run(
+ ['v4l2-ctl', '--device', f'/dev/video{i}', '--info'],
+ capture_output=True, text=True, timeout=1
+ )
+ for line in result.stdout.split('\n'):
+ if 'Card type' in line:
+ name = line.split(':')[1].strip()
+ break
+ except Exception:
+ pass
+
+ cameras.append({
+ "id": i,
+ "name": name,
+ "resolution": f"{width}x{height}",
+ "fps": fps,
+ "backend": backend,
+ "description": f"{name} ({width}x{height} @ {fps:.0f}fps)"
+ })
+
+ cap.release()
+
+ return cameras
+
+
+def switch_camera(camera_id: int) -> bool:
+ """Switch to a different camera and reset detection state."""
+ global cc
+
+ cc.log(f"Switching to camera {camera_id}...")
+
+ # Release current camera
+ if cc.camera is not None:
+ cc.camera.release()
+ cc.camera = None
+
+ # Open new camera
+ new_camera = cv2.VideoCapture(camera_id)
+
+ if not new_camera.isOpened():
+ cc.log(f"Failed to open camera {camera_id}", "ERROR")
+ # Try to reopen previous camera
+ cc.camera = cv2.VideoCapture(cc.current_camera_id)
+ return False
+
+ cc.camera = new_camera
+ cc.current_camera_id = camera_id
+
+ # Get camera info
+ width = int(cc.camera.get(cv2.CAP_PROP_FRAME_WIDTH))
+ height = int(cc.camera.get(cv2.CAP_PROP_FRAME_HEIGHT))
+
+ # Reset detection state
+ reset_detection_state()
+
+ cc.log(f"Switched to camera {camera_id} ({width}x{height})", "SUCCESS")
+ return True
+
+
+def reset_detection_state():
+ """Reset all detection state for a fresh start."""
+ global cc
+
+ cc.state = None
+ cc.last_masks = None
+ cc.last_boxes = None
+ cc.last_scores = None
+ cc.last_labels = None
+ cc.tracked_objects = {}
+ cc.memory_bank = {}
+ cc.object_colors = {}
+ cc.next_object_id = 1
+ cc.pending_detections = {}
+ cc.last_poses = {}
+ cc.prev_gray = None
+ cc.current_detections = []
+ cc.frame_count = 0
+
+
+# ===== MASK REFINEMENT FUNCTIONS =====
+
+def fill_holes_in_mask(mask: np.ndarray, max_hole_area: int = 100) -> np.ndarray:
+ """Fill small holes in a binary mask."""
+ mask_bool = mask.astype(bool)
+ # Find holes (inverted connected components)
+ inverted = ~mask_bool
+ labeled, num_features = ndimage.label(inverted)
+
+ # Fill small holes
+ for i in range(1, num_features + 1):
+ hole = labeled == i
+ if hole.sum() <= max_hole_area:
+ mask_bool[hole] = True
+
+ return mask_bool.astype(np.float32)
+
+
+def smooth_mask_edges(mask: np.ndarray, kernel_size: int = 5) -> np.ndarray:
+ """Smooth mask edges using morphological operations."""
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
+ # Close then open to smooth
+ smoothed = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
+ smoothed = cv2.morphologyEx(smoothed, cv2.MORPH_OPEN, kernel)
+ return smoothed.astype(np.float32)
+
+
+def remove_mask_overlaps(masks: List[np.ndarray], scores: List[float]) -> List[np.ndarray]:
+ """Remove overlapping regions, keeping higher confidence masks."""
+ if len(masks) <= 1:
+ return masks
+
+ # Sort by score (highest first)
+ sorted_indices = np.argsort(scores)[::-1]
+ result_masks = [None] * len(masks)
+ occupied = np.zeros_like(masks[0], dtype=bool)
+
+ for idx in sorted_indices:
+ mask = masks[idx].astype(bool)
+ # Remove already occupied regions
+ mask = mask & ~occupied
+ result_masks[idx] = mask.astype(np.float32)
+ occupied |= mask
+
+ return result_masks
+
+
+# ===== DETECTION CONTROL FUNCTIONS =====
+
+def is_near_boundary(box: List[float], frame_shape: Tuple[int, int], margin: int = 10) -> bool:
+ """Check if a bounding box is near the frame boundary."""
+ h, w = frame_shape[:2]
+ x1, y1, x2, y2 = box
+ return x1 < margin or y1 < margin or x2 > w - margin or y2 > h - margin
+
+
+def calculate_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
+ """Calculate Intersection over Union between two masks."""
+ intersection = np.logical_and(mask1, mask2).sum()
+ union = np.logical_or(mask1, mask2).sum()
+ return intersection / union if union > 0 else 0
+
+
+def match_detection_to_object(mask: np.ndarray, existing_masks: Dict[int, np.ndarray],
+ threshold: float = 0.3) -> Optional[int]:
+ """Match a detection to an existing tracked object by IoU."""
+ best_match = None
+ best_iou = threshold
+
+ for obj_id, existing_mask in existing_masks.items():
+ iou = calculate_iou(mask, existing_mask)
+ if iou > best_iou:
+ best_iou = iou
+ best_match = obj_id
+
+ return best_match
+
+
+def get_bounding_box_from_mask(mask: np.ndarray) -> Optional[List[float]]:
+ """Extract bounding box from a binary mask."""
+ if mask is None or mask.sum() == 0:
+ return None
+
+ rows = np.any(mask, axis=1)
+ cols = np.any(mask, axis=0)
+
+ if not rows.any() or not cols.any():
+ return None
+
+ y_min, y_max = np.where(rows)[0][[0, -1]]
+ x_min, x_max = np.where(cols)[0][[0, -1]]
+
+ return [float(x_min), float(y_min), float(x_max), float(y_max)]
+
+
+def is_mask_valid(mask: np.ndarray, frame_shape: Tuple[int, int], min_area: int = 50,
+ boundary_margin: int = 5) -> bool:
+ """
+ Check if a tracked mask is still valid.
+ Returns False if:
+ - Mask is too small (object left frame)
+ - Mask is mostly outside the frame boundaries
+ """
+ if mask is None:
+ return False
+
+ mask_area = mask.sum()
+ if mask_area < min_area:
+ return False
+
+ h, w = frame_shape[:2]
+
+ # Check if mask is mostly within frame bounds
+ if mask.shape != (h, w):
+ return False
+
+ # Get bounding box
+ box = get_bounding_box_from_mask(mask)
+ if box is None:
+ return False
+
+ x1, y1, x2, y2 = box
+
+ # Check if box is mostly outside frame
+ if x2 < boundary_margin or x1 > w - boundary_margin:
+ return False
+ if y2 < boundary_margin or y1 > h - boundary_margin:
+ return False
+
+ return True
+
+
+def update_detections_from_tracked_masks(tracked_masks: torch.Tensor, frame_shape: Tuple[int, int]):
+ """
+ Update current_detections based on tracked masks.
+ Removes detections for masks that are no longer valid (left frame).
+ Updates bounding boxes for masks that moved.
+ """
+ global cc
+
+ if tracked_masks is None or len(cc.current_detections) == 0:
+ return
+
+ h, w = frame_shape[:2]
+ masks_np = tracked_masks.squeeze(1).cpu().numpy()
+
+ updated_detections = []
+ valid_mask_indices = []
+
+ for i, det in enumerate(cc.current_detections):
+ if i >= len(masks_np):
+ break
+
+ mask = masks_np[i]
+ if mask.shape != (h, w):
+ mask = cv2.resize(mask.astype(np.float32), (w, h)) > 0.5
+
+ # Check if mask is still valid
+ if is_mask_valid(mask, frame_shape):
+ # Update bounding box from tracked mask
+ new_box = get_bounding_box_from_mask(mask)
+ if new_box:
+ det = det.copy() # Don't modify original
+ det["box"] = new_box
+ det["tracked"] = True # Mark as being tracked (not fresh detection)
+ updated_detections.append(det)
+ valid_mask_indices.append(i)
+ else:
+ # Object has left the frame or tracking failed
+ label = det.get("label", "object")
+ obj_id = det.get("id", i)
+ cc.log(f"Object #{obj_id} ({label}) left frame", "INFO")
+
+ # Update global state
+ with cc.lock:
+ cc.current_detections = updated_detections
+
+ return valid_mask_indices
+
+
+# ===== MEMORY TRACKING FUNCTIONS =====
+
+def update_memory_bank(object_id: int, mask_features: torch.Tensor):
+ """Update memory bank for an object."""
+ if object_id not in cc.memory_bank:
+ cc.memory_bank[object_id] = []
+
+ cc.memory_bank[object_id].append(mask_features)
+
+ # Keep only recent frames
+ if len(cc.memory_bank[object_id]) > cc.memory_max_frames:
+ cc.memory_bank[object_id].pop(0)
+
+
+# ===== ADVANCED MONOCULAR DEPTH ESTIMATION =====
+# Proprietary: LIDAR-like depth from single RGB camera using AI
+
+_depth_model = None
+_depth_transform = None
+_depth_available = False
+_depth_device = None
+
+# Optical flow state for motion-based collision detection
+_prev_flow_frame = None
+_obstacle_tracking = {} # Track obstacles over time for approach detection
+
+
+def load_depth_model():
+ """
+ Load monocular depth estimation model.
+ Provides LIDAR-like depth perception from a single RGB camera.
+
+ Tries models in order of quality:
+ 1. Depth Anything (state-of-the-art 2024)
+ 2. MiDaS (widely compatible)
+ """
+ global _depth_model, _depth_transform, _depth_available, _depth_device
+
+ if _depth_available:
+ return True
+
+ _depth_device = torch.device("cuda" if torch.cuda.is_available() else
+ "mps" if torch.backends.mps.is_available() else "cpu")
+
+ # Try Depth Anything first (best quality, 2024 state-of-the-art)
+ try:
+ from transformers import pipeline
+ _depth_model = pipeline("depth-estimation",
+ model="LiheYoung/depth-anything-small-hf",
+ device=0 if torch.cuda.is_available() else -1)
+ _depth_available = True
+ print(f"✓ Loaded Depth Anything for monocular depth estimation")
+ return True
+ except Exception as e:
+ print(f" Depth Anything not available: {e}")
+
+ # Try MiDaS (more compatible)
+ try:
+ _depth_model = torch.hub.load("intel-isl/MiDaS", "MiDaS_small", trust_repo=True)
+ _depth_model.to(_depth_device)
+ _depth_model.eval()
+
+ midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms", trust_repo=True)
+ _depth_transform = midas_transforms.small_transform
+
+ _depth_available = True
+ print(f"✓ Loaded MiDaS for monocular depth estimation on {_depth_device}")
+ return True
+ except Exception as e:
+ print(f" MiDaS not available: {e}")
+
+ print(" No depth model available - using edge-based detection only")
+ return False
+
+
+def estimate_depth(frame: np.ndarray) -> Optional[np.ndarray]:
+ """
+ Estimate depth map from a single RGB image.
+
+ Returns depth map where HIGHER values = CLOSER to camera.
+ This mimics LIDAR point cloud distance measurement.
+ """
+ global _depth_model, _depth_transform, _depth_available, _depth_device
+
+ if not _depth_available or _depth_model is None:
+ return None
+
+ try:
+ # Depth Anything (pipeline-based)
+ if hasattr(_depth_model, '__call__') and hasattr(_depth_model, 'task'):
+ pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
+ result = _depth_model(pil_image)
+ depth = np.array(result["depth"])
+
+ # Resize to match frame
+ depth = cv2.resize(depth, (frame.shape[1], frame.shape[0]))
+
+ # Normalize and invert (so closer = higher value)
+ depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-6)
+ depth = 1.0 - depth # Invert
+ depth = (depth * 255).astype(np.uint8)
+
+ return depth
+
+ # MiDaS model
+ else:
+ img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ input_batch = _depth_transform(img_rgb).to(_depth_device)
+
+ with torch.no_grad():
+ prediction = _depth_model(input_batch)
+ prediction = torch.nn.functional.interpolate(
+ prediction.unsqueeze(1),
+ size=frame.shape[:2],
+ mode="bicubic",
+ align_corners=False,
+ ).squeeze()
+
+ depth = prediction.cpu().numpy()
+
+ # Normalize (MiDaS: higher = further, so we invert)
+ depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-6)
+ depth = 1.0 - depth # Invert so closer = higher
+ depth = (depth * 255).astype(np.uint8)
+
+ return depth
+
+ except Exception as e:
+ cc.log(f"Depth estimation error: {e}", "ERROR")
+ return None
+
+
+def detect_obstacles_depth(frame: np.ndarray, depth_map: np.ndarray) -> List[Dict]:
+ """
+ Detect obstacles using AI-generated depth map.
+
+ This is MORE ACCURATE than edge detection because it knows actual distance,
+ not just "there's something there".
+ """
+ if depth_map is None:
+ return []
+
+ h, w = frame.shape[:2]
+ obstacles = []
+
+ try:
+ # Focus on walking path (center and bottom of frame)
+ path_mask = np.zeros_like(depth_map)
+ path_mask[h // 3:, w // 6:5 * w // 6] = 1
+
+ path_depth = depth_map * path_mask
+
+ # Thresholds for proximity (calibrated for normalized 0-255 depth)
+ very_close_thresh = 200 # Within arm's reach
+ close_thresh = 150 # Few steps away
+ medium_thresh = 100 # Room distance
+
+ # Find very close obstacles
+ very_close_mask = (path_depth > very_close_thresh).astype(np.uint8) * 255
+ close_mask = ((path_depth > close_thresh) & (path_depth <= very_close_thresh)).astype(np.uint8) * 255
+
+ # Morphological cleanup
+ kernel = np.ones((7, 7), np.uint8)
+ very_close_mask = cv2.morphologyEx(very_close_mask, cv2.MORPH_CLOSE, kernel)
+ very_close_mask = cv2.morphologyEx(very_close_mask, cv2.MORPH_OPEN, kernel)
+
+ # Find contours for very close obstacles
+ contours, _ = cv2.findContours(very_close_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+
+ min_area = (h * w) * 0.01 # 1% of frame minimum
+
+ for contour in contours:
+ area = cv2.contourArea(contour)
+ if area < min_area:
+ continue
+
+ x, y, cw, ch = cv2.boundingRect(contour)
+
+ # Get depth stats for this region
+ region_depth = depth_map[y:y+ch, x:x+cw]
+ avg_depth = np.mean(region_depth)
+ max_depth = np.max(region_depth)
+
+ # Classify severity
+ if max_depth > very_close_thresh:
+ severity = "high"
+ distance = "very_close"
+ elif max_depth > close_thresh:
+ severity = "medium"
+ distance = "close"
+ else:
+ severity = "low"
+ distance = "medium"
+
+ # Position classification
+ center_x = x + cw // 2
+ if center_x < w // 3:
+ position = "left"
+ elif center_x > 2 * w // 3:
+ position = "right"
+ else:
+ position = "center"
+
+ # Track this obstacle over time
+ obstacle_id = f"depth_{position}_{int(avg_depth)}"
+ approach_info = track_obstacle_approach(obstacle_id, avg_depth, position)
+
+ obstacles.append({
+ "label": "obstacle (depth)",
+ "type": severity,
+ "position": position,
+ "distance": distance,
+ "box": [x, y, x + cw, y + ch],
+ "depth_value": float(avg_depth),
+ "max_depth": float(max_depth),
+ "area_pct": float(area / (h * w) * 100),
+ "approaching": approach_info.get("approaching", False),
+ "approach_rate": approach_info.get("rate", 0),
+ "time_to_collision": approach_info.get("ttc"),
+ "source": "depth_ai",
+ "reason": f"Depth AI: {avg_depth:.0f}/255 proximity"
+ })
+
+ except Exception as e:
+ cc.log(f"Depth obstacle detection error: {e}", "ERROR")
+
+ return obstacles
+
+
+def detect_collision_optical_flow(frame: np.ndarray) -> List[Dict]:
+ """
+ Detect approaching obstacles using optical flow expansion.
+
+ PROPRIETARY TECHNIQUE: Objects approaching you EXPAND in the frame.
+ This mimics how flying insects detect and avoid collisions!
+
+ Physics: An object moving toward you at constant speed will appear to
+ grow larger. The rate of expansion indicates approach speed.
+ """
+ global _prev_flow_frame, _obstacle_tracking
+
+ h, w = frame.shape[:2]
+ obstacles = []
+
+ try:
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+ gray = cv2.GaussianBlur(gray, (5, 5), 0)
+
+ if _prev_flow_frame is None:
+ _prev_flow_frame = gray
+ return []
+
+ # Dense optical flow (Farneback method)
+ flow = cv2.calcOpticalFlowFarneback(
+ _prev_flow_frame, gray, None,
+ pyr_scale=0.5, levels=3, winsize=15,
+ iterations=3, poly_n=5, poly_sigma=1.2, flags=0
+ )
+
+ _prev_flow_frame = gray
+
+ # Analyze expansion in different regions
+ regions = {
+ "left": (0, h // 4, w // 3, 3 * h // 4),
+ "center": (w // 3, h // 4, 2 * w // 3, 3 * h // 4),
+ "right": (2 * w // 3, h // 4, w, 3 * h // 4),
+ "floor": (w // 4, 2 * h // 3, 3 * w // 4, h),
+ }
+
+ for region_name, (x1, y1, x2, y2) in regions.items():
+ region_flow = flow[y1:y2, x1:x2]
+ rh, rw = region_flow.shape[:2]
+
+ if rh < 10 or rw < 10:
+ continue
+
+ fx = region_flow[:, :, 0]
+ fy = region_flow[:, :, 1]
+
+ # Flow magnitude
+ magnitude = np.sqrt(fx**2 + fy**2)
+ avg_magnitude = np.mean(magnitude)
+
+ # Skip if no significant motion
+ if avg_magnitude < 1.0:
+ continue
+
+ # Calculate EXPANSION: do flow vectors point outward from center?
+ # This is the key insight - approaching objects expand!
+ center_y, center_x = rh // 2, rw // 2
+ y_coords, x_coords = np.meshgrid(np.arange(rh) - center_y,
+ np.arange(rw) - center_x, indexing='ij')
+
+ # Outward direction from center
+ dist = np.sqrt(x_coords**2 + y_coords**2) + 1e-6
+ out_x = x_coords / dist
+ out_y = y_coords / dist
+
+ # Dot product: positive = expanding (approaching)
+ expansion = fx * out_x + fy * out_y
+ avg_expansion = np.mean(expansion)
+
+ # Temporal smoothing
+ key = f"flow_{region_name}"
+ if key not in _obstacle_tracking:
+ _obstacle_tracking[key] = []
+ _obstacle_tracking[key].append(avg_expansion)
+ _obstacle_tracking[key] = _obstacle_tracking[key][-15:]
+
+ smoothed = np.mean(_obstacle_tracking[key])
+
+ # Threshold for collision warning
+ if smoothed > 0.8 and avg_magnitude > 1.5:
+ if smoothed > 2.0:
+ severity = "high"
+ distance = "very_close"
+ elif smoothed > 1.2:
+ severity = "medium"
+ distance = "close"
+ else:
+ severity = "low"
+ distance = "medium"
+
+ obstacles.append({
+ "label": "approaching",
+ "type": severity,
+ "position": region_name,
+ "distance": distance,
+ "box": [x1, y1, x2, y2],
+ "expansion_rate": float(smoothed),
+ "flow_magnitude": float(avg_magnitude),
+ "source": "optical_flow",
+ "reason": f"Motion expansion: {smoothed:.1f}x (collision trajectory)"
+ })
+
+ except Exception as e:
+ cc.log(f"Optical flow error: {e}", "ERROR")
+
+ return obstacles
+
+
+def segment_walkable_ground(frame: np.ndarray, depth_map: np.ndarray = None) -> Dict:
+ """
+ Segment walkable floor area from obstacles.
+
+ PROPRIETARY: Combines color consistency + depth + geometry to find safe walking path.
+ """
+ h, w = frame.shape[:2]
+
+ try:
+ # Sample ground color from bottom center (assumed floor)
+ sample = frame[h - 60:h - 20, w // 3:2 * w // 3]
+ ground_mean = np.mean(sample, axis=(0, 1))
+ ground_std = np.std(sample, axis=(0, 1))
+
+ # Color-based ground mask
+ lower = np.clip(ground_mean - 2.5 * ground_std, 0, 255).astype(np.uint8)
+ upper = np.clip(ground_mean + 2.5 * ground_std, 0, 255).astype(np.uint8)
+ color_mask = cv2.inRange(frame, lower, upper)
+
+ # If depth available, refine with depth consistency
+ if depth_map is not None:
+ ground_depth = np.median(depth_map[3 * h // 4:, w // 3:2 * w // 3])
+ depth_tolerance = 40
+ depth_mask = np.abs(depth_map.astype(float) - ground_depth) < depth_tolerance
+ combined_mask = cv2.bitwise_and(color_mask, depth_mask.astype(np.uint8) * 255)
+ else:
+ combined_mask = color_mask
+
+ # Morphological cleanup
+ kernel = np.ones((7, 7), np.uint8)
+ combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_CLOSE, kernel)
+ combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_OPEN, kernel)
+
+ # Analyze walkability per region
+ left_walk = np.mean(combined_mask[h // 2:, :w // 3] > 0)
+ center_walk = np.mean(combined_mask[h // 2:, w // 3:2 * w // 3] > 0)
+ right_walk = np.mean(combined_mask[h // 2:, 2 * w // 3:] > 0)
+
+ paths = {"left": left_walk, "center": center_walk, "right": right_walk}
+ best = max(paths, key=paths.get)
+ worst = min(paths, key=paths.get)
+
+ return {
+ "walkable_mask": combined_mask,
+ "left": float(left_walk),
+ "center": float(center_walk),
+ "right": float(right_walk),
+ "best_path": best,
+ "blocked_path": worst,
+ "confidence": float(max(paths.values()))
+ }
+
+ except Exception as e:
+ cc.log(f"Ground segmentation error: {e}", "ERROR")
+ return {"best_path": "center", "confidence": 0.0}
+
+
+def track_obstacle_approach(obstacle_id: str, current_depth: float, position: str) -> Dict:
+ """
+ Track obstacle over time to detect if it's getting closer.
+
+ Returns approach rate and estimated time-to-collision.
+ """
+ global _obstacle_tracking
+
+ current_time = time.time()
+ key = f"approach_{obstacle_id}"
+
+ if key not in _obstacle_tracking:
+ _obstacle_tracking[key] = {"history": [], "first_seen": current_time}
+
+ _obstacle_tracking[key]["history"].append({
+ "time": current_time,
+ "depth": current_depth
+ })
+
+ # Keep last 30 readings
+ _obstacle_tracking[key]["history"] = _obstacle_tracking[key]["history"][-30:]
+
+ history = _obstacle_tracking[key]["history"]
+
+ if len(history) < 5:
+ return {"approaching": False, "rate": 0, "ttc": None}
+
+ # Calculate approach rate
+ time_span = history[-1]["time"] - history[0]["time"]
+ if time_span < 0.1:
+ return {"approaching": False, "rate": 0, "ttc": None}
+
+ depth_change = history[-1]["depth"] - history[0]["depth"]
+ rate = depth_change / time_span
+
+ # Positive rate = getting closer (depth increasing)
+ approaching = rate > 3
+
+ # Time to collision estimate
+ ttc = None
+ if approaching and rate > 0:
+ remaining = 255 - history[-1]["depth"]
+ ttc = remaining / rate if rate > 0 else None
+
+ return {
+ "approaching": approaching,
+ "rate": float(rate),
+ "ttc": float(ttc) if ttc and ttc > 0 else None
+ }
+
+
+# ===== OBSTACLE DETECTION =====
+
+# Timing control for Claude obstacle analysis (don't call too frequently)
+_last_obstacle_analysis_time = 0
+_obstacle_analysis_interval = 3.0 # Seconds between Claude calls
+_cached_obstacles = []
+_cached_opencv_obstacles = []
+
+
+# ===== OPENCV REAL-TIME OBSTACLE DETECTION =====
+# Fast, runs every frame - detects "something is there"
+# Complements Claude AI which understands "what is it and is it dangerous"
+
+def detect_obstacles_opencv(frame: np.ndarray) -> List[Dict]:
+ """
+ Real-time obstacle detection using OpenCV techniques.
+
+ This runs every frame and detects:
+ 1. Large objects/edges in the path (via Canny edge detection)
+ 2. Proximity based on edge density in regions
+ 3. Floor-level obstacles (via bottom-region analysis)
+
+ This is FAST but DUMB - it detects "something is there" but doesn't know what.
+ Claude AI provides the smart context about whether it's actually dangerous.
+ """
+ global cc
+
+ h, w = frame.shape[:2]
+ obstacles = []
+
+ try:
+ # Convert to grayscale
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+
+ # Apply bilateral filter to reduce noise while preserving edges
+ # This is key for obstacle detection - keeps edges sharp
+ filtered = cv2.bilateralFilter(gray, 9, 75, 75)
+
+ # Canny edge detection
+ edges = cv2.Canny(filtered, 50, 150)
+
+ # Dilate edges to connect nearby edges
+ kernel = np.ones((3, 3), np.uint8)
+ edges_dilated = cv2.dilate(edges, kernel, iterations=2)
+
+ # Define regions of interest (ROI) for obstacle detection
+ # Focus on center and bottom of frame (where obstacles matter for walking)
+ regions = {
+ "center_close": (w // 4, h // 2, 3 * w // 4, h - 50), # Center, bottom half
+ "left_path": (0, h // 2, w // 4, h - 50), # Left side path
+ "right_path": (3 * w // 4, h // 2, w, h - 50), # Right side path
+ "floor_immediate": (w // 6, 2 * h // 3, 5 * w // 6, h), # Immediate floor area
+ }
+
+ for region_name, (x1, y1, x2, y2) in regions.items():
+ # Extract region
+ roi = edges_dilated[y1:y2, x1:x2]
+
+ if roi.size == 0:
+ continue
+
+ # Calculate edge density in this region
+ edge_density = np.sum(roi > 0) / roi.size
+
+ # Find contours in this region
+ contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+
+ # Filter significant contours (large enough to be obstacles)
+ min_contour_area = (x2 - x1) * (y2 - y1) * 0.05 # At least 5% of region
+ significant_contours = [c for c in contours if cv2.contourArea(c) > min_contour_area]
+
+ # Determine if this region has an obstacle
+ # High edge density + significant contours = likely obstacle
+ if edge_density > 0.15 and len(significant_contours) > 0:
+ # Estimate proximity based on position in frame
+ # Objects lower in frame = closer
+ vertical_position = (y1 + y2) / 2 / h
+
+ if vertical_position > 0.8: # Very low in frame
+ distance = "very_close"
+ severity = "high"
+ elif vertical_position > 0.65:
+ distance = "close"
+ severity = "medium"
+ else:
+ distance = "medium"
+ severity = "low"
+
+ # Map region to position
+ if "left" in region_name:
+ position = "left"
+ elif "right" in region_name:
+ position = "right"
+ elif "floor" in region_name:
+ position = "floor"
+ else:
+ position = "center"
+
+ # Get the largest contour for this obstacle
+ largest_contour = max(significant_contours, key=cv2.contourArea)
+ contour_box = cv2.boundingRect(largest_contour)
+
+ # Adjust box coordinates to full frame
+ box = [
+ x1 + contour_box[0],
+ y1 + contour_box[1],
+ x1 + contour_box[0] + contour_box[2],
+ y1 + contour_box[1] + contour_box[3]
+ ]
+
+ obstacles.append({
+ "label": f"obstacle ({position})",
+ "type": severity,
+ "position": position,
+ "distance": distance,
+ "box": box,
+ "edge_density": edge_density,
+ "contour_count": len(significant_contours),
+ "source": "opencv",
+ "reason": f"Edge detection: {edge_density:.0%} density"
+ })
+
+ # Also check for sudden large objects in center (collision imminent)
+ center_roi = edges_dilated[h // 3:, w // 4:3 * w // 4]
+ center_density = np.sum(center_roi > 0) / center_roi.size
+
+ if center_density > 0.25: # Very high edge density in center
+ # This suggests a large object directly ahead
+ obstacles.append({
+ "label": "large obstacle ahead",
+ "type": "high",
+ "position": "center",
+ "distance": "close",
+ "box": [w // 4, h // 3, 3 * w // 4, h],
+ "edge_density": center_density,
+ "source": "opencv",
+ "reason": "High edge density directly ahead - possible collision"
+ })
+
+ except Exception as e:
+ cc.log(f"OpenCV obstacle detection error: {e}", "ERROR")
+
+ return obstacles
+
+
+def analyze_floor_clearance(frame: np.ndarray) -> Dict:
+ """
+ Analyze if the immediate floor area is clear for walking.
+
+ Uses color consistency and edge analysis of the floor region
+ to detect trip hazards, steps, or objects on the ground.
+ """
+ h, w = frame.shape[:2]
+
+ # Focus on bottom third of frame (floor area)
+ floor_region = frame[2 * h // 3:, :]
+
+ # Convert to grayscale
+ gray = cv2.cvtColor(floor_region, cv2.COLOR_BGR2GRAY)
+
+ # Calculate standard deviation - uniform floor has low std dev
+ std_dev = np.std(gray)
+
+ # Edge detection on floor
+ edges = cv2.Canny(gray, 30, 100)
+ edge_ratio = np.sum(edges > 0) / edges.size
+
+ # Analyze left, center, right paths
+ third = w // 3
+ left_edges = np.sum(edges[:, :third] > 0) / (edges[:, :third].size + 1)
+ center_edges = np.sum(edges[:, third:2*third] > 0) / (edges[:, third:2*third].size + 1)
+ right_edges = np.sum(edges[:, 2*third:] > 0) / (edges[:, 2*third:].size + 1)
+
+ # Determine clearest path
+ paths = {"left": left_edges, "center": center_edges, "right": right_edges}
+ clearest = min(paths, key=paths.get)
+
+ return {
+ "floor_uniformity": 1.0 - min(std_dev / 80, 1.0), # Higher = more uniform
+ "edge_ratio": edge_ratio,
+ "path_analysis": paths,
+ "suggested_path": clearest,
+ "floor_clear": edge_ratio < 0.1 and std_dev < 40
+ }
+
+
+def detect_obstacles(frame: np.ndarray, pil_image: Image.Image) -> List[Dict]:
+ """
+ PROPRIETARY 4-LAYER OBSTACLE DETECTION SYSTEM
+
+ Combines multiple techniques for comprehensive obstacle detection
+ using only a single RGB camera (no LIDAR/radar needed):
+
+ Layer 1: OpenCV Edge Detection (every frame, ~20ms)
+ - Canny edges, contours, bilateral filtering
+ - Immediate response for sudden obstacles
+
+ Layer 2: AI Depth Estimation (every frame if available, ~50ms)
+ - MiDaS or Depth Anything for LIDAR-like depth
+ - Knows actual distance, not just "something is there"
+
+ Layer 3: Optical Flow Collision Detection (every frame, ~30ms)
+ - Detects APPROACHING objects via motion expansion
+ - Biomimetic: same technique insects use!
+
+ Layer 4: Claude AI Analysis (every 3 seconds, ~1-2s)
+ - Semantic understanding of obstacles
+ - Knows target is NOT an obstacle
+ - Explains WHY something is dangerous
+
+ Plus: Ground Plane Segmentation, Temporal Tracking, Time-to-Collision
+ """
+ global cc, _last_obstacle_analysis_time, _cached_obstacles, _cached_opencv_obstacles, _depth_available
+
+ if not cc.obstacle_detection_active:
+ return []
+
+ current_time = time.time()
+ all_obstacles = []
+
+ # ===== LAYER 1: OpenCV Edge Detection (every frame) =====
+ # Fast, detects "something is there"
+ opencv_obstacles = detect_obstacles_opencv(frame)
+
+ for obs in opencv_obstacles:
+ if obs["distance"] in ["very_close", "close"] and obs.get("edge_density", 0) > 0.2:
+ obs["timestamp"] = current_time
+ cooldown_key = f"opencv_{obs['position']}_{obs['distance']}"
+ last_alert = cc.obstacle_alert_cooldown.get(cooldown_key, 0)
+
+ if current_time - last_alert > 2.0:
+ obs["should_alert"] = True
+ cc.obstacle_alert_cooldown[cooldown_key] = current_time
+ else:
+ obs["should_alert"] = False
+
+ all_obstacles.append(obs)
+
+ _cached_opencv_obstacles = opencv_obstacles
+
+ # ===== LAYER 2: AI Depth Estimation (LIDAR-like) =====
+ # Uses MiDaS or Depth Anything for real distance measurement
+ depth_map = None
+ if _depth_available:
+ depth_map = estimate_depth(frame)
+ if depth_map is not None:
+ depth_obstacles = detect_obstacles_depth(frame, depth_map)
+
+ for obs in depth_obstacles:
+ obs["timestamp"] = current_time
+ cooldown_key = f"depth_{obs['position']}_{obs['distance']}"
+ last_alert = cc.obstacle_alert_cooldown.get(cooldown_key, 0)
+
+ # Depth is more reliable, use slightly shorter cooldown
+ if current_time - last_alert > 1.5:
+ obs["should_alert"] = True
+ cc.obstacle_alert_cooldown[cooldown_key] = current_time
+
+ # Alert on approaching objects with TTC
+ if obs.get("approaching") and obs.get("time_to_collision"):
+ ttc = obs["time_to_collision"]
+ if ttc < 2.0:
+ obs["type"] = "high"
+ obs["reason"] = f"Approaching! {ttc:.1f}s to collision"
+ cc.log(f"COLLISION WARNING: {obs['label']} in {ttc:.1f}s", "ERROR")
+ else:
+ obs["should_alert"] = False
+
+ all_obstacles.append(obs)
+
+ # ===== LAYER 3: Optical Flow Collision Detection =====
+ # Biomimetic: detects approaching objects via expansion
+ flow_obstacles = detect_collision_optical_flow(frame)
+
+ for obs in flow_obstacles:
+ obs["timestamp"] = current_time
+ cooldown_key = f"flow_{obs['position']}"
+ last_alert = cc.obstacle_alert_cooldown.get(cooldown_key, 0)
+
+ if current_time - last_alert > 1.5 and obs.get("expansion_rate", 0) > 1.0:
+ obs["should_alert"] = True
+ cc.obstacle_alert_cooldown[cooldown_key] = current_time
+ cc.log(f"MOTION: {obs['label']} expanding at {obs['expansion_rate']:.1f}x", "WARN")
+ else:
+ obs["should_alert"] = False
+
+ all_obstacles.append(obs)
+
+ # ===== Ground Plane & Walkable Path Analysis =====
+ walkable = segment_walkable_ground(frame, depth_map)
+ cc.navigation_context = cc.navigation_context or {}
+ cc.navigation_context["walkable"] = walkable
+ cc.navigation_context["best_path"] = walkable.get("best_path", "center")
+
+ # Also run simpler floor analysis
+ floor_analysis = analyze_floor_clearance(frame)
+ if not floor_analysis["floor_clear"]:
+ cc.navigation_context["floor_analysis"] = floor_analysis
+ cc.navigation_context["suggested_path"] = floor_analysis["suggested_path"]
+
+ # ===== LAYER 4: Claude AI Analysis (every few seconds) =====
+ # Smart contextual understanding
+ if current_time - _last_obstacle_analysis_time >= _obstacle_analysis_interval:
+ try:
+ # Encode frame for Claude
+ _, buffer = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 70])
+ image_data = base64.b64encode(buffer).decode('utf-8')
+
+ # Get target box if available (for spatial context)
+ target_box = None
+ if cc.navigation_target:
+ for det in cc.current_detections:
+ if det.get("label", "").lower() == cc.navigation_target.lower():
+ target_box = det.get("box")
+ break
+
+ # Call Claude for intelligent obstacle analysis
+ claude_obstacles = analyze_obstacles_with_claude(
+ image_data,
+ cc.navigation_target or "the object",
+ target_box
+ )
+
+ _last_obstacle_analysis_time = current_time
+
+ # Process Claude's obstacles
+ for obs in claude_obstacles:
+ obstacle = {
+ "label": obs["label"],
+ "type": obs["type"],
+ "position": obs.get("position", "ahead"),
+ "distance": obs["distance"],
+ "reason": obs.get("reason", ""),
+ "timestamp": current_time,
+ "box": None,
+ "mask": None,
+ "source": "claude"
+ }
+
+ # Check cooldown for alerts
+ cooldown_key = f"claude_{obs['label']}_{obs['distance']}"
+ last_alert = cc.obstacle_alert_cooldown.get(cooldown_key, 0)
+
+ if current_time - last_alert > cc.obstacle_alert_interval:
+ obstacle["should_alert"] = True
+ cc.obstacle_alert_cooldown[cooldown_key] = current_time
+
+ # Log the obstacle with reason
+ cc.log(f"OBSTACLE: {obs['label']} ({obs['distance']}) - {obs.get('reason', '')}", "WARN")
+
+ # Save to database
+ if cc.navigation_db_id:
+ db.save_obstacle(
+ cc.navigation_db_id,
+ obs["label"],
+ obs["type"],
+ [],
+ obs["distance"],
+ alert_sent=True
+ )
+ else:
+ obstacle["should_alert"] = False
+
+ all_obstacles.append(obstacle)
+
+ # Log safe direction if available
+ if cc.navigation_context and cc.navigation_context.get("safe_direction"):
+ cc.log(f"Safe path: {cc.navigation_context['safe_direction']}", "INFO")
+
+ _cached_obstacles = [o for o in all_obstacles if o.get("source") == "claude"]
+
+ except Exception as e:
+ cc.log(f"Claude obstacle analysis error: {e}", "ERROR")
+ # Fall back to cached Claude results
+ all_obstacles.extend(_cached_obstacles)
+
+ else:
+ # Use cached Claude results between API calls
+ all_obstacles.extend(_cached_obstacles)
+
+ return all_obstacles
+
+
+def get_obstacle_segmentation(frame: np.ndarray, obstacle_label: str) -> Optional[np.ndarray]:
+ """
+ Optional: Get SAM3 segmentation mask for an obstacle identified by Claude.
+ This can be used if we want to visually highlight the obstacle.
+ """
+ global cc
+
+ if cc.processor is None:
+ return None
+
+ try:
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ pil_image = Image.fromarray(frame_rgb)
+
+ state = cc.processor.set_image(pil_image, {})
+ state = cc.processor.set_text_prompt(obstacle_label, state)
+
+ masks = state.get("masks")
+ if masks is not None and masks.numel() > 0:
+ return masks[0].squeeze().cpu().numpy()
+
+ except Exception as e:
+ cc.log(f"Obstacle segmentation failed: {e}", "ERROR")
+
+ return None
+
+
+def overlay_obstacles(display: np.ndarray, obstacles: List[Dict]) -> np.ndarray:
+ """
+ Overlay obstacle alerts on the display frame.
+
+ Handles both:
+ - OpenCV obstacles (have precise bounding boxes from edge detection)
+ - Claude obstacles (have position-based info like left/center/right)
+ """
+ if not obstacles:
+ return display
+
+ h, w = display.shape[:2]
+
+ # Obstacle color (orange/red based on severity)
+ colors = {
+ "high": (0, 0, 255), # Red
+ "medium": (0, 165, 255), # Orange
+ "low": (0, 255, 255) # Yellow
+ }
+
+ # Position to screen region mapping (for Claude obstacles without boxes)
+ position_regions = {
+ "left": (10, h // 3, w // 3, 2 * h // 3),
+ "center": (w // 3, h // 3, 2 * w // 3, 2 * h // 3),
+ "right": (2 * w // 3, h // 3, w - 10, 2 * h // 3),
+ "floor": (w // 4, 2 * h // 3, 3 * w // 4, h - 10),
+ "ahead": (w // 4, h // 4, 3 * w // 4, 3 * h // 4),
+ }
+
+ for i, obstacle in enumerate(obstacles):
+ severity = obstacle.get("type", "medium")
+ label = obstacle.get("label", "Obstacle")
+ distance = obstacle.get("distance", "medium")
+ position = obstacle.get("position", "ahead")
+ reason = obstacle.get("reason", "")
+ source = obstacle.get("source", "unknown")
+ box = obstacle.get("box")
+
+ color = colors.get(severity, (0, 165, 255))
+
+ # Determine region - use box if available (OpenCV), otherwise use position (Claude)
+ if box and len(box) == 4 and all(v is not None for v in box):
+ # OpenCV obstacle with precise box
+ rx1, ry1, rx2, ry2 = [int(v) for v in box]
+
+ # Draw bounding box with dashed lines for OpenCV detections
+ if source == "opencv":
+ # Dashed rectangle effect
+ for j in range(rx1, rx2, 10):
+ cv2.line(display, (j, ry1), (min(j + 5, rx2), ry1), color, 2)
+ cv2.line(display, (j, ry2), (min(j + 5, rx2), ry2), color, 2)
+ for j in range(ry1, ry2, 10):
+ cv2.line(display, (rx1, j), (rx1, min(j + 5, ry2)), color, 2)
+ cv2.line(display, (rx2, j), (rx2, min(j + 5, ry2)), color, 2)
+ else:
+ cv2.rectangle(display, (rx1, ry1), (rx2, ry2), color, 2)
+ else:
+ # Claude obstacle - use position-based region
+ region = position_regions.get(position, position_regions["ahead"])
+ rx1, ry1, rx2, ry2 = region
+
+ # Draw semi-transparent warning zone for close obstacles
+ if distance in ["very_close", "close"]:
+ overlay = display.copy()
+ alpha = 0.25 if severity == "high" else 0.15
+ cv2.rectangle(overlay, (rx1, ry1), (rx2, ry2), color, -1)
+ display = cv2.addWeighted(overlay, alpha, display, 1 - alpha, 0)
+
+ # Draw thick border
+ cv2.rectangle(display, (rx1, ry1), (rx2, ry2), color, 3)
+
+ # Draw warning icon
+ icon_size = 35 if severity == "high" else 25
+ icon_x = (rx1 + rx2) // 2 - icon_size // 2
+ icon_y = max(ry1 - icon_size - 5, 5)
+
+ # Warning triangle
+ triangle = np.array([
+ [icon_x + icon_size // 2, icon_y],
+ [icon_x, icon_y + icon_size],
+ [icon_x + icon_size, icon_y + icon_size]
+ ], np.int32)
+ cv2.fillPoly(display, [triangle], color)
+ cv2.polylines(display, [triangle], True, (0, 0, 0), 2)
+
+ # Exclamation mark
+ cv2.line(display, (icon_x + icon_size // 2, icon_y + 8),
+ (icon_x + icon_size // 2, icon_y + icon_size - 12), (0, 0, 0), 2)
+ cv2.circle(display, (icon_x + icon_size // 2, icon_y + icon_size - 6), 2, (0, 0, 0), -1)
+
+ # Label text
+ if distance in ["very_close"]:
+ label_text = f"STOP! {label}"
+ elif distance == "close":
+ label_text = f"WARNING: {label}"
+ else:
+ label_text = f"CAUTION: {label}"
+
+ # Add source indicator for debugging
+ if source == "opencv":
+ label_text += " [CV]"
+
+ text_x = rx1 + 5
+ text_y = ry2 + 20 if ry2 + 25 < h else ry1 - 40
+
+ # Text with background
+ (text_w, text_h), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 2)
+ cv2.rectangle(display, (text_x - 2, text_y - text_h - 3),
+ (text_x + text_w + 2, text_y + 3), (0, 0, 0), -1)
+ cv2.putText(display, label_text, (text_x, text_y),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.55, color, 2)
+
+ # Distance text
+ text_y += 18
+ distance_text = distance.replace("_", " ").upper()
+ cv2.putText(display, distance_text, (text_x, text_y),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.45, (255, 255, 255), 1)
+
+ # Reason (if from Claude)
+ if reason and source == "claude" and len(reason) < 40:
+ text_y += 16
+ cv2.putText(display, reason, (text_x, text_y),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.35, (200, 200, 200), 1)
+
+ # Draw path status indicator
+ if cc.navigation_context:
+ if cc.navigation_context.get("path_clear"):
+ cv2.putText(display, "PATH CLEAR", (w // 2 - 60, 30),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
+ elif cc.navigation_context.get("safe_direction"):
+ safe_text = f"Go: {cc.navigation_context['safe_direction']}"
+ cv2.putText(display, safe_text, (10, h - 20),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
+ elif cc.navigation_context.get("suggested_path"):
+ # From floor analysis
+ path_text = f"Clearest path: {cc.navigation_context['suggested_path']}"
+ cv2.putText(display, path_text, (10, h - 20),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 200, 255), 1)
+
+ return display
+
+
+# ===== FRAME PROCESSING =====
+
+def process_frame(frame: np.ndarray) -> np.ndarray:
+ """Process a frame through SAM3 and overlay results."""
+ global cc
+
+ cc.frame_count += 1
+ is_keyframe = cc.frame_count % cc.skip_frames == 0
+
+ # Handle geometric prompts (draw to search)
+ if cc.pending_box_prompt is not None or cc.pending_point_prompt is not None:
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ pil_image = Image.fromarray(frame_rgb)
+
+ cc.state = cc.processor.set_image(pil_image, cc.state)
+
+ if cc.pending_box_prompt is not None:
+ # Box prompt
+ x1, y1, x2, y2 = cc.pending_box_prompt
+ cc.state["geometric_prompt"] = {
+ "type": "box",
+ "box": [x1, y1, x2, y2]
+ }
+ cc.log(f"Processing box prompt: ({x1:.0f},{y1:.0f}) to ({x2:.0f},{y2:.0f})")
+
+ elif cc.pending_point_prompt is not None:
+ # Point prompt
+ x, y = cc.pending_point_prompt
+ cc.state["geometric_prompt"] = {
+ "type": "point",
+ "point": [x, y],
+ "label": 1 # 1 = foreground, 0 = background
+ }
+ cc.log(f"Processing point prompt: ({x:.0f},{y:.0f})")
+
+ # Get mask from geometric prompt
+ try:
+ # Use the processor's segment method with geometric prompt
+ masks = cc.state.get("masks")
+ if masks is not None and len(masks) > 0:
+ mask_np = masks[0].squeeze().cpu().numpy()
+ box = get_bounding_box_from_mask(mask_np)
+
+ cc.last_masks = masks[:1]
+ cc.last_boxes = torch.tensor([box]) if box else None
+ cc.last_scores = torch.tensor([1.0])
+ cc.last_labels = ["selected object"]
+
+ # Add to detections
+ with cc.lock:
+ cc.current_detections = [{
+ "id": 0,
+ "label": "selected object",
+ "confidence": 1.0,
+ "box": box,
+ "tracked": False,
+ }]
+
+ cc.log("Object selected via drawing", "SUCCESS")
+
+ except Exception as e:
+ cc.log(f"Geometric prompt failed: {e}", "ERROR")
+
+ # Clear the pending prompts
+ cc.pending_box_prompt = None
+ cc.pending_point_prompt = None
+ cc.draw_mode = None
+ cc.prev_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+
+ elif is_keyframe and not cc.paused:
+ # Full inference
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ pil_image = Image.fromarray(frame_rgb)
+
+ cc.state = cc.processor.set_image(pil_image, cc.state)
+
+ # Build new detections list (don't clear until we have new ones)
+ new_detections = []
+ cc.last_poses = {}
+
+ all_masks = []
+ all_boxes = []
+ all_scores = []
+ all_labels = []
+ all_object_ids = []
+
+ for prompt in cc.prompts:
+ if "geometric_prompt" in cc.state:
+ del cc.state["geometric_prompt"]
+
+ cc.state = cc.processor.set_text_prompt(prompt.strip(), cc.state)
+
+ masks = cc.state.get("masks")
+ boxes = cc.state.get("boxes")
+ scores = cc.state.get("scores")
+
+ if masks is not None and masks.numel() > 0:
+ for i in range(len(masks)):
+ mask_np = masks[i].squeeze().cpu().numpy()
+ box = boxes[i].cpu().numpy().tolist() if boxes is not None and i < len(boxes) else None
+ score = float(scores[i].cpu()) if scores is not None and i < len(scores) else 0.0
+
+ # Boundary suppression
+ if cc.enable_boundary_suppression and box:
+ if is_near_boundary(box, frame.shape, cc.boundary_margin):
+ continue
+
+ # Hotstart
+ if cc.enable_hotstart:
+ det_hash = f"{prompt}_{int(box[0]) if box else 0}_{int(box[1]) if box else 0}"
+ if det_hash not in cc.pending_detections:
+ cc.pending_detections[det_hash] = {"frames": 1, "data": None}
+ continue
+ else:
+ cc.pending_detections[det_hash]["frames"] += 1
+ if cc.pending_detections[det_hash]["frames"] < cc.hotstart_frames:
+ continue
+ del cc.pending_detections[det_hash]
+
+ # Fill holes
+ if cc.enable_fill_holes:
+ mask_np = fill_holes_in_mask(mask_np, cc.fill_hole_area)
+
+ # Smooth edges
+ if cc.enable_smooth_edges:
+ mask_np = smooth_mask_edges(mask_np, cc.smooth_kernel_size)
+
+ # Persistent object IDs
+ object_id = len(all_masks)
+ if cc.enable_persistent_ids:
+ if cc.tracked_objects:
+ match_id = match_detection_to_object(
+ mask_np,
+ {oid: obj["last_mask"] for oid, obj in cc.tracked_objects.items()
+ if "last_mask" in obj},
+ cc.iou_threshold
+ )
+ if match_id is not None:
+ object_id = match_id
+ else:
+ object_id = cc.next_object_id
+ cc.next_object_id += 1
+
+ if object_id not in cc.tracked_objects:
+ cc.tracked_objects[object_id] = {
+ "label": prompt.strip(),
+ "first_seen": cc.frame_count,
+ "color": COLORS[object_id % len(COLORS)],
+ }
+ cc.object_colors[object_id] = COLORS[object_id % len(COLORS)]
+
+ cc.tracked_objects[object_id]["last_seen"] = cc.frame_count
+ cc.tracked_objects[object_id]["last_mask"] = mask_np
+ cc.tracked_objects[object_id]["confidence"] = score
+
+ # Memory tracking
+ if cc.enable_memory_tracking:
+ mask_tensor = torch.from_numpy(mask_np).unsqueeze(0)
+ update_memory_bank(object_id, mask_tensor)
+
+ # ===== YOLO INTEGRATION =====
+ yolo_info = {}
+
+ # YOLO Classification
+ if cc.enable_yolo_classify and box and cc.yolo_classify_model is not None:
+ if cc.frame_count % cc.yolo_classify_every_n == 0:
+ classify_result = classify_region(frame, box, prompt.strip())
+ if classify_result:
+ yolo_info["classify"] = classify_result
+ if classify_result["yolo_confidence"] >= cc.yolo_classify_threshold:
+ cc.log(f"YOLO: {classify_result['yolo_class']} ({classify_result['yolo_confidence']:.0%})")
+
+ # YOLO Pose Estimation (only for person-like labels)
+ if cc.enable_yolo_pose and box and cc.yolo_pose_model is not None:
+ if is_person_label(prompt.strip()):
+ pose_result = estimate_pose(frame, box)
+ if pose_result and pose_result["confidence"] >= cc.yolo_pose_threshold:
+ yolo_info["pose"] = pose_result
+ cc.last_poses[object_id] = pose_result
+ cc.log(f"Pose detected for {prompt} (conf: {pose_result['confidence']:.0%})")
+
+ detection = {
+ "id": object_id,
+ "label": prompt.strip(),
+ "confidence": score,
+ "box": box,
+ "persistent_id": object_id if cc.enable_persistent_ids else None,
+ "yolo": yolo_info if yolo_info else None,
+ "tracked": False, # Fresh detection from SAM3
+ }
+ new_detections.append(detection)
+
+ all_masks.append(mask_np)
+ all_object_ids.append(object_id)
+ if box:
+ all_boxes.append(box)
+ all_scores.append(score)
+ all_labels.append(prompt.strip())
+
+ # Remove overlapping masks
+ if cc.enable_non_overlap and len(all_masks) > 1:
+ all_masks = remove_mask_overlaps(all_masks, all_scores)
+
+ # Occlusion suppression
+ if cc.enable_occlusion_suppression and len(all_masks) > 1:
+ keep_indices = []
+ for i, mask_i in enumerate(all_masks):
+ is_occluded = False
+ for j, mask_j in enumerate(all_masks):
+ if i != j and all_scores[j] > all_scores[i]:
+ overlap = np.logical_and(mask_i, mask_j).sum() / (mask_i.sum() + 1e-6)
+ if overlap > cc.occlusion_threshold:
+ is_occluded = True
+ break
+ if not is_occluded:
+ keep_indices.append(i)
+
+ all_masks = [all_masks[i] for i in keep_indices]
+ all_boxes = [all_boxes[i] for i in keep_indices if i < len(all_boxes)]
+ all_scores = [all_scores[i] for i in keep_indices]
+ all_labels = [all_labels[i] for i in keep_indices]
+ all_object_ids = [all_object_ids[i] for i in keep_indices]
+
+ # Store for tracking
+ if all_masks:
+ cc.last_masks = torch.stack([torch.from_numpy(m).unsqueeze(0) for m in all_masks])
+ cc.last_boxes = torch.tensor(all_boxes) if all_boxes else None
+ cc.last_scores = torch.tensor(all_scores) if all_scores else None
+ cc.last_labels = all_labels
+ cc.prev_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+ else:
+ cc.last_masks = None
+ cc.last_boxes = None
+ cc.last_scores = None
+ cc.last_labels = None
+
+ # CLIP-based visual matching filter
+ if cc.visual_match_enabled and cc.reference_embedding is not None and new_detections:
+ matched_detections = []
+ matched_indices = []
+
+ for i, det in enumerate(new_detections):
+ box = det.get("box")
+ if box:
+ # Crop the detected region
+ x1, y1, x2, y2 = [int(v) for v in box]
+ h, w = frame.shape[:2]
+ x1, y1 = max(0, x1), max(0, y1)
+ x2, y2 = min(w, x2), min(h, y2)
+
+ if x2 > x1 and y2 > y1:
+ crop = frame[y1:y2, x1:x2]
+ crop_pil = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
+
+ # Get CLIP embedding
+ crop_embedding = get_clip_embedding(crop_pil)
+ if crop_embedding is not None:
+ similarity = compute_clip_similarity(cc.reference_embedding, crop_embedding)
+ det["clip_similarity"] = similarity
+
+ if similarity >= cc.visual_match_threshold:
+ matched_detections.append(det)
+ matched_indices.append(i)
+ cc.log(f"Visual match: {det['label']} (sim: {similarity:.2f})")
+
+ if matched_detections:
+ new_detections = matched_detections
+ # Also filter masks
+ if all_masks and matched_indices:
+ all_masks = [all_masks[i] for i in matched_indices if i < len(all_masks)]
+ cc.last_masks = torch.stack([torch.from_numpy(m).unsqueeze(0) for m in all_masks]) if all_masks else None
+ else:
+ cc.log("No visual matches found", "WARN")
+ new_detections = []
+
+ # Atomically update detections (only update if we have new detections,
+ # otherwise keep the existing tracked detections)
+ if new_detections:
+ with cc.lock:
+ cc.current_detections = new_detections
+ # Note: If SAM3 found nothing but we have tracked objects, keep them
+ # They will be removed by tracking when they actually leave the frame
+
+ if all_labels:
+ cc.log(f"Detected: {', '.join(all_labels)}")
+
+ elif cc.enable_tracking and cc.last_masks is not None and not cc.paused:
+ # Track with optical flow
+ tracked = track_frame(frame)
+ if tracked is not None:
+ cc.last_masks = tracked
+
+ # Update detections based on tracked masks and remove objects that left frame
+ valid_indices = update_detections_from_tracked_masks(tracked, frame.shape)
+
+ # If some masks were invalidated, update the mask list too
+ if valid_indices is not None and len(valid_indices) < len(tracked):
+ # Keep only valid masks
+ valid_masks = [tracked[i] for i in valid_indices]
+ if valid_masks:
+ cc.last_masks = torch.stack(valid_masks)
+ else:
+ cc.last_masks = None
+
+ # Also update labels, scores, boxes to stay in sync
+ if cc.last_labels:
+ cc.last_labels = [cc.last_labels[i] for i in valid_indices if i < len(cc.last_labels)]
+ if cc.last_scores is not None and len(cc.last_scores) > 0:
+ try:
+ idx_tensor = torch.tensor(valid_indices, dtype=torch.long)
+ cc.last_scores = cc.last_scores[idx_tensor] if len(valid_indices) > 0 else None
+ except Exception:
+ cc.last_scores = None
+ if cc.last_boxes is not None and len(cc.last_boxes) > 0:
+ try:
+ idx_tensor = torch.tensor(valid_indices, dtype=torch.long)
+ cc.last_boxes = cc.last_boxes[idx_tensor] if len(valid_indices) > 0 else None
+ except Exception:
+ cc.last_boxes = None
+
+ # Overlay masks on frame
+ display = frame.copy()
+ if cc.last_masks is not None:
+ display = overlay_masks(display, cc.last_masks, cc.last_boxes, cc.last_scores, cc.last_labels)
+
+ # Draw pose overlays
+ if cc.enable_yolo_pose and cc.last_poses:
+ for obj_id, pose_data in cc.last_poses.items():
+ display = draw_pose_overlay(display, pose_data, obj_id)
+
+ # Obstacle detection during navigation (run on keyframes)
+ if cc.obstacle_detection_active and is_keyframe and not cc.paused:
+ try:
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ pil_image = Image.fromarray(frame_rgb)
+ obstacles = detect_obstacles(frame, pil_image)
+
+ if obstacles:
+ cc.current_obstacles = obstacles
+ display = overlay_obstacles(display, obstacles)
+
+ # Log high-severity obstacles that should alert
+ for obs in obstacles:
+ if obs.get("should_alert") and obs.get("type") in ["high", "medium"]:
+ cc.log(f"OBSTACLE: {obs['label']} ({obs['distance']})", "WARN")
+ except Exception as e:
+ cc.log(f"Obstacle overlay error: {e}", "ERROR")
+
+ return display
+
+
+def track_frame(frame: np.ndarray) -> Optional[torch.Tensor]:
+ """Track masks using optical flow."""
+ if cc.last_masks is None or cc.prev_gray is None:
+ return None
+
+ try:
+ curr_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+
+ flow = cv2.calcOpticalFlowFarneback(
+ cc.prev_gray, curr_gray, None,
+ pyr_scale=0.5, levels=3, winsize=15,
+ iterations=3, poly_n=5, poly_sigma=1.2, flags=0
+ )
+
+ h, w = curr_gray.shape
+ flow_map_x = np.arange(w).reshape(1, -1).repeat(h, axis=0).astype(np.float32)
+ flow_map_y = np.arange(h).reshape(-1, 1).repeat(w, axis=1).astype(np.float32)
+ flow_map_x += flow[:, :, 0]
+ flow_map_y += flow[:, :, 1]
+
+ tracked_masks = []
+ for mask in cc.last_masks:
+ if isinstance(mask, torch.Tensor):
+ mask_np = mask.cpu().numpy().squeeze()
+ else:
+ mask_np = mask.squeeze()
+
+ if mask_np.shape != (h, w):
+ mask_np = cv2.resize(mask_np.astype(np.float32), (w, h))
+
+ warped = cv2.remap(
+ mask_np.astype(np.float32),
+ flow_map_x, flow_map_y,
+ interpolation=cv2.INTER_LINEAR,
+ borderMode=cv2.BORDER_CONSTANT,
+ borderValue=0
+ )
+ warped = (warped > 0.5).astype(np.float32)
+
+ if cc.enable_fill_holes:
+ warped = fill_holes_in_mask(warped, cc.fill_hole_area)
+ if cc.enable_smooth_edges:
+ warped = smooth_mask_edges(warped, cc.smooth_kernel_size)
+
+ tracked_masks.append(torch.from_numpy(warped).unsqueeze(0).to(cc.device_str))
+
+ cc.prev_gray = curr_gray
+
+ if tracked_masks:
+ return torch.stack(tracked_masks)
+
+ except Exception as e:
+ cc.log(f"Tracking error: {e}", "ERROR")
+
+ return None
+
+
+def overlay_masks(frame: np.ndarray, masks: torch.Tensor, boxes=None, scores=None, labels=None, alpha=0.5) -> np.ndarray:
+ """Overlay masks on frame."""
+ if masks is None or masks.numel() == 0:
+ return frame
+
+ overlay = frame.copy()
+ h, w = frame.shape[:2]
+ masks_np = masks.squeeze(1).cpu().numpy()
+
+ scores_np = scores.cpu().numpy() if scores is not None and isinstance(scores, torch.Tensor) else scores
+
+ for i, mask in enumerate(masks_np):
+ if mask.shape != (h, w):
+ mask = cv2.resize(mask.astype(np.float32), (w, h)) > 0.5
+
+ # Use persistent color if available
+ if cc.enable_persistent_ids and i < len(cc.current_detections):
+ det = cc.current_detections[i]
+ obj_id = det.get("persistent_id")
+ color = cc.object_colors.get(obj_id, COLORS[i % len(COLORS)])
+ else:
+ color = COLORS[i % len(COLORS)]
+
+ mask_region = mask.astype(bool)
+ overlay[mask_region] = (
+ overlay[mask_region] * (1 - alpha) + np.array(color) * alpha
+ ).astype(np.uint8)
+
+ # Draw contour
+ contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ cv2.drawContours(overlay, contours, -1, color, 2)
+
+ # Draw label
+ if len(contours) > 0:
+ largest = max(contours, key=cv2.contourArea)
+ x, y, cw, ch = cv2.boundingRect(largest)
+
+ label = labels[i] if labels and i < len(labels) else "object"
+ conf = scores_np[i] if scores_np is not None and i < len(scores_np) else 0.0
+
+ # Add persistent ID and YOLO info to label
+ text_parts = []
+ if cc.enable_persistent_ids and i < len(cc.current_detections):
+ obj_id = cc.current_detections[i].get("persistent_id")
+ text_parts.append(f"#{obj_id}")
+
+ text_parts.append(f"{label} {conf:.0%}")
+
+ # Add YOLO classification if available
+ if i < len(cc.current_detections):
+ det = cc.current_detections[i]
+ yolo_info = det.get("yolo")
+ if yolo_info and "classify" in yolo_info:
+ yolo_class = yolo_info["classify"]["yolo_class"]
+ yolo_conf = yolo_info["classify"]["yolo_confidence"]
+ text_parts.append(f"[{yolo_class} {yolo_conf:.0%}]")
+
+ text = " ".join(text_parts)
+
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ (tw, th), _ = cv2.getTextSize(text, font, 0.5, 1)
+
+ cv2.rectangle(overlay, (x, y - th - 4), (x + tw + 4, y), color, -1)
+ cv2.putText(overlay, text, (x + 2, y - 2), font, 0.5, (255, 255, 255), 1)
+
+ return overlay
+
+
+def generate_frames():
+ """Generator for video streaming."""
+ global cc
+
+ while cc.running:
+ if cc.camera is None or not cc.camera.isOpened():
+ time.sleep(0.1)
+ continue
+
+ ret, frame = cc.camera.read()
+ if not ret:
+ time.sleep(0.1)
+ continue
+
+ # Apply flip transformations
+ if cc.flip_horizontal and cc.flip_vertical:
+ frame = cv2.flip(frame, -1) # Flip both
+ elif cc.flip_horizontal:
+ frame = cv2.flip(frame, 1) # Flip horizontally (mirror)
+ elif cc.flip_vertical:
+ frame = cv2.flip(frame, 0) # Flip vertically
+
+ start = time.time()
+
+ # Store raw frame (without overlays) for Claude analysis
+ cc.current_raw_frame = frame.copy()
+
+ # Process frame (adds overlays)
+ display = process_frame(frame)
+
+ # Calculate FPS
+ elapsed = time.time() - start
+ cc.fps = 1.0 / elapsed if elapsed > 0 else 0
+
+ # Encode to JPEG
+ _, buffer = cv2.imencode('.jpg', display, [cv2.IMWRITE_JPEG_QUALITY, 85])
+ cc.current_frame = display
+ cc.current_frame_jpeg = buffer.tobytes()
+
+ yield (b'--frame\r\n'
+ b'Content-Type: image/jpeg\r\n\r\n' + cc.current_frame_jpeg + b'\r\n')
+
+
+def analyze_with_claude(image_data: str, label: str) -> str:
+ """Send image to Claude for analysis."""
+ global ANTHROPIC_API_KEY
+
+ if not ANTHROPIC_API_KEY:
+ return "Error: ANTHROPIC_API_KEY not set. Set it via environment variable or --api-key argument."
+
+ try:
+ import anthropic
+
+ client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)
+
+ if image_data.startswith("data:"):
+ image_data = image_data.split(",", 1)[1]
+
+ message = client.messages.create(
+ model="claude-sonnet-4-20250514",
+ max_tokens=500,
+ messages=[
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": "image/jpeg",
+ "data": image_data,
+ },
+ },
+ {
+ "type": "text",
+ "text": f"This is a cropped image of a detected '{label}'. Please provide a brief, detailed description of what you see. Focus on: appearance, distinctive features, actions/pose, and any notable details. Keep it concise (2-3 sentences)."
+ }
+ ],
+ }
+ ],
+ )
+
+ return message.content[0].text
+
+ except Exception as e:
+ return f"Analysis error: {str(e)}"
+
+
+def analysis_worker():
+ """Background worker for Claude analysis."""
+ global cc
+
+ while cc.running:
+ if cc.analysis_queue:
+ with cc.lock:
+ if cc.analysis_queue:
+ item = cc.analysis_queue.pop(0)
+ cc.analyzing = True
+ else:
+ item = None
+
+ if item:
+ cc.log(f"Analyzing object #{item['id']}...", "INFO")
+
+ detections = cc.current_detections
+ label = "object"
+ for det in detections:
+ if det.get("id") == item["id"]:
+ label = det.get("label", "object")
+ break
+
+ result = analyze_with_claude(item["image_data"], label)
+ cc.add_analysis_result(item["id"], result)
+ cc.log(f"Analysis complete for #{item['id']}", "SUCCESS")
+ cc.analyzing = False
+ else:
+ time.sleep(0.5)
+
+
+# ===== FLASK ROUTES =====
+
+@app.route('/')
+def index():
+ """Main command center page."""
+ return render_template('index.html',
+ prompts=cc.prompts,
+ threshold=cc.confidence_threshold,
+ skip_frames=cc.skip_frames,
+ tracking=cc.enable_tracking,
+ features=cc.get_feature_status(),
+ yolo_available=cc.yolo_available)
+
+
+@app.route('/video_feed')
+def video_feed():
+ """Video streaming route."""
+ return Response(generate_frames(),
+ mimetype='multipart/x-mixed-replace; boundary=frame')
+
+
+@app.route('/api/status')
+def api_status():
+ """Get current status."""
+ filtered, hidden = cc.get_filtered_detections()
+ return jsonify({
+ "running": cc.running,
+ "paused": cc.paused,
+ "fps": round(cc.fps, 1),
+ "frame_count": cc.frame_count,
+ "device": cc.device_str,
+ "detections": filtered,
+ "hidden_counts": hidden,
+ "prompts": cc.prompts,
+ "max_objects": cc.max_objects_per_prompt,
+ "show_all": cc.show_all_matches,
+ "analyzing": cc.analyzing,
+ "analysis_queue_size": len(cc.analysis_queue),
+ "features": cc.get_feature_status(),
+ "tracked_objects_count": len(cc.tracked_objects),
+ "memory_bank_size": len(cc.memory_bank),
+ "yolo_available": cc.yolo_available,
+ "poses_count": len(cc.last_poses),
+ })
+
+
+@app.route('/api/logs')
+def api_logs():
+ """Get recent logs."""
+ return jsonify({"logs": cc.get_logs()})
+
+
+@app.route('/api/analysis_results')
+def api_analysis_results():
+ """Get analysis results."""
+ with cc.lock:
+ results = list(cc.analysis_results)
+ return jsonify({"results": results})
+
+
+@app.route('/api/set_prompts', methods=['POST'])
+def api_set_prompts():
+ """Set detection prompts."""
+ data = request.json
+ prompts_str = data.get("prompts", "object")
+ cc.prompts = [p.strip() for p in prompts_str.split(",") if p.strip()]
+ cc.state = None
+ cc.last_masks = None
+ cc.last_boxes = None
+ cc.last_scores = None
+ cc.last_labels = None
+ cc.tracked_objects = {}
+ cc.memory_bank = {}
+ cc.last_poses = {}
+ cc.log(f"Prompts updated: {', '.join(cc.prompts)}")
+ return jsonify({"success": True, "prompts": cc.prompts})
+
+
+@app.route('/api/set_limit', methods=['POST'])
+def api_set_limit():
+ """Set max objects limit for a prompt."""
+ data = request.json
+ prompt = data.get("prompt")
+ limit = data.get("limit")
+
+ if limit is not None:
+ cc.max_objects_per_prompt[prompt] = int(limit)
+ elif prompt in cc.max_objects_per_prompt:
+ del cc.max_objects_per_prompt[prompt]
+
+ cc.log(f"Limit for '{prompt}': {limit if limit else 'unlimited'}")
+ return jsonify({"success": True})
+
+
+@app.route('/api/toggle_show_all', methods=['POST'])
+def api_toggle_show_all():
+ """Toggle show all matches for a prompt."""
+ data = request.json
+ prompt = data.get("prompt")
+ cc.show_all_matches[prompt] = not cc.show_all_matches.get(prompt, False)
+ cc.log(f"Show all for '{prompt}': {cc.show_all_matches[prompt]}")
+ return jsonify({"success": True, "show_all": cc.show_all_matches[prompt]})
+
+
+@app.route('/api/toggle_pause', methods=['POST'])
+def api_toggle_pause():
+ """Toggle pause state."""
+ cc.paused = not cc.paused
+ cc.log("Paused" if cc.paused else "Resumed")
+ return jsonify({"success": True, "paused": cc.paused})
+
+
+@app.route('/api/reset', methods=['POST'])
+def api_reset():
+ """Reset detection state."""
+ cc.state = None
+ cc.last_masks = None
+ cc.last_boxes = None
+ cc.last_scores = None
+ cc.last_labels = None
+ cc.tracked_objects = {}
+ cc.memory_bank = {}
+ cc.object_colors = {}
+ cc.next_object_id = 1
+ cc.pending_detections = {}
+ cc.last_poses = {}
+ cc.clear_detections()
+ cc.log("Detection state reset")
+ return jsonify({"success": True})
+
+
+@app.route('/api/set_threshold', methods=['POST'])
+def api_set_threshold():
+ """Set confidence threshold."""
+ data = request.json
+ cc.confidence_threshold = float(data.get("threshold", 0.3))
+ if cc.processor:
+ cc.processor.confidence_threshold = cc.confidence_threshold
+ cc.log(f"Threshold set to {cc.confidence_threshold:.2f}")
+ return jsonify({"success": True})
+
+
+@app.route('/api/set_skip_frames', methods=['POST'])
+def api_set_skip_frames():
+ """Set skip frames value."""
+ data = request.json
+ cc.skip_frames = max(1, int(data.get("skip_frames", 3)))
+ cc.log(f"Skip frames set to {cc.skip_frames}")
+ return jsonify({"success": True})
+
+
+# ===== FEATURE TOGGLE ROUTES =====
+
+@app.route('/api/toggle_feature', methods=['POST'])
+def api_toggle_feature():
+ """Toggle a feature on/off."""
+ data = request.json
+ feature = data.get("feature")
+
+ feature_map = {
+ "tracking": "enable_tracking",
+ "memory_tracking": "enable_memory_tracking",
+ "persistent_ids": "enable_persistent_ids",
+ "fill_holes": "enable_fill_holes",
+ "non_overlap": "enable_non_overlap",
+ "smooth_edges": "enable_smooth_edges",
+ "boundary_suppression": "enable_boundary_suppression",
+ "occlusion_suppression": "enable_occlusion_suppression",
+ "hotstart": "enable_hotstart",
+ "yolo_classify": "enable_yolo_classify",
+ "yolo_pose": "enable_yolo_pose",
+ "show_keypoint_labels": "show_keypoint_labels",
+ "show_skeleton": "show_skeleton",
+ "label_spoofing": "enable_label_spoofing",
+ }
+
+ if feature in feature_map:
+ attr = feature_map[feature]
+ current = getattr(cc, attr)
+ setattr(cc, attr, not current)
+ new_val = getattr(cc, attr)
+ cc.log(f"{feature}: {'ON' if new_val else 'OFF'}")
+ return jsonify({"success": True, "feature": feature, "enabled": new_val})
+
+ return jsonify({"success": False, "error": "Unknown feature"})
+
+
+@app.route('/api/set_feature_param', methods=['POST'])
+def api_set_feature_param():
+ """Set a feature parameter value."""
+ data = request.json
+ param = data.get("param")
+ value = data.get("value")
+
+ param_map = {
+ "fill_hole_area": ("fill_hole_area", int),
+ "smooth_kernel_size": ("smooth_kernel_size", int),
+ "boundary_margin": ("boundary_margin", int),
+ "occlusion_threshold": ("occlusion_threshold", float),
+ "hotstart_frames": ("hotstart_frames", int),
+ "iou_threshold": ("iou_threshold", float),
+ "memory_max_frames": ("memory_max_frames", int),
+ "yolo_classify_threshold": ("yolo_classify_threshold", float),
+ "yolo_pose_threshold": ("yolo_pose_threshold", float),
+ "yolo_classify_every_n": ("yolo_classify_every_n", int),
+ "keypoint_radius": ("keypoint_radius", int),
+ "skeleton_thickness": ("skeleton_thickness", int),
+ }
+
+ if param in param_map:
+ attr, type_fn = param_map[param]
+ setattr(cc, attr, type_fn(value))
+ cc.log(f"{param} set to {value}")
+ return jsonify({"success": True})
+
+ return jsonify({"success": False, "error": "Unknown parameter"})
+
+
+@app.route('/api/analyze_object', methods=['POST'])
+def api_analyze_object():
+ """Queue an object for Claude analysis with mask-based cropping."""
+ data = request.json
+ detection_id = data.get("detection_id")
+ box = data.get("box")
+ mask_index = data.get("mask_index") # Index into cc.last_masks
+
+ # Use raw frame (without overlays) for analysis
+ if cc.current_raw_frame is None:
+ return jsonify({"success": False, "error": "No frame available"})
+
+ try:
+ frame = cc.current_raw_frame.copy()
+ h, w = frame.shape[:2]
+
+ # Try to use mask for better cropping
+ mask = None
+ if mask_index is not None and cc.last_masks is not None:
+ try:
+ if mask_index < len(cc.last_masks):
+ mask = cc.last_masks[mask_index].squeeze().cpu().numpy()
+ if mask.shape != (h, w):
+ mask = cv2.resize(mask.astype(np.float32), (w, h)) > 0.5
+ except Exception as e:
+ cc.log(f"Could not get mask: {e}", "WARN")
+ mask = None
+
+ if mask is not None and mask.sum() > 0:
+ # Use mask to create a clean crop with transparent/white background
+ # Get bounding box from mask
+ rows = np.any(mask, axis=1)
+ cols = np.any(mask, axis=0)
+ y_min, y_max = np.where(rows)[0][[0, -1]]
+ x_min, x_max = np.where(cols)[0][[0, -1]]
+
+ # Add padding
+ pad = 15
+ x1 = max(0, x_min - pad)
+ y1 = max(0, y_min - pad)
+ x2 = min(w, x_max + pad)
+ y2 = min(h, y_max + pad)
+
+ # Crop the region
+ crop = frame[y1:y2, x1:x2].copy()
+ mask_crop = mask[y1:y2, x1:x2]
+
+ # Apply mask - set background to white for cleaner analysis
+ mask_3ch = np.stack([mask_crop] * 3, axis=-1)
+ crop = np.where(mask_3ch, crop, 255).astype(np.uint8)
+
+ elif box:
+ # Fallback to box-based cropping
+ x1, y1, x2, y2 = [int(v) for v in box]
+ pad = 20
+ x1 = max(0, x1 - pad)
+ y1 = max(0, y1 - pad)
+ x2 = min(w, x2 + pad)
+ y2 = min(h, y2 + pad)
+ crop = frame[y1:y2, x1:x2]
+ else:
+ crop = frame
+
+ _, buffer = cv2.imencode('.jpg', crop, [cv2.IMWRITE_JPEG_QUALITY, 90])
+ image_data = base64.b64encode(buffer).decode('utf-8')
+
+ cc.queue_analysis(detection_id, image_data)
+ cc.log(f"Queued object #{detection_id} for analysis (mask-cropped: {mask is not None})")
+
+ return jsonify({"success": True})
+
+ except Exception as e:
+ cc.log(f"Failed to queue analysis: {e}", "ERROR")
+ return jsonify({"success": False, "error": str(e)})
+
+
+@app.route('/api/describe_scene', methods=['POST'])
+def api_describe_scene():
+ """Send full scene to Claude for description."""
+ global ANTHROPIC_API_KEY
+
+ if not ANTHROPIC_API_KEY:
+ return jsonify({"success": False, "error": "ANTHROPIC_API_KEY not set"})
+
+ if cc.current_raw_frame is None:
+ return jsonify({"success": False, "error": "No frame available"})
+
+ try:
+ import anthropic
+
+ frame = cc.current_raw_frame.copy()
+ _, buffer = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
+ image_data = base64.b64encode(buffer).decode('utf-8')
+
+ cc.log("Analyzing full scene with Claude...")
+
+ client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)
+
+ message = client.messages.create(
+ model="claude-sonnet-4-20250514",
+ max_tokens=800,
+ messages=[
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "source": {
+ "type": "base64",
+ "media_type": "image/jpeg",
+ "data": image_data,
+ },
+ },
+ {
+ "type": "text",
+ "text": "Please describe this scene in detail. Include: the setting/environment, all visible objects and people, their positions and relationships, any activities or actions taking place, lighting conditions, and any notable details. Be comprehensive but concise (3-5 sentences)."
+ }
+ ],
+ }
+ ],
+ )
+
+ result = message.content[0].text
+ cc.log("Scene analysis complete", "SUCCESS")
+
+ # Add to analysis results
+ cc.add_analysis_result(-1, f"[SCENE] {result}")
+
+ return jsonify({
+ "success": True,
+ "description": result
+ })
+
+ except Exception as e:
+ cc.log(f"Scene analysis failed: {e}", "ERROR")
+ return jsonify({"success": False, "error": str(e)})
+
+
+@app.route('/api/tracked_objects')
+def api_tracked_objects():
+ """Get list of tracked objects with persistent IDs."""
+ objects = []
+ for obj_id, data in cc.tracked_objects.items():
+ objects.append({
+ "id": obj_id,
+ "label": data.get("label"),
+ "first_seen": data.get("first_seen"),
+ "last_seen": data.get("last_seen"),
+ "confidence": data.get("confidence", 0),
+ "frames_tracked": data.get("last_seen", 0) - data.get("first_seen", 0),
+ })
+ return jsonify({"objects": objects})
+
+
+@app.route('/api/poses')
+def api_poses():
+ """Get current pose data for all detected persons."""
+ poses = []
+ for obj_id, pose_data in cc.last_poses.items():
+ poses.append({
+ "object_id": obj_id,
+ "confidence": pose_data.get("confidence", 0),
+ "keypoints": [
+ {"name": name, "x": kp[0], "y": kp[1], "confidence": kp[2]}
+ for name, kp in zip(POSE_KEYPOINTS, pose_data.get("keypoints", []))
+ ]
+ })
+ return jsonify({"poses": poses})
+
+
+@app.route('/api/coco_mapping')
+def api_coco_mapping():
+ """Get SAM3 to COCO label mapping."""
+ return jsonify({
+ "mapping": SAM3_TO_COCO,
+ "coco_classes": COCO_CLASSES
+ })
+
+
+# ===== VOICE SEARCH ROUTES =====
+
+def parse_voice_query_with_claude(voice_text: str) -> Dict:
+ """
+ Use Claude to parse a voice query into search prompts.
+
+ Handles queries like:
+ - "help me find a red car"
+ - "can you search for a person and a dog"
+ - "look for my phone, keys, and wallet"
+ - "find the blue cup on the table"
+
+ Returns dict with:
+ - prompts: List of parsed object prompts (comma-separated format)
+ - is_multi: Whether multiple objects were requested
+ - feedback: Human-readable feedback message
+ """
+ global ANTHROPIC_API_KEY
+
+ if not ANTHROPIC_API_KEY:
+ return {
+ "success": False,
+ "error": "ANTHROPIC_API_KEY not set",
+ "prompts": [voice_text], # Fallback: use raw text
+ "feedback": f"API key not set. Searching for: {voice_text}"
+ }
+
+ try:
+ import anthropic
+
+ client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY)
+
+ message = client.messages.create(
+ model="claude-sonnet-4-20250514",
+ max_tokens=300,
+ messages=[
+ {
+ "role": "user",
+ "content": f"""Parse this voice command for an object detection system. Extract the objects the user wants to find.
+
+Voice command: "{voice_text}"
+
+Rules:
+1. Extract object names/descriptions that can be detected visually
+2. If multiple objects are mentioned, list them all
+3. Include color/size descriptors if mentioned (e.g., "red car", "large dog")
+4. Ignore filler words like "help me find", "can you search for", "look for"
+5. Return ONLY a JSON object, no other text
+
+Return JSON format:
+{{"prompts": ["object1", "object2"], "feedback": "Searching for object1 and object2"}}
+
+Examples:
+- "help me find a red car" -> {{"prompts": ["red car"], "feedback": "Searching for red car"}}
+- "search for people and dogs" -> {{"prompts": ["person", "dog"], "feedback": "Searching for person and dog"}}
+- "find my phone and keys" -> {{"prompts": ["phone", "keys"], "feedback": "Searching for phone and keys"}}
+- "look for a blue cup" -> {{"prompts": ["blue cup"], "feedback": "Searching for blue cup"}}"""
+ }
+ ],
+ )
+
+ response_text = message.content[0].text.strip()
+
+ # Parse JSON from response
+ # Handle potential markdown code blocks
+ if "```json" in response_text:
+ response_text = response_text.split("```json")[1].split("```")[0].strip()
+ elif "```" in response_text:
+ response_text = response_text.split("```")[1].split("```")[0].strip()
+
+ result = json.loads(response_text)
+
+ prompts = result.get("prompts", [])
+ feedback = result.get("feedback", f"Searching for {', '.join(prompts)}")
+
+ return {
+ "success": True,
+ "prompts": prompts,
+ "is_multi": len(prompts) > 1,
+ "feedback": feedback,
+ "raw_query": voice_text
+ }
+
+ except json.JSONDecodeError as e:
+ cc.log(f"Failed to parse Claude response as JSON: {e}", "ERROR")
+ # Fallback: just use the voice text directly
+ return {
+ "success": True,
+ "prompts": [voice_text],
+ "is_multi": False,
+ "feedback": f"Searching for {voice_text}",
+ "raw_query": voice_text
+ }
+ except Exception as e:
+ cc.log(f"Voice query parsing error: {e}", "ERROR")
+ return {
+ "success": False,
+ "error": str(e),
+ "prompts": [],
+ "feedback": "Failed to parse voice command"
+ }
+
+
+def check_describe_command(voice_text: str) -> Optional[Dict]:
+ """
+ Check if voice command is a describe command.
+ Returns dict with action info if it's a describe command, None otherwise.
+
+ Handles:
+ - "describe scene" / "describe the scene" / "what do you see"
+ - "describe object 1" / "describe the first object" / "tell me about object 2"
+ - "analyze object 3" / "what is object 1"
+ """
+ text_lower = voice_text.lower().strip()
+
+ # Scene describe patterns
+ scene_patterns = [
+ "describe scene", "describe the scene", "describe this scene",
+ "what do you see", "what's in the scene", "describe everything",
+ "describe the view", "describe what you see", "analyze scene",
+ "tell me about the scene", "what's happening"
+ ]
+
+ for pattern in scene_patterns:
+ if pattern in text_lower:
+ return {
+ "action": "describe_scene",
+ "feedback": "Describing the scene..."
+ }
+
+ # Object describe patterns - extract object number
+ import re
+
+ # Patterns like "describe object 1", "analyze object 2", "tell me about object 3"
+ object_patterns = [
+ r"describe (?:the )?(?:object|item|thing) (\d+)",
+ r"analyze (?:the )?(?:object|item|thing) (\d+)",
+ r"what is (?:object|item|thing) (\d+)",
+ r"tell me about (?:object|item|thing) (\d+)",
+ r"describe (?:the )?(\d+)(?:st|nd|rd|th)? (?:object|item|thing)",
+ r"describe number (\d+)",
+ r"object (\d+) describe",
+ ]
+
+ for pattern in object_patterns:
+ match = re.search(pattern, text_lower)
+ if match:
+ obj_num = int(match.group(1))
+ return {
+ "action": "describe_object",
+ "object_id": obj_num,
+ "feedback": f"Describing object {obj_num}..."
+ }
+
+ # Ordinal patterns like "describe the first object", "analyze the second item"
+ ordinals = {
+ "first": 0, "1st": 0,
+ "second": 1, "2nd": 1,
+ "third": 2, "3rd": 2,
+ "fourth": 3, "4th": 3,
+ "fifth": 4, "5th": 4,
+ }
+
+ for ordinal, idx in ordinals.items():
+ if ordinal in text_lower and ("object" in text_lower or "item" in text_lower or "thing" in text_lower):
+ if "describe" in text_lower or "analyze" in text_lower or "tell me" in text_lower or "what is" in text_lower:
+ return {
+ "action": "describe_object",
+ "object_index": idx,
+ "feedback": f"Describing the {ordinal} object..."
+ }
+
+ return None
+
+
+@app.route('/api/voice_search', methods=['POST'])
+def api_voice_search():
+ """Process a voice search query through Claude and set prompts."""
+ data = request.json
+ voice_text = data.get("text", "").strip()
+
+ if not voice_text:
+ return jsonify({"success": False, "error": "No voice text provided"})
+
+ cc.log(f"Voice query received: '{voice_text}'", "INFO")
+ cc.last_voice_query = voice_text
+
+ # First check for describe commands
+ describe_cmd = check_describe_command(voice_text)
+ if describe_cmd:
+ cc.add_voice_feedback(describe_cmd["feedback"], "info")
+
+ if describe_cmd["action"] == "describe_scene":
+ return jsonify({
+ "success": True,
+ "action": "describe_scene",
+ "feedback": describe_cmd["feedback"],
+ "tts_message": describe_cmd["feedback"]
+ })
+
+ elif describe_cmd["action"] == "describe_object":
+ # Find the object to describe
+ obj_id = describe_cmd.get("object_id")
+ obj_index = describe_cmd.get("object_index")
+
+ detections = cc.current_detections
+
+ if not detections:
+ return jsonify({
+ "success": False,
+ "error": "No objects detected",
+ "feedback": "No objects are currently detected"
+ })
+
+ # Find the detection
+ target_det = None
+ target_index = None
+
+ if obj_id is not None:
+ # Look for object with this ID
+ for i, det in enumerate(detections):
+ if det.get("id") == obj_id:
+ target_det = det
+ target_index = i
+ break
+ elif obj_index is not None:
+ # Use index directly
+ if obj_index < len(detections):
+ target_det = detections[obj_index]
+ target_index = obj_index
+
+ if target_det is None:
+ return jsonify({
+ "success": False,
+ "error": f"Object not found",
+ "feedback": f"Could not find the specified object"
+ })
+
+ return jsonify({
+ "success": True,
+ "action": "describe_object",
+ "detection": target_det,
+ "mask_index": target_index,
+ "feedback": describe_cmd["feedback"],
+ "tts_message": describe_cmd["feedback"]
+ })
+
+ # Parse the voice query with Claude for search
+ result = parse_voice_query_with_claude(voice_text)
+
+ if result["success"] and result["prompts"]:
+ # Update prompts
+ cc.prompts = result["prompts"]
+ cc.last_parsed_prompts = result["prompts"]
+
+ # Reset detection state for new search
+ cc.state = None
+ cc.last_masks = None
+ cc.last_boxes = None
+ cc.last_scores = None
+ cc.last_labels = None
+ cc.tracked_objects = {}
+ cc.memory_bank = {}
+ cc.last_poses = {}
+
+ prompt_str = ", ".join(result["prompts"])
+ cc.log(f"Voice search: {prompt_str}", "SUCCESS")
+ cc.add_voice_feedback(result["feedback"], "success")
+
+ return jsonify({
+ "success": True,
+ "action": "search",
+ "prompts": result["prompts"],
+ "prompt_string": prompt_str,
+ "is_multi": result["is_multi"],
+ "feedback": result["feedback"],
+ "tts_message": result["feedback"]
+ })
+ else:
+ error_msg = result.get("error", "Could not understand the voice command")
+ cc.add_voice_feedback(f"Error: {error_msg}", "error")
+ return jsonify({
+ "success": False,
+ "error": error_msg,
+ "feedback": result.get("feedback", "Failed to process voice command")
+ })
+
+
+@app.route('/api/voice_feedback')
+def api_voice_feedback():
+ """Get recent voice feedback messages."""
+ with cc.lock:
+ messages = list(cc.voice_feedback_messages)
+ return jsonify({
+ "messages": messages,
+ "last_query": cc.last_voice_query,
+ "last_prompts": cc.last_parsed_prompts
+ })
+
+
+@app.route('/api/toggle_voice', methods=['POST'])
+def api_toggle_voice():
+ """Toggle voice features."""
+ data = request.json
+ feature = data.get("feature", "voice")
+
+ if feature == "voice":
+ cc.voice_enabled = not cc.voice_enabled
+ cc.log(f"Voice input: {'ON' if cc.voice_enabled else 'OFF'}")
+ return jsonify({"success": True, "enabled": cc.voice_enabled})
+ elif feature == "tts":
+ cc.tts_enabled = not cc.tts_enabled
+ cc.log(f"TTS output: {'ON' if cc.tts_enabled else 'OFF'}")
+ return jsonify({"success": True, "enabled": cc.tts_enabled})
+
+ return jsonify({"success": False, "error": "Unknown feature"})
+
+
+# ===== CAMERA ROUTES =====
+
+@app.route('/api/cameras')
+def api_cameras():
+ """Get list of available cameras."""
+ cameras = detect_available_cameras()
+ cc.available_cameras = cameras
+ return jsonify({
+ "cameras": cameras,
+ "current_camera": cc.current_camera_id,
+ "flip_horizontal": cc.flip_horizontal,
+ "flip_vertical": cc.flip_vertical
+ })
+
+
+@app.route('/api/switch_camera', methods=['POST'])
+def api_switch_camera():
+ """Switch to a different camera."""
+ data = request.json
+ camera_id = data.get("camera_id")
+
+ if camera_id is None:
+ return jsonify({"success": False, "error": "No camera_id provided"})
+
+ camera_id = int(camera_id)
+
+ success = switch_camera(camera_id)
+
+ return jsonify({
+ "success": success,
+ "current_camera": cc.current_camera_id,
+ "message": f"Switched to camera {camera_id}" if success else f"Failed to switch to camera {camera_id}"
+ })
+
+
+@app.route('/api/flip_camera', methods=['POST'])
+def api_flip_camera():
+ """Toggle camera flip (horizontal/vertical)."""
+ data = request.json
+ direction = data.get("direction", "horizontal")
+
+ if direction == "horizontal":
+ cc.flip_horizontal = not cc.flip_horizontal
+ cc.log(f"Horizontal flip: {'ON' if cc.flip_horizontal else 'OFF'}")
+ # Reset detection state when flip changes
+ reset_detection_state()
+ return jsonify({
+ "success": True,
+ "flip_horizontal": cc.flip_horizontal,
+ "flip_vertical": cc.flip_vertical
+ })
+ elif direction == "vertical":
+ cc.flip_vertical = not cc.flip_vertical
+ cc.log(f"Vertical flip: {'ON' if cc.flip_vertical else 'OFF'}")
+ # Reset detection state when flip changes
+ reset_detection_state()
+ return jsonify({
+ "success": True,
+ "flip_horizontal": cc.flip_horizontal,
+ "flip_vertical": cc.flip_vertical
+ })
+ elif direction == "both":
+ cc.flip_horizontal = not cc.flip_horizontal
+ cc.flip_vertical = not cc.flip_vertical
+ cc.log(f"Flip both: H={'ON' if cc.flip_horizontal else 'OFF'}, V={'ON' if cc.flip_vertical else 'OFF'}")
+ reset_detection_state()
+ return jsonify({
+ "success": True,
+ "flip_horizontal": cc.flip_horizontal,
+ "flip_vertical": cc.flip_vertical
+ })
+
+ return jsonify({"success": False, "error": "Invalid direction"})
+
+
+@app.route('/api/set_flip', methods=['POST'])
+def api_set_flip():
+ """Set flip state explicitly."""
+ data = request.json
+ flip_h = data.get("flip_horizontal")
+ flip_v = data.get("flip_vertical")
+
+ changed = False
+
+ if flip_h is not None and flip_h != cc.flip_horizontal:
+ cc.flip_horizontal = bool(flip_h)
+ changed = True
+
+ if flip_v is not None and flip_v != cc.flip_vertical:
+ cc.flip_vertical = bool(flip_v)
+ changed = True
+
+ if changed:
+ cc.log(f"Flip set: H={'ON' if cc.flip_horizontal else 'OFF'}, V={'ON' if cc.flip_vertical else 'OFF'}")
+ reset_detection_state()
+
+ return jsonify({
+ "success": True,
+ "flip_horizontal": cc.flip_horizontal,
+ "flip_vertical": cc.flip_vertical
+ })
+
+
+# ===== REFERENCE IMAGE SEARCH API =====
+
+@app.route('/api/upload_reference', methods=['POST'])
+def api_upload_reference():
+ """
+ Upload a reference image for search.
+ Modes:
+ - 'description': Use Claude to describe, then search by text
+ - 'visual': Use CLIP for visual similarity matching
+ """
+ global cc
+
+ if 'image' not in request.files:
+ return jsonify({"success": False, "error": "No image provided"})
+
+ mode = request.form.get('mode', 'description') # 'description' or 'visual'
+
+ try:
+ file = request.files['image']
+ image_data = file.read()
+
+ # Convert to PIL Image
+ pil_image = Image.open(io.BytesIO(image_data)).convert('RGB')
+ cc.reference_image = pil_image
+
+ # Get base64 for Claude
+ buffered = io.BytesIO()
+ pil_image.save(buffered, format="JPEG", quality=90)
+ base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
+
+ if mode == 'description':
+ # Use Claude to describe the image
+ cc.log("Analyzing reference image with Claude...")
+ description = describe_image_with_claude(base64_image)
+
+ if description:
+ cc.reference_description = description
+ cc.visual_match_enabled = False
+
+ # Set as prompt
+ cc.prompts = [description]
+ cc.state = None
+ cc.last_masks = None
+ reset_detection_state()
+
+ cc.log(f"Reference search: '{description}'", "SUCCESS")
+
+ return jsonify({
+ "success": True,
+ "mode": "description",
+ "description": description,
+ "prompt": description
+ })
+ else:
+ return jsonify({"success": False, "error": "Failed to describe image"})
+
+ elif mode == 'visual':
+ # Use CLIP for visual matching
+ if not cc.clip_available:
+ return jsonify({
+ "success": False,
+ "error": "CLIP not available. Install with: pip install transformers"
+ })
+
+ cc.log("Computing CLIP embedding for reference image...")
+ embedding = get_clip_embedding(pil_image)
+
+ if embedding is not None:
+ cc.reference_embedding = embedding
+ cc.visual_match_enabled = True
+
+ # Also get a description for display
+ description = describe_image_with_claude(base64_image)
+ cc.reference_description = description or "Visual reference"
+
+ # Set a generic prompt to detect objects
+ cc.prompts = ["object"]
+ cc.state = None
+ cc.last_masks = None
+ reset_detection_state()
+
+ cc.log(f"Visual matching enabled for: {cc.reference_description}", "SUCCESS")
+
+ return jsonify({
+ "success": True,
+ "mode": "visual",
+ "description": cc.reference_description,
+ "message": "Visual matching enabled"
+ })
+ else:
+ return jsonify({"success": False, "error": "Failed to compute CLIP embedding"})
+
+ else:
+ return jsonify({"success": False, "error": f"Unknown mode: {mode}"})
+
+ except Exception as e:
+ cc.log(f"Reference upload failed: {e}", "ERROR")
+ return jsonify({"success": False, "error": str(e)})
+
+
+@app.route('/api/clear_reference', methods=['POST'])
+def api_clear_reference():
+ """Clear the reference image."""
+ global cc
+
+ cc.reference_image = None
+ cc.reference_embedding = None
+ cc.reference_description = None
+ cc.visual_match_enabled = False
+
+ cc.log("Reference image cleared")
+
+ return jsonify({"success": True})
+
+
+@app.route('/api/reference_status')
+def api_reference_status():
+ """Get reference image status."""
+ return jsonify({
+ "has_reference": cc.reference_image is not None,
+ "description": cc.reference_description,
+ "visual_match_enabled": cc.visual_match_enabled,
+ "clip_available": cc.clip_available,
+ "threshold": cc.visual_match_threshold
+ })
+
+
+# ===== GEOMETRIC PROMPTS (DRAW TO SEARCH) API =====
+
+@app.route('/api/draw_prompt', methods=['POST'])
+def api_draw_prompt():
+ """
+ Set a geometric prompt (box or point) from user drawing.
+ This will be processed on the next frame.
+ """
+ global cc
+
+ data = request.json
+ prompt_type = data.get('type', 'box') # 'box' or 'point'
+
+ if prompt_type == 'box':
+ x1 = data.get('x1')
+ y1 = data.get('y1')
+ x2 = data.get('x2')
+ y2 = data.get('y2')
+
+ if all(v is not None for v in [x1, y1, x2, y2]):
+ cc.pending_box_prompt = (float(x1), float(y1), float(x2), float(y2))
+ cc.pending_point_prompt = None
+ cc.draw_mode = 'box'
+ cc.log(f"Box prompt set: ({x1:.0f}, {y1:.0f}) to ({x2:.0f}, {y2:.0f})")
+
+ return jsonify({
+ "success": True,
+ "type": "box",
+ "box": [x1, y1, x2, y2]
+ })
+ else:
+ return jsonify({"success": False, "error": "Invalid box coordinates"})
+
+ elif prompt_type == 'point':
+ x = data.get('x')
+ y = data.get('y')
+
+ if x is not None and y is not None:
+ cc.pending_point_prompt = (float(x), float(y))
+ cc.pending_box_prompt = None
+ cc.draw_mode = 'point'
+ cc.log(f"Point prompt set: ({x:.0f}, {y:.0f})")
+
+ return jsonify({
+ "success": True,
+ "type": "point",
+ "point": [x, y]
+ })
+ else:
+ return jsonify({"success": False, "error": "Invalid point coordinates"})
+
+ else:
+ return jsonify({"success": False, "error": f"Unknown prompt type: {prompt_type}"})
+
+
+@app.route('/api/clear_draw_prompt', methods=['POST'])
+def api_clear_draw_prompt():
+ """Clear any pending geometric prompts."""
+ global cc
+
+ cc.pending_box_prompt = None
+ cc.pending_point_prompt = None
+ cc.draw_mode = None
+
+ cc.log("Draw prompt cleared")
+
+ return jsonify({"success": True})
+
+
+# ===== NAVIGATION SYSTEM API =====
+
+@app.route('/api/navigation/start', methods=['POST'])
+def api_navigation_start():
+ """Start navigation to a detected object."""
+ global cc
+
+ data = request.json
+ target_label = data.get("target_label") or data.get("label")
+ target_id = data.get("target_id") or data.get("detection_id")
+
+ if not target_label and target_id is None:
+ return jsonify({"success": False, "error": "No target specified"})
+
+ # Check for location memory first (from SQLite)
+ memory = cc.recall_location(target_label) if target_label else None
+ memory_hint = None
+ if memory:
+ memory_hint = f"I remember finding {target_label} in the {memory.get('context', 'unknown location')} before."
+
+ cc.navigation_active = True
+ cc.navigation_target = target_label
+ cc.navigation_target_id = target_id
+ cc.navigation_start_time = time.time()
+ cc.navigation_last_seen = None
+ cc.navigation_reached = False
+ cc.navigation_target_history = []
+
+ # Start obstacle detection
+ cc.obstacle_detection_active = True
+ cc.current_obstacles = []
+ cc.obstacle_masks = None
+ cc.obstacle_boxes = None
+
+ # Create navigation session in database
+ if cc.session_id:
+ cc.navigation_db_id = db.start_navigation_session(cc.session_id, target_label, target_id)
+ db.log_event(cc.session_id, "navigation_start", f"Started navigation to {target_label}",
+ data={"target_label": target_label, "target_id": target_id})
+
+ # Analyze scene context
+ if cc.current_raw_frame is not None:
+ try:
+ _, buffer = cv2.imencode('.jpg', cc.current_raw_frame, [cv2.IMWRITE_JPEG_QUALITY, 70])
+ image_data = base64.b64encode(buffer).decode('utf-8')
+ cc.navigation_context = analyze_scene_context(image_data)
+ except Exception as e:
+ cc.log(f"Scene context analysis failed: {e}", "WARN")
+ cc.navigation_context = None
+
+ cc.log(f"Navigation started: looking for '{target_label}'", "SUCCESS")
+
+ # Initial message
+ location = cc.navigation_context.get("location", "this area") if cc.navigation_context else "this area"
+ initial_message = f"Starting navigation to find {target_label}. You appear to be in {location}."
+ if memory_hint:
+ initial_message += f" {memory_hint}"
+
+ return jsonify({
+ "success": True,
+ "target": target_label,
+ "initial_message": initial_message,
+ "memory_hint": memory_hint,
+ "context": cc.navigation_context
+ })
+
+
+@app.route('/api/navigation/stop', methods=['POST'])
+def api_navigation_stop():
+ """Stop navigation."""
+ global cc
+
+ was_active = cc.navigation_active
+ target = cc.navigation_target
+ reached = cc.navigation_reached
+
+ # If we reached the target, remember its location (in SQLite)
+ if reached and cc.navigation_context and target:
+ location = cc.navigation_context.get("location", "unknown location")
+ cc.remember_location(target, location)
+
+ # End navigation session in database
+ if cc.navigation_db_id:
+ db.end_navigation_session(
+ cc.navigation_db_id,
+ reached=reached,
+ path_history=cc.navigation_target_history,
+ scene_context=cc.navigation_context
+ )
+ if cc.session_id:
+ db.log_event(cc.session_id, "navigation_stop",
+ f"Navigation to {target} {'reached' if reached else 'cancelled'}",
+ data={"target": target, "reached": reached})
+
+ # Stop obstacle detection
+ cc.obstacle_detection_active = False
+ cc.current_obstacles = []
+ cc.obstacle_masks = None
+ cc.obstacle_boxes = None
+
+ cc.navigation_active = False
+ cc.navigation_target = None
+ cc.navigation_target_id = None
+ cc.navigation_db_id = None
+ cc.navigation_start_time = None
+ cc.navigation_last_seen = None
+ cc.navigation_reached = False
+ cc.navigation_context = None
+ cc.navigation_target_history = []
+
+ if was_active:
+ cc.log(f"Navigation ended for '{target}'")
+
+ return jsonify({
+ "success": True,
+ "reached": reached,
+ "show_post_nav_dialog": was_active # Tell UI to show continue/pause dialog
+ })
+
+
+@app.route('/api/navigation/status')
+def api_navigation_status():
+ """Get current navigation status and guidance."""
+ status = get_navigation_status()
+
+ # Add TTS guidance if needed
+ if status.get("active") and status.get("guidance"):
+ current_time = time.time()
+ guidance_text = status["guidance"].get("guidance_text", "")
+
+ # Only speak if enough time has passed and guidance changed
+ if (current_time - cc.navigation_last_guidance_time > cc.navigation_guidance_interval and
+ guidance_text != cc.navigation_last_guidance):
+ status["speak_guidance"] = True
+ cc.navigation_last_guidance = guidance_text
+ cc.navigation_last_guidance_time = current_time
+ else:
+ status["speak_guidance"] = False
+
+ # Add obstacle alerts with position for AR path routing
+ if cc.current_obstacles:
+ obstacles_for_alert = []
+ for obs in cc.current_obstacles:
+ if obs.get("should_alert"):
+ obstacles_for_alert.append({
+ "label": obs["label"],
+ "type": obs["type"],
+ "distance": obs["distance"],
+ "position": obs.get("position", "center"), # For AR path routing
+ "reason": obs.get("reason", ""),
+ "alert_text": f"Watch out! {obs['label']} {obs['distance'].replace('_', ' ')}"
+ })
+ status["obstacles"] = obstacles_for_alert
+
+ return jsonify(status)
+
+
+@app.route('/api/navigation/analyze_scene', methods=['POST'])
+def api_navigation_analyze_scene():
+ """Analyze current scene for navigation context."""
+ global cc
+
+ if cc.current_raw_frame is None:
+ return jsonify({"success": False, "error": "No frame available"})
+
+ try:
+ _, buffer = cv2.imencode('.jpg', cc.current_raw_frame, [cv2.IMWRITE_JPEG_QUALITY, 70])
+ image_data = base64.b64encode(buffer).decode('utf-8')
+ context = analyze_scene_context(image_data)
+
+ if context:
+ cc.navigation_context = context
+ return jsonify({"success": True, "context": context})
+ else:
+ return jsonify({"success": False, "error": "Analysis failed"})
+
+ except Exception as e:
+ return jsonify({"success": False, "error": str(e)})
+
+
+@app.route('/api/location_memory')
+def api_location_memory():
+ """Get stored location memory (from SQLite)."""
+ memories = cc.get_all_location_memories()
+ return jsonify({
+ "success": True,
+ "memory": memories
+ })
+
+
+@app.route('/api/location_memory/recall', methods=['POST'])
+def api_recall_location():
+ """Recall where an object was last found (from SQLite)."""
+ data = request.json
+ label = data.get("label", "")
+
+ memory = cc.recall_location(label)
+
+ if memory:
+ return jsonify({
+ "success": True,
+ "found": True,
+ "label": label,
+ "location": memory.get("context"),
+ "frequency": memory.get("frequency", 1),
+ "last_seen": memory.get("last_seen")
+ })
+ else:
+ return jsonify({
+ "success": True,
+ "found": False,
+ "label": label,
+ "message": f"No memory of where {label} was found"
+ })
+
+
+@app.route('/api/location_memory/clear', methods=['POST'])
+def api_clear_location_memory():
+ """Clear location memory."""
+ data = request.json or {}
+ label = data.get("label")
+
+ cc.clear_location_memory(label)
+
+ return jsonify({
+ "success": True,
+ "message": f"Cleared location memory" + (f" for {label}" if label else "")
+ })
+
+
+# ===== OBSTACLE DETECTION API =====
+
+@app.route('/api/navigation/obstacles')
+def api_navigation_obstacles():
+ """Get current obstacles detected during navigation."""
+ return jsonify({
+ "success": True,
+ "obstacles": cc.current_obstacles,
+ "active": cc.obstacle_detection_active
+ })
+
+
+# ===== DATABASE HISTORY API =====
+
+@app.route('/api/history/detections')
+def api_history_detections():
+ """Get detection history from database."""
+ label = request.args.get('label')
+ limit = int(request.args.get('limit', 100))
+
+ history = db.get_detection_history(session_id=cc.session_id, label=label, limit=limit)
+
+ return jsonify({
+ "success": True,
+ "detections": history,
+ "count": len(history)
+ })
+
+
+@app.route('/api/history/analysis')
+def api_history_analysis():
+ """Get analysis history from database."""
+ limit = int(request.args.get('limit', 50))
+
+ history = db.get_analysis_history(session_id=cc.session_id, limit=limit)
+
+ return jsonify({
+ "success": True,
+ "analyses": history,
+ "count": len(history)
+ })
+
+
+@app.route('/api/history/navigation')
+def api_history_navigation():
+ """Get navigation history from database."""
+ limit = int(request.args.get('limit', 20))
+
+ history = db.get_navigation_history(session_id=cc.session_id, limit=limit)
+
+ return jsonify({
+ "success": True,
+ "navigations": history,
+ "count": len(history)
+ })
+
+
+@app.route('/api/session/stats')
+def api_session_stats():
+ """Get statistics for the current session."""
+ if not cc.session_id:
+ return jsonify({"success": False, "error": "No active session"})
+
+ stats = db.get_session_stats(cc.session_id)
+
+ return jsonify({
+ "success": True,
+ "session_id": cc.session_id,
+ "stats": stats
+ })
+
+
+def generate_self_signed_cert(cert_dir: str = None) -> Tuple[str, str]:
+ """Generate a self-signed SSL certificate for HTTPS."""
+ try:
+ from cryptography import x509
+ from cryptography.x509.oid import NameOID
+ from cryptography.hazmat.primitives import hashes
+ from cryptography.hazmat.backends import default_backend
+ from cryptography.hazmat.primitives.asymmetric import rsa
+ from cryptography.hazmat.primitives import serialization
+ import datetime
+
+ if cert_dir is None:
+ cert_dir = os.path.join(os.path.dirname(__file__), '.ssl')
+
+ os.makedirs(cert_dir, exist_ok=True)
+
+ key_path = os.path.join(cert_dir, 'key.pem')
+ cert_path = os.path.join(cert_dir, 'cert.pem')
+
+ # Check if certs already exist
+ if os.path.exists(key_path) and os.path.exists(cert_path):
+ print(f"Using existing SSL certificates from {cert_dir}")
+ return cert_path, key_path
+
+ print("Generating self-signed SSL certificate...")
+
+ # Generate private key
+ key = rsa.generate_private_key(
+ public_exponent=65537,
+ key_size=2048,
+ backend=default_backend()
+ )
+
+ # Generate certificate
+ subject = issuer = x509.Name([
+ x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
+ x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "California"),
+ x509.NameAttribute(NameOID.LOCALITY_NAME, "San Francisco"),
+ x509.NameAttribute(NameOID.ORGANIZATION_NAME, "SAM3 Command Center"),
+ x509.NameAttribute(NameOID.COMMON_NAME, "localhost"),
+ ])
+
+ cert = x509.CertificateBuilder().subject_name(
+ subject
+ ).issuer_name(
+ issuer
+ ).public_key(
+ key.public_key()
+ ).serial_number(
+ x509.random_serial_number()
+ ).not_valid_before(
+ datetime.datetime.utcnow()
+ ).not_valid_after(
+ datetime.datetime.utcnow() + datetime.timedelta(days=365)
+ ).add_extension(
+ x509.SubjectAlternativeName([
+ x509.DNSName("localhost"),
+ x509.DNSName("127.0.0.1"),
+ x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")),
+ ]),
+ critical=False,
+ ).sign(key, hashes.SHA256(), default_backend())
+
+ # Write key
+ with open(key_path, "wb") as f:
+ f.write(key.private_bytes(
+ encoding=serialization.Encoding.PEM,
+ format=serialization.PrivateFormat.TraditionalOpenSSL,
+ encryption_algorithm=serialization.NoEncryption()
+ ))
+
+ # Write certificate
+ with open(cert_path, "wb") as f:
+ f.write(cert.public_bytes(serialization.Encoding.PEM))
+
+ print(f"SSL certificate generated: {cert_path}")
+ return cert_path, key_path
+
+ except ImportError:
+ print("WARNING: cryptography package not installed. Cannot generate SSL certificate.")
+ print(" Install with: pip install cryptography")
+ print(" Or provide --ssl-cert and --ssl-key arguments")
+ return None, None
+
+
+def main():
+ global cc
+
+ parser = argparse.ArgumentParser(description="SAM3 Web Command Center")
+ parser.add_argument("--camera", "-c", type=int, default=0, help="Camera device ID")
+ parser.add_argument("--device", "-d", type=str, default=None, help="Device (cuda, mps, cpu)")
+ parser.add_argument("--prompt", type=str, default="object", help="Initial prompts (comma-separated)")
+ parser.add_argument("--threshold", type=float, default=0.3, help="Confidence threshold")
+ parser.add_argument("--checkpoint", type=str, default=None, help="Model checkpoint path")
+ parser.add_argument("--port", type=int, default=5000, help="Web server port")
+ parser.add_argument("--skip-frames", type=int, default=3, help="Process every N frames")
+ parser.add_argument("--no-tracking", action="store_true", help="Disable optical flow tracking")
+ parser.add_argument("--no-yolo", action="store_true", help="Disable YOLO models")
+ parser.add_argument("--api-key", type=str, default=None, help="Anthropic API key (or set ANTHROPIC_API_KEY env var)")
+ parser.add_argument("--no-https", action="store_true", help="Disable HTTPS (not recommended - microphone won't work)")
+ parser.add_argument("--ssl-cert", type=str, default=None, help="Path to SSL certificate file")
+ parser.add_argument("--ssl-key", type=str, default=None, help="Path to SSL private key file")
+
+ args = parser.parse_args()
+
+ # Set API key from argument if provided
+ global ANTHROPIC_API_KEY
+ if args.api_key:
+ ANTHROPIC_API_KEY = args.api_key
+ print("Using API key from command line argument")
+ elif ANTHROPIC_API_KEY:
+ print("Using API key from environment variable")
+ else:
+ print("WARNING: No Anthropic API key set. Claude features (analysis, voice search) will not work.")
+ print(" Set via: --api-key YOUR_KEY or ANTHROPIC_API_KEY=YOUR_KEY")
+
+ # Configure command center
+ cc.prompts = [p.strip() for p in args.prompt.split(",") if p.strip()]
+ cc.confidence_threshold = args.threshold
+ cc.skip_frames = args.skip_frames
+ cc.enable_tracking = not args.no_tracking
+
+ if args.device:
+ cc.device_str = args.device
+
+ # Create database session
+ cc.session_id = db.create_session(
+ device=args.device or "auto",
+ prompts=cc.prompts,
+ settings={
+ "threshold": args.threshold,
+ "skip_frames": args.skip_frames,
+ "tracking": not args.no_tracking,
+ "yolo": not args.no_yolo
+ }
+ )
+ cc.log(f"Database session started: {cc.session_id[:8]}...")
+
+ # Load model
+ load_model(args.checkpoint)
+
+ # Load depth estimation model for LIDAR-like obstacle detection
+ cc.log("Loading depth estimation model for advanced obstacle detection...")
+ load_depth_model()
+ if depth_model is not None:
+ cc.log("Depth estimation model loaded successfully", "SUCCESS")
+ else:
+ cc.log("Depth estimation unavailable - using other detection layers", "WARNING")
+
+ # Skip YOLO if requested
+ if args.no_yolo:
+ cc.yolo_available = False
+ cc.log("YOLO disabled via command line")
+
+ # Detect available cameras
+ cc.log("Detecting available cameras...")
+ cc.available_cameras = detect_available_cameras()
+ cc.log(f"Found {len(cc.available_cameras)} camera(s)", "SUCCESS")
+ for cam in cc.available_cameras:
+ cc.log(f" Camera {cam['id']}: {cam['description']}")
+
+ # Open camera
+ cc.log(f"Opening camera {args.camera}...")
+ cc.camera = cv2.VideoCapture(args.camera)
+ cc.current_camera_id = args.camera
+
+ if not cc.camera.isOpened():
+ cc.log(f"Failed to open camera {args.camera}", "ERROR")
+ return
+
+ width = int(cc.camera.get(cv2.CAP_PROP_FRAME_WIDTH))
+ height = int(cc.camera.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ cc.log(f"Camera opened: {width}x{height}", "SUCCESS")
+
+ cc.running = True
+
+ # Start analysis worker
+ analysis_thread = threading.Thread(target=analysis_worker, daemon=True)
+ analysis_thread.start()
+
+ print(f"\n{'='*50}")
+ print(f"SAM3 Web Command Center")
+ print(f"{'='*50}")
+
+ # Setup SSL (HTTPS is default, use --no-https to disable)
+ ssl_context = None
+ protocol = "http"
+
+ if not args.no_https:
+ if args.ssl_cert and args.ssl_key:
+ # Use provided certificates
+ if os.path.exists(args.ssl_cert) and os.path.exists(args.ssl_key):
+ ssl_context = (args.ssl_cert, args.ssl_key)
+ protocol = "https"
+ print(f"Using provided SSL certificates")
+ else:
+ print(f"ERROR: SSL certificate files not found")
+ print(f" Cert: {args.ssl_cert}")
+ print(f" Key: {args.ssl_key}")
+ return
+ else:
+ # Generate self-signed certificate
+ cert_path, key_path = generate_self_signed_cert()
+ if cert_path and key_path:
+ ssl_context = (cert_path, key_path)
+ protocol = "https"
+ print(f"Using auto-generated self-signed certificate")
+ print(f" NOTE: You may need to accept the security warning in your browser")
+ else:
+ print("WARNING: Could not setup HTTPS. Falling back to HTTP.")
+ print(" Microphone and navigation features may not work without HTTPS!")
+ else:
+ print("WARNING: HTTPS disabled. Microphone and navigation features may not work!")
+
+ print(f"Open {protocol}://localhost:{args.port} in your browser")
+ print(f"YOLO: {'Available' if cc.yolo_available else 'Not available'}")
+ print(f"CLIP: {'Available' if cc.clip_available else 'Not available'}")
+ if protocol == "https":
+ print(f"HTTPS: Enabled (microphone and navigation available)")
+ else:
+ print(f"HTTPS: Disabled (use default or remove --no-https for full features)")
+ print(f"{'='*50}\n")
+
+ try:
+ if ssl_context:
+ app.run(host='0.0.0.0', port=args.port, threaded=True, debug=False, ssl_context=ssl_context)
+ else:
+ app.run(host='0.0.0.0', port=args.port, threaded=True, debug=False)
+ finally:
+ cc.running = False
+ if cc.camera:
+ cc.camera.release()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/web_command_center/templates/index.html b/examples/web_command_center/templates/index.html
new file mode 100644
index 00000000..cf931063
--- /dev/null
+++ b/examples/web_command_center/templates/index.html
@@ -0,0 +1,4006 @@
+
+
+
+
+
+ SAM3 Command Center
+
+
+
+
+
+
+
+
+
+

+
+
+
+
+
+
+ ›
+ ›
+ ›
+
+
+
+
+
+
Real View Navigation
+
+
+
+
+
+ ↑
+ Ahead
+
+
~2m
+
+ 🎯
+ Looking for object...
+
+
+
+
+
+
+
+
+
+
+
+
+
+ ↑
+ Searching...
+
+
+
+ Distance:
+ Unknown
+
+
+
+ Looking for object...
+
+
+
+ Scene:
+ Analyzing...
+
+
+
+
+
+
+
+
+
+
+
+ 📍
+
+
+
+
+
+
+
+
+
+
+
+
+

+
+
+
+
+
+
+
Controls
+
Camera
+
Voice Search
+
Reference Search
+
Features
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Camera Selection
+
+
+
+
+
+
Select a camera
+
+
+
+
+
Flip / Mirror
+
+ Flip the camera feed if the image appears reversed
+
+
+
+
+
+
No flip applied
+
+
+
+
Camera Tips
+
+ - Use Flip Horizontal to mirror the image (for front-facing cameras)
+ - Changing cameras resets all tracked objects
+ - If a camera doesn't work, try refreshing the list
+ - External cameras may take a moment to initialize
+
+
+
+
+
+
+
+
Voice-to-AI Search
+
+ Click the microphone and say something like:
+ "Help me find a red car" or "Search for a person and a dog"
+
+
+
+
+
+ Click microphone to start
+
+
+
+
+
+
+
+
Text-to-Speech (TTS)
+
+
+
+ Enable TTS Feedback
+ Speak search confirmations aloud
+
+
+
+
+
+
+
+
+
+
+
+
+
Voice History
+
+
No voice searches yet
+
+
+
+
+
+
+
+
+
+
Reference Image Search
+
+ Upload an image of an object to find similar objects in the live feed
+
+
+
+
+
+
+
+
Draw to Search
+
+ Draw a box around an object in the video to select it
+
+
+
+
+
+
+
+
+
+ Click a button above, then draw on the video
+
+
+
+
+
+
Visual Match Settings
+
+
+
+
+
+ Loose (0.5)
+ 0.75
+ Strict (0.95)
+
+
+
+
+
+
+
+
+
+
+
+
Tracking
+
+
+
+ Optical Flow Tracking
+ Track masks between keyframes
+
+
+
+
+
+
+ Memory Tracking
+ Store mask history for re-identification
+
+
+
+
+
+
+
+
+
+
+ Persistent Object IDs
+ Assign stable IDs to tracked objects
+
+
+
+
+
+
+
+
+
+
+
+
Mask Refinement
+
+
+
+ Fill Holes
+ Fill small gaps in masks
+
+
+
+
+
+
+
+
+
+
+ Smooth Edges
+ Morphological smoothing of mask edges
+
+
+
+
+
+
+
+
+
+
+ Non-Overlapping Masks
+ Prevent mask overlaps (higher conf wins)
+
+
+
+
+
+
+
+
Detection Controls
+
+
+
+ Boundary Suppression
+ Ignore detections near frame edges
+
+
+
+
+
+
+
+
+
+
+ Occlusion Suppression
+ Remove heavily overlapped detections
+
+
+
+
+
+
+
+
+
+
+ Hotstart Mode
+ Require N frames before confirming detection
+
+
+
+
+
+
+
+
+
+
+
+
YOLO Integration (v12)
+
+
+
+ YOLO Classification
+ Run YOLOv12 classification on detected regions
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ YOLO Pose Estimation
+ Detect body keypoints for person-like objects
+
+
+
+
+
+
+
+
+
+
+ Show Skeleton
+ Draw skeleton lines connecting keypoints
+
+
+
+
+
+
+
+
+
+
+ Show Keypoint Labels
+ Display names for each detected keypoint
+
+
+
+
+
+
+
+
+
+
+ Label Spoofing
+ Map SAM3 labels to COCO classes for YOLO
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/sam3/model/decoder.py b/sam3/model/decoder.py
index c8b1657e..3b0ded8b 100644
--- a/sam3/model/decoder.py
+++ b/sam3/model/decoder.py
@@ -11,6 +11,7 @@
import torch
from sam3.sam.transformer import RoPEAttention
+from sam3.utils.device import get_device
from torch import nn, Tensor
from torchvision.ops.roi_align import RoIAlign
@@ -278,7 +279,7 @@ def __init__(
if resolution is not None and stride is not None:
feat_size = resolution // stride
coords_h, coords_w = self._get_coords(
- feat_size, feat_size, device="cuda"
+ feat_size, feat_size, device=get_device()
)
self.compilable_cord_cache = (coords_h, coords_w)
self.compilable_stored_size = (feat_size, feat_size)
diff --git a/sam3/model/edt.py b/sam3/model/edt.py
index 9448c1d3..65b0d4cf 100644
--- a/sam3/model/edt.py
+++ b/sam3/model/edt.py
@@ -1,10 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
-"""Triton kernel for euclidean distance transform (EDT)"""
+"""Euclidean distance transform (EDT) with optional Triton kernel acceleration for CUDA devices."""
import torch
-import triton
-import triton.language as tl
+
+# Try to import Triton (only available on CUDA)
+try:
+ import triton
+ import triton.language as tl
+
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
"""
Disclaimer: This implementation is not meant to be extremely efficient. A CUDA kernel would likely be more efficient.
@@ -50,74 +57,193 @@
"""
-@triton.jit
-def edt_kernel(inputs_ptr, outputs_ptr, v, z, height, width, horizontal: tl.constexpr):
- # This is a somewhat verbatim implementation of the efficient 1D EDT algorithm described above
- # It can be applied horizontally or vertically depending if we're doing the first or second stage.
- # It's parallelized across batch+row (or batch+col if horizontal=False)
- # TODO: perhaps the implementation can be revisited if/when local gather/scatter become available in triton
- batch_id = tl.program_id(axis=0)
- if horizontal:
- row_id = tl.program_id(axis=1)
- block_start = (batch_id * height * width) + row_id * width
- length = width
- stride = 1
- else:
- col_id = tl.program_id(axis=1)
- block_start = (batch_id * height * width) + col_id
- length = height
- stride = width
-
- # This will be the index of the right most parabola in the envelope ("the top of the stack")
- k = 0
- for q in range(1, length):
- # Read the function value at the current location. Note that we're doing a singular read, not very efficient
- cur_input = tl.load(inputs_ptr + block_start + (q * stride))
- # location of the parabola on top of the stack
- r = tl.load(v + block_start + (k * stride))
- # associated boundary
- z_k = tl.load(z + block_start + (k * stride))
- # value of the function at the parabola location
- previous_input = tl.load(inputs_ptr + block_start + (r * stride))
- # intersection between the two parabolas
- s = (cur_input - previous_input + q * q - r * r) / (q - r) / 2
-
- # we'll pop as many parabolas as required
- while s <= z_k and k - 1 >= 0:
- k = k - 1
+# ============================================================================
+# PyTorch-based implementations (for CPU, MPS, and fallback)
+# ============================================================================
+
+
+def edt_pytorch(data: torch.Tensor) -> torch.Tensor:
+ """
+ Computes the Euclidean Distance Transform (EDT) of a batch of binary images using scipy.
+
+ This is a fallback implementation for non-CUDA devices. It processes each image
+ in the batch individually using scipy's distance_transform_edt.
+
+ Args:
+ data: A tensor of shape (B, H, W) representing a batch of binary images.
+
+ Returns:
+ A tensor of the same shape as data containing the EDT.
+ It should be equivalent to a batched version of cv2.distanceTransform(input, cv2.DIST_L2, 0)
+ """
+ from scipy.ndimage import distance_transform_edt
+
+ assert data.dim() == 3, "Input tensor must have shape (B, H, W)"
+
+ device = data.device
+ dtype = data.dtype
+ B, H, W = data.shape
+
+ # Convert to numpy for scipy processing
+ data_np = data.cpu().numpy()
+
+ # Allocate output
+ output_np = data_np.copy().astype("float32")
+
+ # Process each image in the batch
+ for b in range(B):
+ # scipy's distance_transform_edt computes distance to nearest zero pixel
+ # We need to invert the mask because scipy computes distance to zero
+ # If data[i,j] == 0, EDT should be 0; otherwise distance to nearest 0
+ mask = data_np[b] != 0
+ output_np[b] = distance_transform_edt(mask)
+
+ # Convert back to tensor and move to original device
+ output = torch.from_numpy(output_np).to(device=device, dtype=dtype)
+ return output
+
+
+# ============================================================================
+# Triton-based implementations (CUDA only)
+# ============================================================================
+
+if HAS_TRITON:
+
+ @triton.jit
+ def edt_kernel(
+ inputs_ptr, outputs_ptr, v, z, height, width, horizontal: tl.constexpr
+ ):
+ # This is a somewhat verbatim implementation of the efficient 1D EDT algorithm described above
+ # It can be applied horizontally or vertically depending if we're doing the first or second stage.
+ # It's parallelized across batch+row (or batch+col if horizontal=False)
+ # TODO: perhaps the implementation can be revisited if/when local gather/scatter become available in triton
+ batch_id = tl.program_id(axis=0)
+ if horizontal:
+ row_id = tl.program_id(axis=1)
+ block_start = (batch_id * height * width) + row_id * width
+ length = width
+ stride = 1
+ else:
+ col_id = tl.program_id(axis=1)
+ block_start = (batch_id * height * width) + col_id
+ length = height
+ stride = width
+
+ # This will be the index of the right most parabola in the envelope ("the top of the stack")
+ k = 0
+ for q in range(1, length):
+ # Read the function value at the current location. Note that we're doing a singular read, not very efficient
+ cur_input = tl.load(inputs_ptr + block_start + (q * stride))
+ # location of the parabola on top of the stack
r = tl.load(v + block_start + (k * stride))
+ # associated boundary
z_k = tl.load(z + block_start + (k * stride))
+ # value of the function at the parabola location
previous_input = tl.load(inputs_ptr + block_start + (r * stride))
+ # intersection between the two parabolas
s = (cur_input - previous_input + q * q - r * r) / (q - r) / 2
- # Store the new one
- k = k + 1
- tl.store(v + block_start + (k * stride), q)
- tl.store(z + block_start + (k * stride), s)
- if k + 1 < length:
- tl.store(z + block_start + ((k + 1) * stride), 1e9)
-
- # Last step, we read the envelope to find the min in every location
- k = 0
- for q in range(length):
- while (
- k + 1 < length
- and tl.load(
- z + block_start + ((k + 1) * stride), mask=(k + 1) < length, other=q
- )
- < q
- ):
- k += 1
- r = tl.load(v + block_start + (k * stride))
- d = q - r
- old_value = tl.load(inputs_ptr + block_start + (r * stride))
- tl.store(outputs_ptr + block_start + (q * stride), old_value + d * d)
-
-
-def edt_triton(data: torch.Tensor):
+ # we'll pop as many parabolas as required
+ while s <= z_k and k - 1 >= 0:
+ k = k - 1
+ r = tl.load(v + block_start + (k * stride))
+ z_k = tl.load(z + block_start + (k * stride))
+ previous_input = tl.load(inputs_ptr + block_start + (r * stride))
+ s = (cur_input - previous_input + q * q - r * r) / (q - r) / 2
+
+ # Store the new one
+ k = k + 1
+ tl.store(v + block_start + (k * stride), q)
+ tl.store(z + block_start + (k * stride), s)
+ if k + 1 < length:
+ tl.store(z + block_start + ((k + 1) * stride), 1e9)
+
+ # Last step, we read the envelope to find the min in every location
+ k = 0
+ for q in range(length):
+ while (
+ k + 1 < length
+ and tl.load(
+ z + block_start + ((k + 1) * stride), mask=(k + 1) < length, other=q
+ )
+ < q
+ ):
+ k += 1
+ r = tl.load(v + block_start + (k * stride))
+ d = q - r
+ old_value = tl.load(inputs_ptr + block_start + (r * stride))
+ tl.store(outputs_ptr + block_start + (q * stride), old_value + d * d)
+
+ def edt_triton_impl(data: torch.Tensor) -> torch.Tensor:
+ """
+ Computes the Euclidean Distance Transform (EDT) of a batch of binary images using Triton.
+
+ Args:
+ data: A tensor of shape (B, H, W) representing a batch of binary images.
+
+ Returns:
+ A tensor of the same shape as data containing the EDT.
+ It should be equivalent to a batched version of cv2.distanceTransform(input, cv2.DIST_L2, 0)
+ """
+ assert data.dim() == 3
+ assert data.is_cuda
+ B, H, W = data.shape
+ data = data.contiguous()
+
+ # Allocate the "function" tensor. Implicitly the function is 0 if data[i,j]==0 else +infinity
+ output = torch.where(data, 1e18, 0.0)
+ assert output.is_contiguous()
+
+ # Scratch tensors for the parabola stacks
+ parabola_loc = torch.zeros(B, H, W, dtype=torch.uint32, device=data.device)
+ parabola_inter = torch.empty(B, H, W, dtype=torch.float, device=data.device)
+ parabola_inter[:, :, 0] = -1e18
+ parabola_inter[:, :, 1] = 1e18
+
+ # Grid size (number of blocks)
+ grid = (B, H)
+
+ # Launch initialization kernel
+ edt_kernel[grid](
+ output.clone(),
+ output,
+ parabola_loc,
+ parabola_inter,
+ H,
+ W,
+ horizontal=True,
+ )
+
+ # reset the parabola stacks
+ parabola_loc.zero_()
+ parabola_inter[:, :, 0] = -1e18
+ parabola_inter[:, :, 1] = 1e18
+
+ grid = (B, W)
+ edt_kernel[grid](
+ output.clone(),
+ output,
+ parabola_loc,
+ parabola_inter,
+ H,
+ W,
+ horizontal=False,
+ )
+ # don't forget to take sqrt at the end
+ return output.sqrt()
+
+
+# ============================================================================
+# Public API - automatically selects best implementation
+# ============================================================================
+
+
+def edt(data: torch.Tensor) -> torch.Tensor:
"""
Computes the Euclidean Distance Transform (EDT) of a batch of binary images.
+ Uses Triton kernel on CUDA when available, falls back to scipy otherwise.
+
Args:
data: A tensor of shape (B, H, W) representing a batch of binary images.
@@ -125,49 +251,11 @@ def edt_triton(data: torch.Tensor):
A tensor of the same shape as data containing the EDT.
It should be equivalent to a batched version of cv2.distanceTransform(input, cv2.DIST_L2, 0)
"""
- assert data.dim() == 3
- assert data.is_cuda
- B, H, W = data.shape
- data = data.contiguous()
-
- # Allocate the "function" tensor. Implicitly the function is 0 if data[i,j]==0 else +infinity
- output = torch.where(data, 1e18, 0.0)
- assert output.is_contiguous()
-
- # Scratch tensors for the parabola stacks
- parabola_loc = torch.zeros(B, H, W, dtype=torch.uint32, device=data.device)
- parabola_inter = torch.empty(B, H, W, dtype=torch.float, device=data.device)
- parabola_inter[:, :, 0] = -1e18
- parabola_inter[:, :, 1] = 1e18
-
- # Grid size (number of blocks)
- grid = (B, H)
-
- # Launch initialization kernel
- edt_kernel[grid](
- output.clone(),
- output,
- parabola_loc,
- parabola_inter,
- H,
- W,
- horizontal=True,
- )
-
- # reset the parabola stacks
- parabola_loc.zero_()
- parabola_inter[:, :, 0] = -1e18
- parabola_inter[:, :, 1] = 1e18
-
- grid = (B, W)
- edt_kernel[grid](
- output.clone(),
- output,
- parabola_loc,
- parabola_inter,
- H,
- W,
- horizontal=False,
- )
- # don't forget to take sqrt at the end
- return output.sqrt()
+ if HAS_TRITON and data.is_cuda:
+ return edt_triton_impl(data)
+ else:
+ return edt_pytorch(data)
+
+
+# Legacy alias for backward compatibility
+edt_triton = edt
diff --git a/sam3/model/geometry_encoders.py b/sam3/model/geometry_encoders.py
index bff29172..1a2aa349 100644
--- a/sam3/model/geometry_encoders.py
+++ b/sam3/model/geometry_encoders.py
@@ -4,10 +4,28 @@
import torch
import torch.nn as nn
+import torch.nn.functional as F
import torchvision
from typing_extensions import override
from .act_ckpt_utils import activation_ckpt_wrapper
+
+
+def _grid_sample_mps_safe(input, grid, **kwargs):
+ """
+ MPS-safe wrapper for grid_sample.
+ MPS has bugs with grid_sample on certain tensor configurations,
+ so we fall back to CPU for MPS devices.
+ """
+ if input.device.type == "mps":
+ # Move to CPU, perform operation, move back
+ input_cpu = input.cpu()
+ grid_cpu = grid.cpu()
+ result = F.grid_sample(input_cpu, grid_cpu, **kwargs)
+ return result.to(input.device)
+ return F.grid_sample(input, grid, **kwargs)
+
+
from .box_ops import box_cxcywh_to_xyxy
from .model_misc import get_clones
@@ -44,8 +62,13 @@ def concat_padded_sequences(seq1, mask1, seq2, mask2, return_index: bool = False
assert seq1_length == mask1.size(1)
assert seq2_length == mask2.size(1)
- torch._assert_async(is_right_padded(mask1))
- torch._assert_async(is_right_padded(mask2))
+ # _assert_async is not supported on MPS, use regular assert
+ if mask1.device.type == "mps" or mask2.device.type == "mps":
+ assert is_right_padded(mask1), "mask1 must be right padded"
+ assert is_right_padded(mask2), "mask2 must be right padded"
+ else:
+ torch._assert_async(is_right_padded(mask1))
+ torch._assert_async(is_right_padded(mask2))
actual_seq1_lengths = (~mask1).sum(dim=-1)
actual_seq2_lengths = (~mask2).sum(dim=-1)
@@ -613,7 +636,7 @@ def _encode_points(self, points, points_mask, points_labels, img_feats):
grid = points.transpose(0, 1).unsqueeze(2)
# re normalize to [-1, 1]
grid = (grid * 2) - 1
- sampled = torch.nn.functional.grid_sample(
+ sampled = _grid_sample_mps_safe(
img_feats, grid, align_corners=False
)
assert list(sampled.shape) == [bs, self.d_model, n_points, 1]
@@ -656,11 +679,16 @@ def _encode_boxes(self, boxes, boxes_mask, boxes_labels, img_feats):
# We need to denormalize, and convert to [x, y, x, y]
boxes_xyxy = box_cxcywh_to_xyxy(boxes)
scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype)
- scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True)
+ # pin_memory() only works with CUDA, not MPS
+ if boxes_xyxy.device.type == "cuda":
+ scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True)
+ else:
+ scale = scale.to(device=boxes_xyxy.device)
scale = scale.view(1, 1, 4)
boxes_xyxy = boxes_xyxy * scale
+ # Match boxes dtype to img_feats dtype for roi_align (needed for half precision)
sampled = torchvision.ops.roi_align(
- img_feats, boxes_xyxy.float().transpose(0, 1).unbind(0), self.roi_size
+ img_feats, boxes_xyxy.to(img_feats.dtype).transpose(0, 1).unbind(0), self.roi_size
)
assert list(sampled.shape) == [
bs * n_boxes,
diff --git a/sam3/model/io_utils.py b/sam3/model/io_utils.py
index 0a225842..1691911b 100644
--- a/sam3/model/io_utils.py
+++ b/sam3/model/io_utils.py
@@ -15,6 +15,7 @@
from PIL import Image
from sam3.logger import get_logger
+from sam3.utils.device import get_device
from tqdm import tqdm
logger = get_logger(__name__)
@@ -63,7 +64,7 @@ def load_resource_as_video_frames(
images.append(img)
images = torch.stack(images)
if not offload_video_to_cpu:
- images = images.cuda()
+ images = images.to(get_device())
return images, orig_height, orig_width
is_image = (
@@ -104,9 +105,10 @@ def load_image_as_single_frame_video(
img_mean = torch.tensor(img_mean, dtype=torch.float16)[:, None, None]
img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None]
if not offload_video_to_cpu:
- images = images.cuda()
- img_mean = img_mean.cuda()
- img_std = img_std.cuda()
+ device = get_device()
+ images = images.to(device)
+ img_mean = img_mean.to(device)
+ img_std = img_std.to(device)
# normalize by mean and std
images -= img_mean
images /= img_std
@@ -201,9 +203,10 @@ def load_video_frames_from_image_folder(
):
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
if not offload_video_to_cpu:
- images = images.cuda()
- img_mean = img_mean.cuda()
- img_std = img_std.cuda()
+ device = get_device()
+ images = images.to(device)
+ img_mean = img_mean.to(device)
+ img_std = img_std.to(device)
# normalize by mean and std
images -= img_mean
images /= img_std
@@ -307,9 +310,10 @@ def load_video_frames_from_video_file_using_cv2(
img_mean = torch.tensor(img_mean, dtype=torch.float16).view(1, 3, 1, 1)
img_std = torch.tensor(img_std, dtype=torch.float16).view(1, 3, 1, 1)
if not offload_video_to_cpu:
- video_tensor = video_tensor.cuda()
- img_mean = img_mean.cuda()
- img_std = img_std.cuda()
+ device = get_device()
+ video_tensor = video_tensor.to(device)
+ img_mean = img_mean.to(device)
+ img_std = img_std.to(device)
# normalize by mean and std
video_tensor -= img_mean
video_tensor /= img_std
@@ -323,7 +327,7 @@ def load_dummy_video(image_size, offload_video_to_cpu, num_frames=60):
video_height, video_width = 480, 640 # dummy original video sizes
images = torch.randn(num_frames, 3, image_size, image_size, dtype=torch.float16)
if not offload_video_to_cpu:
- images = images.cuda()
+ images = images.to(get_device())
return images, video_height, video_width
@@ -392,7 +396,7 @@ def __getitem__(self, index):
img -= self.img_mean
img /= self.img_std
if not self.offload_video_to_cpu:
- img = img.cuda()
+ img = img.to(get_device())
self.images[index] = img
return img
@@ -503,16 +507,33 @@ def __init__(
use_rand_seek_in_loading=False,
):
# Check and possibly infer the output device (and also get its GPU id when applicable)
- assert gpu_device is None or gpu_device.type == "cuda"
- gpu_id = (
- gpu_device.index
- if gpu_device is not None and gpu_device.index is not None
- else torch.cuda.current_device()
- )
+ # For MPS devices, we disable GPU acceleration since TorchCodec doesn't support it
+ default_device = get_device()
+ is_mps = default_device.type == "mps"
+
+ if gpu_device is not None:
+ assert gpu_device.type in ("cuda", "mps", "cpu"), f"Unsupported device type: {gpu_device.type}"
+
+ # Disable GPU acceleration for non-CUDA devices
+ if is_mps or (gpu_device is not None and gpu_device.type != "cuda"):
+ gpu_acceleration = False
+
+ gpu_id = 0
+ if torch.cuda.is_available():
+ gpu_id = (
+ gpu_device.index
+ if gpu_device is not None and gpu_device.type == "cuda" and gpu_device.index is not None
+ else torch.cuda.current_device()
+ )
+
if offload_video_to_cpu:
out_device = torch.device("cpu")
else:
- out_device = torch.device("cuda") if gpu_device is None else gpu_device
+ if gpu_device is not None:
+ out_device = gpu_device
+ else:
+ out_device = default_device
+
self.out_device = out_device
self.gpu_acceleration = gpu_acceleration
self.gpu_id = gpu_id
@@ -525,7 +546,7 @@ def __init__(
img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None]
self.img_std = img_std
- if gpu_acceleration:
+ if gpu_acceleration and torch.cuda.is_available():
self.img_mean = self.img_mean.to(f"cuda:{self.gpu_id}")
self.img_std = self.img_std.to(f"cuda:{self.gpu_id}")
decoder_option = {"device": f"cuda:{self.gpu_id}"}
diff --git a/sam3/model/position_encoding.py b/sam3/model/position_encoding.py
index eb3f4055..2efbb5d1 100644
--- a/sam3/model/position_encoding.py
+++ b/sam3/model/position_encoding.py
@@ -6,6 +6,8 @@
import torch
from torch import nn
+from sam3.utils.device import get_device
+
class PositionEmbeddingSine(nn.Module):
"""
@@ -44,7 +46,7 @@ def __init__(
(precompute_resolution // 32, precompute_resolution // 32),
]
for size in precompute_sizes:
- tensors = torch.zeros((1, 1) + size, device="cuda")
+ tensors = torch.zeros((1, 1) + size, device=get_device())
self.forward(tensors)
# further clone and detach it in the cache (just to be safe)
self.cache[size] = self.cache[size].clone().detach()
diff --git a/sam3/model/sam3_image.py b/sam3/model/sam3_image.py
index aafe520b..db961e2a 100644
--- a/sam3/model/sam3_image.py
+++ b/sam3/model/sam3_image.py
@@ -122,7 +122,11 @@ def _get_img_feats(self, backbone_out, img_ids):
# If this assert fails, it likely means we're requesting different img_ids (perhaps a different frame?)
# We currently don't expect this to happen. We could technically trigger a recompute here,
# but likely at the cost of a cpu<->gpu sync point, which would deteriorate perf
- torch._assert_async((img_ids >= 0).all())
+ # _assert_async is not supported on MPS
+ if img_ids.device.type == "mps":
+ assert (img_ids >= 0).all(), "img_ids must be non-negative"
+ else:
+ torch._assert_async((img_ids >= 0).all())
vis_feats = backbone_out["backbone_fpn"][-self.num_feature_levels :]
vis_pos_enc = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
diff --git a/sam3/model/sam3_image_processor.py b/sam3/model/sam3_image_processor.py
index 4d98fbfb..82f410c0 100644
--- a/sam3/model/sam3_image_processor.py
+++ b/sam3/model/sam3_image_processor.py
@@ -8,13 +8,16 @@
from sam3.model import box_ops
from sam3.model.data_misc import FindStage, interpolate
+from sam3.utils.device import get_device_str
from torchvision.transforms import v2
class Sam3Processor:
""" """
- def __init__(self, model, resolution=1008, device="cuda", confidence_threshold=0.5):
+ def __init__(self, model, resolution=1008, device=None, confidence_threshold=0.5):
+ if device is None:
+ device = get_device_str()
self.model = model
self.resolution = resolution
self.device = device
@@ -54,6 +57,11 @@ def set_image(self, image, state=None):
image = v2.functional.to_image(image).to(self.device)
image = self.transform(image).unsqueeze(0)
+ # Match model dtype (for half precision support)
+ model_dtype = next(self.model.parameters()).dtype
+ if image.dtype != model_dtype:
+ image = image.to(model_dtype)
+
state["original_height"] = height
state["original_width"] = width
state["backbone_out"] = self.model.backbone.forward_image(image)
@@ -93,6 +101,12 @@ def set_image_batch(self, images: List[np.ndarray], state=None):
for image in images
]
images = torch.stack(images, dim=0)
+
+ # Match model dtype (for half precision support)
+ model_dtype = next(self.model.parameters()).dtype
+ if images.dtype != model_dtype:
+ images = images.to(model_dtype)
+
state["backbone_out"] = self.model.backbone.forward_image(images)
inst_interactivity_en = self.model.inst_interactive_predictor is not None
if inst_interactivity_en and "sam2_backbone_out" in state["backbone_out"]:
diff --git a/sam3/model/sam3_tracker_base.py b/sam3/model/sam3_tracker_base.py
index 90fbd696..94952834 100644
--- a/sam3/model/sam3_tracker_base.py
+++ b/sam3/model/sam3_tracker_base.py
@@ -164,10 +164,12 @@ def _get_tpos_enc(self, rel_pos_list, device, max_abs_pos=None, dummy=False):
return torch.zeros(len(rel_pos_list), self.mem_dim, device=device)
t_diff_max = max_abs_pos - 1 if max_abs_pos is not None else 1
- pos_enc = (
- torch.tensor(rel_pos_list).pin_memory().to(device=device, non_blocking=True)
- / t_diff_max
- )
+ # pin_memory() only works with CUDA, not MPS
+ rel_pos_tensor = torch.tensor(rel_pos_list)
+ if device.type == "cuda" if isinstance(device, torch.device) else device == "cuda":
+ pos_enc = rel_pos_tensor.pin_memory().to(device=device, non_blocking=True) / t_diff_max
+ else:
+ pos_enc = rel_pos_tensor.to(device=device) / t_diff_max
tpos_dim = self.hidden_dim
pos_enc = get_1d_sine_pe(pos_enc, dim=tpos_dim)
pos_enc = self.obj_ptr_tpos_proj(pos_enc)
@@ -653,15 +655,15 @@ def _prepare_memory_conditioned_features(
if prev is None:
continue # skip padding frames
# "maskmem_features" might have been offloaded to CPU in demo use cases,
- # so we load it back to GPU (it's a no-op if it's already on GPU).
- feats = prev["maskmem_features"].cuda(non_blocking=True)
+ # so we load it back to the model's device (it's a no-op if it's already there).
+ feats = prev["maskmem_features"].to(device, non_blocking=True)
seq_len = feats.shape[-2] * feats.shape[-1]
to_cat_prompt.append(feats.flatten(2).permute(2, 0, 1))
to_cat_prompt_mask.append(
torch.zeros(B, seq_len, device=device, dtype=bool)
)
# Spatial positional encoding (it might have been offloaded to CPU in eval)
- maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
+ maskmem_enc = prev["maskmem_pos_enc"][-1].to(device)
maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
if (
diff --git a/sam3/model/sam3_tracking_predictor.py b/sam3/model/sam3_tracking_predictor.py
index b2440ef6..a5a27bb5 100644
--- a/sam3/model/sam3_tracking_predictor.py
+++ b/sam3/model/sam3_tracking_predictor.py
@@ -46,8 +46,16 @@ def __init__(
self.max_point_num_in_prompt_enc = max_point_num_in_prompt_enc
self.non_overlap_masks_for_output = non_overlap_masks_for_output
- self.bf16_context = torch.autocast(device_type="cuda", dtype=torch.bfloat16)
- self.bf16_context.__enter__() # keep using for the entire model process
+ # Set up autocast context based on device type
+ # MPS doesn't support bfloat16, so we skip autocast on non-CUDA devices
+ device_type = getattr(self, 'device', torch.device('cpu'))
+ if hasattr(device_type, 'type'):
+ device_type = device_type.type
+ if device_type == "cuda":
+ self.bf16_context = torch.autocast(device_type="cuda", dtype=torch.bfloat16)
+ self.bf16_context.__enter__() # keep using for the entire model process
+ else:
+ self.bf16_context = None # No autocast for MPS/CPU
self.iter_use_prev_mask_pred = True
self.add_all_frames_to_correct_as_cond = True
@@ -78,7 +86,8 @@ def init_state(
if offload_state_to_cpu:
inference_state["storage_device"] = torch.device("cpu")
else:
- inference_state["storage_device"] = torch.device("cuda")
+ # Use the actual device (cuda, mps, or cpu) instead of hardcoded cuda
+ inference_state["storage_device"] = self.device
if video_path is not None:
images, video_height, video_width = load_video_frames(
@@ -300,7 +309,12 @@ def add_new_points_or_box(
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
if prev_out is not None and prev_out["pred_masks"] is not None:
- prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
+ device = inference_state["device"]
+ # Use device-agnostic transfer (cuda, mps, or cpu)
+ if device.type == "cuda":
+ prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
+ else:
+ prev_sam_mask_logits = prev_out["pred_masks"].to(device)
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
current_out, _ = self._run_single_frame_inference(
@@ -1021,7 +1035,8 @@ def _get_image_feature(self, inference_state, frame_idx, batch_size):
)
else:
# Cache miss -- we will run inference on a single image
- image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
+ device = inference_state["device"]
+ image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
backbone_out = self.forward_image(image)
# Cache the most recent frame's feature (for repeated interactions with
# a frame; we can use an LRU cache for more frames in the future).
diff --git a/sam3/model/sam3_video_inference.py b/sam3/model/sam3_video_inference.py
index 7fb87d01..8e1b71e1 100644
--- a/sam3/model/sam3_video_inference.py
+++ b/sam3/model/sam3_video_inference.py
@@ -477,9 +477,13 @@ def _postprocess_output(
# slice those valid entries from the original outputs
keep_idx = torch.nonzero(keep, as_tuple=True)[0]
- keep_idx_gpu = keep_idx.pin_memory().to(
- device=out_binary_masks.device, non_blocking=True
- )
+ # pin_memory() only works with CUDA, not MPS
+ if out_binary_masks.device.type == "cuda":
+ keep_idx_gpu = keep_idx.pin_memory().to(
+ device=out_binary_masks.device, non_blocking=True
+ )
+ else:
+ keep_idx_gpu = keep_idx.to(device=out_binary_masks.device)
out_obj_ids = torch.index_select(out_obj_ids, 0, keep_idx)
out_probs = torch.index_select(out_probs, 0, keep_idx)
diff --git a/sam3/model/sam3_video_predictor.py b/sam3/model/sam3_video_predictor.py
index c639e1d0..ccd7a009 100644
--- a/sam3/model/sam3_video_predictor.py
+++ b/sam3/model/sam3_video_predictor.py
@@ -16,6 +16,7 @@
import torch
from sam3.logger import get_logger
+from sam3.utils.device import get_device
logger = get_logger(__name__)
@@ -48,7 +49,7 @@ def __init__(
strict_state_dict_loading=strict_state_dict_loading,
apply_temporal_disambiguation=apply_temporal_disambiguation,
)
- .cuda()
+ .to(get_device())
.eval()
)
@@ -275,11 +276,17 @@ def _get_session_stats(self):
return session_stats_str
def _get_torch_and_gpu_properties(self):
- """Get a string for PyTorch and GPU properties (for logging and debugging)."""
- torch_and_gpu_str = (
- f"torch: {torch.__version__} with CUDA arch {torch.cuda.get_arch_list()}, "
- f"GPU device: {torch.cuda.get_device_properties(torch.cuda.current_device())}"
- )
+ """Get a string for PyTorch and device properties (for logging and debugging)."""
+ device = get_device()
+ if device.type == "cuda":
+ torch_and_gpu_str = (
+ f"torch: {torch.__version__} with CUDA arch {torch.cuda.get_arch_list()}, "
+ f"GPU device: {torch.cuda.get_device_properties(torch.cuda.current_device())}"
+ )
+ elif device.type == "mps":
+ torch_and_gpu_str = f"torch: {torch.__version__} with MPS (Apple Silicon)"
+ else:
+ torch_and_gpu_str = f"torch: {torch.__version__} on CPU"
return torch_and_gpu_str
def shutdown(self):
@@ -428,7 +435,8 @@ def _start_nccl_process_group(self):
device_id=self.device,
)
# warm-up the NCCL process group by running a dummy all-reduce
- tensor = torch.ones(1024, 1024).cuda()
+ # Note: NCCL backend requires CUDA tensors
+ tensor = torch.ones(1024, 1024, device=self.device)
torch.distributed.all_reduce(tensor)
logger.debug(f"started NCCL process group on {rank=} with {world_size=}")
diff --git a/sam3/model/vl_combiner.py b/sam3/model/vl_combiner.py
index 43bc7bd5..ae8bc405 100644
--- a/sam3/model/vl_combiner.py
+++ b/sam3/model/vl_combiner.py
@@ -10,6 +10,7 @@
from torch.nn.attention import sdpa_kernel, SDPBackend
+from sam3.utils.device import get_device_str
from .act_ckpt_utils import activation_ckpt_wrapper
from .necks import Sam3DualViTDetNeck
@@ -119,8 +120,10 @@ def _forward_image_no_act_ckpt(self, samples):
return output
def forward_text(
- self, captions, input_boxes=None, additional_text=None, device="cuda"
+ self, captions, input_boxes=None, additional_text=None, device=None
):
+ if device is None:
+ device = get_device_str()
return activation_ckpt_wrapper(self._forward_text_no_ack_ckpt)(
captions=captions,
input_boxes=input_boxes,
@@ -134,8 +137,10 @@ def _forward_text_no_ack_ckpt(
captions,
input_boxes=None,
additional_text=None,
- device="cuda",
+ device=None,
):
+ if device is None:
+ device = get_device_str()
output = {}
# Forward through text_encoder
diff --git a/sam3/model_builder.py b/sam3/model_builder.py
index 1a3bdecf..3d588ffb 100644
--- a/sam3/model_builder.py
+++ b/sam3/model_builder.py
@@ -44,17 +44,11 @@
from sam3.sam.transformer import RoPEAttention
-# Setup TensorFloat-32 for Ampere GPUs if available
-def _setup_tf32() -> None:
- """Enable TensorFloat-32 for Ampere GPUs if available."""
- if torch.cuda.is_available():
- device_props = torch.cuda.get_device_properties(0)
- if device_props.major >= 8:
- torch.backends.cuda.matmul.allow_tf32 = True
- torch.backends.cudnn.allow_tf32 = True
+# Import device utilities
+from sam3.utils.device import get_device_str, setup_device_optimizations
-
-_setup_tf32()
+# Setup device-specific optimizations (TF32 for Ampere GPUs, etc.)
+setup_device_optimizations()
def _create_position_encoding(precompute_resolution=None):
@@ -549,8 +543,7 @@ def _load_checkpoint(model, checkpoint_path):
def _setup_device_and_mode(model, device, eval_mode):
"""Setup model device and evaluation mode."""
- if device == "cuda":
- model = model.cuda()
+ model = model.to(device=device)
if eval_mode:
model.eval()
return model
@@ -558,7 +551,7 @@ def _setup_device_and_mode(model, device, eval_mode):
def build_sam3_image_model(
bpe_path=None,
- device="cuda" if torch.cuda.is_available() else "cpu",
+ device=None, # Will use get_device_str() if None
eval_mode=True,
checkpoint_path=None,
load_from_HF=True,
@@ -571,7 +564,7 @@ def build_sam3_image_model(
Args:
bpe_path: Path to the BPE tokenizer vocabulary
- device: Device to load the model on ('cuda' or 'cpu')
+ device: Device to load the model on ('cuda', 'mps', or 'cpu'). If None, auto-detects best available device.
eval_mode: Whether to set the model to evaluation mode
checkpoint_path: Optional path to model checkpoint
enable_segmentation: Whether to enable segmentation head
@@ -586,6 +579,10 @@ def build_sam3_image_model(
"sam3", "assets/bpe_simple_vocab_16e6.txt.gz"
)
+ # Set default device if not specified
+ if device is None:
+ device = get_device_str()
+
# Create visual components
compile_mode = "default" if compile else None
vision_encoder = _create_vision_backbone(
@@ -657,7 +654,7 @@ def build_sam3_video_model(
geo_encoder_use_img_cross_attn: bool = True,
strict_state_dict_loading: bool = True,
apply_temporal_disambiguation: bool = True,
- device="cuda" if torch.cuda.is_available() else "cpu",
+ device=None, # Will use get_device_str() if None
compile=False,
) -> Sam3VideoInferenceWithInstanceInteractivity:
"""
@@ -675,6 +672,10 @@ def build_sam3_video_model(
"sam3", "assets/bpe_simple_vocab_16e6.txt.gz"
)
+ # Set default device if not specified
+ if device is None:
+ device = get_device_str()
+
# Build Tracker module
tracker = build_tracker(apply_temporal_disambiguation=apply_temporal_disambiguation)
diff --git a/sam3/perflib/connected_components.py b/sam3/perflib/connected_components.py
index c96932a4..f212263b 100644
--- a/sam3/perflib/connected_components.py
+++ b/sam3/perflib/connected_components.py
@@ -54,6 +54,8 @@ def connected_components(input_tensor: torch.Tensor):
"""
Computes connected components labeling on a batch of 2D tensors, using the best available backend.
+ Supports CUDA (with optional Triton acceleration), MPS (Apple Silicon), and CPU.
+
Args:
input_tensor (torch.Tensor): A BxHxW integer tensor or Bx1xHxW. Non-zero values are considered foreground. Bool tensor also accepted
@@ -69,7 +71,10 @@ def connected_components(input_tensor: torch.Tensor):
input_tensor.dim() == 4 and input_tensor.shape[1] == 1
), "Input tensor must be (B, H, W) or (B, 1, H, W)."
- if input_tensor.is_cuda:
+ # Check device type
+ device_type = input_tensor.device.type
+
+ if device_type == "cuda":
if HAS_CC_TORCH:
return get_connected_components(input_tensor.to(torch.uint8))
else:
@@ -80,5 +85,6 @@ def connected_components(input_tensor: torch.Tensor):
return connected_components_triton(input_tensor)
- # CPU fallback
+ # For MPS (Apple Silicon) and CPU, use the CPU implementation
+ # MPS tensors are handled in connected_components_cpu via .cpu() conversion
return connected_components_cpu(input_tensor)
diff --git a/sam3/perflib/nms.py b/sam3/perflib/nms.py
index b3efc599..f50cb800 100644
--- a/sam3/perflib/nms.py
+++ b/sam3/perflib/nms.py
@@ -55,12 +55,18 @@ def nms_masks(
def generic_nms(
ious: torch.Tensor, scores: torch.Tensor, iou_threshold=0.5
) -> torch.Tensor:
- """A generic version of `torchvision.ops.nms` that takes a pairwise IoU matrix."""
+ """A generic version of `torchvision.ops.nms` that takes a pairwise IoU matrix.
+
+ Supports CUDA (with optional Triton acceleration), MPS (Apple Silicon), and CPU.
+ """
assert ious.dim() == 2 and ious.size(0) == ious.size(1)
assert scores.dim() == 1 and scores.size(0) == ious.size(0)
- if ious.is_cuda:
+ # Check device type
+ device_type = ious.device.type
+
+ if device_type == "cuda":
if GENERIC_NMS_AVAILABLE:
return generic_nms_cuda(ious, scores, iou_threshold, use_iou_matrix=True)
else:
@@ -68,6 +74,8 @@ def generic_nms(
return nms_triton(ious, scores, iou_threshold)
+ # For MPS (Apple Silicon) and CPU, use the CPU implementation
+ # MPS tensors need to be moved to CPU for numpy operations
return generic_nms_cpu(ious, scores, iou_threshold)
diff --git a/sam3/sam/transformer.py b/sam3/sam/transformer.py
index 3e96c283..5d4a4ce9 100644
--- a/sam3/sam/transformer.py
+++ b/sam3/sam/transformer.py
@@ -252,9 +252,11 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
).transpose(1, 2)
else:
- torch.backends.cuda.enable_flash_sdp(True)
- torch.backends.cuda.enable_math_sdp(True)
- torch.backends.cuda.enable_mem_efficient_sdp(True)
+ # Only configure CUDA backends when on CUDA device
+ if q.is_cuda:
+ torch.backends.cuda.enable_flash_sdp(True)
+ torch.backends.cuda.enable_math_sdp(True)
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out)
@@ -282,9 +284,9 @@ def __init__(
self.compute_cis = partial(
compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
)
- device = torch.device("cuda") if torch.cuda.is_available() else None
+ # Use None for device - will be set on first forward pass based on input tensor
self.freqs_cis = self.compute_cis(
- end_x=feat_sizes[0], end_y=feat_sizes[1], device=device
+ end_x=feat_sizes[0], end_y=feat_sizes[1], device=None
)
if self.use_rope_real:
self.freqs_cis_real = self.freqs_cis.real
@@ -347,9 +349,11 @@ def forward(
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
).transpose(1, 2)
else:
- torch.backends.cuda.enable_flash_sdp(True)
- torch.backends.cuda.enable_math_sdp(True)
- torch.backends.cuda.enable_mem_efficient_sdp(True)
+ # Only configure CUDA backends when on CUDA device
+ if q.is_cuda:
+ torch.backends.cuda.enable_flash_sdp(True)
+ torch.backends.cuda.enable_math_sdp(True)
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
out = self._recombine_heads(out)
diff --git a/sam3/train/data/sam3_image_dataset.py b/sam3/train/data/sam3_image_dataset.py
index 97efb1d1..f8b0d634 100644
--- a/sam3/train/data/sam3_image_dataset.py
+++ b/sam3/train/data/sam3_image_dataset.py
@@ -15,7 +15,7 @@
import torch
import torch.utils.data
import torchvision
-from decord import cpu, VideoReader
+# decord is imported lazily when needed for video loading
from iopath.common.file_io import g_pathmgr
from PIL import Image as PILImage
@@ -202,6 +202,7 @@ def _load_images(
try:
if ".mp4" in path and path[-4:] == ".mp4":
# Going to load a video frame
+ from decord import cpu, VideoReader
video_path, frame = path.split("@")
video = VideoReader(video_path, ctx=cpu(0))
# Convert to PIL image
diff --git a/sam3/train/loss/sigmoid_focal_loss.py b/sam3/train/loss/sigmoid_focal_loss.py
index 15e6db43..48f3b811 100644
--- a/sam3/train/loss/sigmoid_focal_loss.py
+++ b/sam3/train/loss/sigmoid_focal_loss.py
@@ -1,11 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
-"""Triton kernel for faster and memory efficient sigmoid focal loss"""
+"""Sigmoid focal loss with optional Triton kernel acceleration for CUDA devices."""
import torch
-import triton
-import triton.language as tl
-from torch._inductor.runtime.triton_helpers import libdevice
+import torch.nn.functional as F
+
+# Try to import Triton (only available on CUDA)
+try:
+ import triton
+ import triton.language as tl
+ from torch._inductor.runtime.triton_helpers import libdevice
+
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
"""
@@ -32,290 +40,410 @@
"""
-@triton.jit
-def _inner_focal_loss_fwd(inputs, targets, alpha, gamma):
- inv_targets = 1 - targets
- # Sigmoid
- sig = tl.sigmoid(inputs)
-
- # Binary cross entropy with logits
- # In practice, we want the following:
- # bce_loss = -targets * tl.log(sig) - (1 - targets) * tl.log(1 - sig)
- # However, the above is not numerically stable.
- # We're also not directly taking the sum here, so the usual log-sum-exp trick doesn't apply
- # The bce can be reformulated, after algebraic manipulation, to
- # bce_loss = log(1 + exp(-x)) + x * (1-y)
- # This is still not stable, because for large (-x) the exponential will blow up.
- # We'll use the following alternate formulation:
- # bce_loss = max(x, 0) - x * y + log(1 + exp(-abs(x)))
- # Let's show that it's equivalent:
- # Case x>=0: abs(x) = x , max(x, 0) = x
- # so we get x - x * y + log(1 + exp(-x)) which is equivalent
- # Case x<0: abs(x) = -x, max(x, 0) = 0
- # we have log(1 + exp(-abs(x))) = log(1 + exp(x)) = log(exp(x)(1 + exp(-x))) = x+log(1 + exp(-x))
- # plugging it in, we get
- # 0 - x * y + x + log(1 + exp(-x)), which is also equivalent
- # Note that this is stable because now the exponent are guaranteed to be below 0.
- max_val = tl.clamp(inputs, min=0, max=1e9)
- bce_loss = max_val - inputs * targets + tl.log(1 + tl.exp(-tl.abs(inputs)))
-
- # Modulating factor
- p_t = sig * targets + (1 - sig) * inv_targets
- mod_factor = libdevice.pow(1 - p_t, gamma)
-
- # Alpha factor
- alpha_t = alpha * targets + (1 - alpha) * inv_targets
-
- # Final loss calculation
- return alpha_t * mod_factor * bce_loss
-
-
-# Non-reduced version
-@triton.jit
-def sigmoid_focal_loss_fwd_kernel(
- inputs_ptr,
- targets_ptr,
- loss_ptr,
- alpha: float,
- gamma: float,
- n_elements: int,
- BLOCK_SIZE: tl.constexpr,
-):
- pid = tl.program_id(axis=0)
- block_start = pid * BLOCK_SIZE
- offset = block_start + tl.arange(0, BLOCK_SIZE)
- mask = offset < n_elements
-
- # Load data
- inputs = tl.load(inputs_ptr + offset, mask=mask).to(tl.float32)
- targets = tl.load(targets_ptr + offset, mask=mask)
-
- final_loss = _inner_focal_loss_fwd(inputs, targets, alpha, gamma)
-
- # Store result
- tl.store(loss_ptr + offset, final_loss, mask=mask)
-
-
-# version with reduction
-@triton.jit
-def sigmoid_focal_loss_fwd_kernel_reduce(
- inputs_ptr,
- targets_ptr,
- loss_ptr,
- alpha: float,
- gamma: float,
- n_elements: int,
- BLOCK_SIZE: tl.constexpr,
- REDUCE_SIZE: tl.constexpr,
-):
- pid = tl.program_id(axis=0)
- block_start = pid * BLOCK_SIZE
- reduce_loc = pid % REDUCE_SIZE
- offset = block_start + tl.arange(0, BLOCK_SIZE)
- mask = offset < n_elements
- # Load data
- inputs = tl.load(inputs_ptr + offset, mask=mask).to(tl.float32)
- targets = tl.load(targets_ptr + offset, mask=mask)
-
- final_loss = _inner_focal_loss_fwd(inputs, targets, alpha, gamma) * mask
-
- fl = tl.sum(final_loss)
-
- # Store result
- tl.atomic_add(loss_ptr + reduce_loc, fl)
-
-
-@triton.jit
-def _inner_focal_loss_bwd(inputs, targets, alpha, gamma):
- inv_targets = 1 - targets
-
- # Recompute forward
- max_val = tl.clamp(inputs, min=0, max=1e9)
- bce_loss = max_val - inputs * targets + tl.log(1 + tl.exp(-tl.abs(inputs)))
-
- # Sigmoid
- sig = tl.sigmoid(inputs)
- inv_sig = 1 - sig
-
- # Modulating factor
- p_t = sig * targets + inv_sig * inv_targets
- tmp = libdevice.pow(1 - p_t, gamma - 1)
- mod_factor = tmp * (1 - p_t)
-
- # Alpha factor
- alpha_t = alpha * targets + (1 - alpha) * inv_targets
-
- # Now computing the derivatives
- d_pt = (2 * targets - 1) * sig * inv_sig
- d_mod_factor = -gamma * d_pt * tmp
-
- d_bce_loss = sig - targets
-
- return alpha_t * (d_bce_loss * mod_factor + d_mod_factor * bce_loss)
-
-
-@triton.jit
-def sigmoid_focal_loss_bwd_kernel(
- inputs_ptr,
- targets_ptr,
- grad_inputs_ptr,
- grad_out_ptr,
- alpha: float,
- gamma: float,
- n_elements: int,
- BLOCK_SIZE: tl.constexpr,
-):
- pid = tl.program_id(axis=0)
- block_start = pid * BLOCK_SIZE
- offset = block_start + tl.arange(0, BLOCK_SIZE)
- mask = offset < n_elements
- input_ptrs = inputs_ptr + offset
- target_ptrs = targets_ptr + offset
- grad_input_ptrs = grad_inputs_ptr + offset
- grad_out_ptrs = grad_out_ptr + offset
- # Load data
- inputs = tl.load(input_ptrs, mask=mask).to(tl.float32)
- targets = tl.load(target_ptrs, mask=mask)
- grad_out = tl.load(grad_out_ptrs, mask=mask)
- d_loss = grad_out * _inner_focal_loss_bwd(inputs, targets, alpha, gamma)
- tl.store(grad_input_ptrs, d_loss, mask=mask)
-
-
-@triton.jit
-def sigmoid_focal_loss_bwd_kernel_reduce(
- inputs_ptr,
- targets_ptr,
- grad_inputs_ptr,
- grad_out_ptr,
- alpha: float,
- gamma: float,
- n_elements: int,
- BLOCK_SIZE: tl.constexpr,
-):
- # The only difference is that the gradient is now a single scalar
- pid = tl.program_id(axis=0)
- block_start = pid * BLOCK_SIZE
- offset = block_start + tl.arange(0, BLOCK_SIZE)
- mask = offset < n_elements
- input_ptrs = inputs_ptr + offset
- target_ptrs = targets_ptr + offset
- grad_input_ptrs = grad_inputs_ptr + offset
- # Load data
- inputs = tl.load(input_ptrs, mask=mask).to(tl.float32)
- targets = tl.load(target_ptrs, mask=mask)
- grad_out = tl.load(grad_out_ptr)
- d_loss = grad_out * _inner_focal_loss_bwd(inputs, targets, alpha, gamma)
- tl.store(grad_input_ptrs, d_loss, mask=mask)
-
-
-class SigmoidFocalLoss(torch.autograd.Function):
- BLOCK_SIZE = 256
-
- @staticmethod
- def forward(ctx, inputs, targets, alpha=0.25, gamma=2):
- n_elements = inputs.numel()
- assert targets.numel() == n_elements
- input_shape = inputs.shape
- inputs = inputs.view(-1).contiguous()
- targets = targets.view(-1).contiguous()
- loss = torch.empty(inputs.shape, dtype=torch.float32, device=inputs.device)
- grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
- sigmoid_focal_loss_fwd_kernel[grid](
- inputs, targets, loss, alpha, gamma, n_elements, SigmoidFocalLoss.BLOCK_SIZE
- )
- ctx.save_for_backward(inputs.view(input_shape), targets.view(input_shape))
- ctx.alpha = alpha
- ctx.gamma = gamma
- return loss.view(input_shape)
-
- @staticmethod
- def backward(ctx, grad_output):
- inputs, targets = ctx.saved_tensors
- alpha = ctx.alpha
- gamma = ctx.gamma
- n_elements = inputs.numel()
- input_shape = inputs.shape
- grad_inputs = torch.empty(
- inputs.shape, dtype=grad_output.dtype, device=grad_output.device
- )
- inputs_ptr = inputs.view(-1).contiguous()
- targets_ptr = targets.view(-1).contiguous()
- grad_output_ptr = grad_output.view(-1).contiguous()
- grad_inputs_ptr = grad_inputs
- assert grad_output.numel() == n_elements
- grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
- sigmoid_focal_loss_bwd_kernel[grid](
- inputs_ptr,
- targets_ptr,
- grad_inputs_ptr,
- grad_output_ptr,
- alpha,
- gamma,
- n_elements,
- SigmoidFocalLoss.BLOCK_SIZE,
- )
- return grad_inputs.view(input_shape), None, None, None
-
-
-triton_sigmoid_focal_loss = SigmoidFocalLoss.apply
-
-
-class SigmoidFocalLossReduced(torch.autograd.Function):
- BLOCK_SIZE = 256
- REDUCE_SIZE = 32
-
- @staticmethod
- def forward(ctx, inputs, targets, alpha=0.25, gamma=2):
- n_elements = inputs.numel()
- input_shape = inputs.shape
- inputs = inputs.view(-1).contiguous()
- targets = targets.view(-1).contiguous()
- loss = torch.zeros(
- SigmoidFocalLossReduced.REDUCE_SIZE,
- device=inputs.device,
- dtype=torch.float32,
- )
- grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
- sigmoid_focal_loss_fwd_kernel_reduce[grid](
- inputs,
- targets,
- loss,
- alpha,
- gamma,
- n_elements,
- SigmoidFocalLossReduced.BLOCK_SIZE,
- SigmoidFocalLossReduced.REDUCE_SIZE,
- )
- ctx.save_for_backward(inputs.view(input_shape), targets.view(input_shape))
- ctx.alpha = alpha
- ctx.gamma = gamma
- return loss.sum()
-
- @staticmethod
- def backward(ctx, grad_output):
- inputs, targets = ctx.saved_tensors
- alpha = ctx.alpha
- gamma = ctx.gamma
- n_elements = inputs.numel()
- input_shape = inputs.shape
- grad_inputs = torch.empty(
- inputs.shape, dtype=grad_output.dtype, device=grad_output.device
- )
- inputs_ptr = inputs.view(-1).contiguous()
- targets_ptr = targets.reshape(-1).contiguous()
- assert grad_output.numel() == 1
- grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
- sigmoid_focal_loss_bwd_kernel_reduce[grid](
- inputs_ptr,
- targets_ptr,
- grad_inputs,
- grad_output,
- alpha,
- gamma,
- n_elements,
- SigmoidFocalLossReduced.BLOCK_SIZE,
- )
- return grad_inputs.view(input_shape), None, None, None
-
-
-triton_sigmoid_focal_loss_reduce = SigmoidFocalLossReduced.apply
+# ============================================================================
+# PyTorch-based implementations (for CPU, MPS, and fallback)
+# ============================================================================
+
+
+def sigmoid_focal_loss_pytorch(
+ inputs: torch.Tensor,
+ targets: torch.Tensor,
+ alpha: float = 0.25,
+ gamma: float = 2.0,
+) -> torch.Tensor:
+ """
+ Pure PyTorch implementation of sigmoid focal loss (no reduction).
+
+ Args:
+ inputs: Tensor of any shape, containing logits
+ targets: Tensor of the same shape as inputs, containing float targets
+ alpha: Weighting factor in range (0,1) to balance positive vs negative examples
+ gamma: Exponent of the modulating factor (1 - p_t) ** gamma
+
+ Returns:
+ Tensor of the same shape as inputs, containing the focal loss for each element
+ """
+ # Compute sigmoid and BCE loss
+ prob = torch.sigmoid(inputs)
+ ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+
+ # Compute p_t and alpha_t
+ p_t = prob * targets + (1 - prob) * (1 - targets)
+ alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+
+ # Compute focal loss
+ focal_weight = (1 - p_t) ** gamma
+ loss = alpha_t * focal_weight * ce_loss
+
+ return loss
+
+
+def sigmoid_focal_loss_reduced_pytorch(
+ inputs: torch.Tensor,
+ targets: torch.Tensor,
+ alpha: float = 0.25,
+ gamma: float = 2.0,
+) -> torch.Tensor:
+ """
+ Pure PyTorch implementation of sigmoid focal loss with sum reduction.
+
+ Args:
+ inputs: Tensor of any shape, containing logits
+ targets: Tensor of the same shape as inputs, containing float targets
+ alpha: Weighting factor in range (0,1) to balance positive vs negative examples
+ gamma: Exponent of the modulating factor (1 - p_t) ** gamma
+
+ Returns:
+ Scalar tensor containing the sum of focal losses
+ """
+ return sigmoid_focal_loss_pytorch(inputs, targets, alpha, gamma).sum()
+
+
+# ============================================================================
+# Triton-based implementations (CUDA only)
+# ============================================================================
+
+if HAS_TRITON:
+
+ @triton.jit
+ def _inner_focal_loss_fwd(inputs, targets, alpha, gamma):
+ inv_targets = 1 - targets
+ # Sigmoid
+ sig = tl.sigmoid(inputs)
+
+ # Binary cross entropy with logits
+ # In practice, we want the following:
+ # bce_loss = -targets * tl.log(sig) - (1 - targets) * tl.log(1 - sig)
+ # However, the above is not numerically stable.
+ # We're also not directly taking the sum here, so the usual log-sum-exp trick doesn't apply
+ # The bce can be reformulated, after algebraic manipulation, to
+ # bce_loss = log(1 + exp(-x)) + x * (1-y)
+ # This is still not stable, because for large (-x) the exponential will blow up.
+ # We'll use the following alternate formulation:
+ # bce_loss = max(x, 0) - x * y + log(1 + exp(-abs(x)))
+ # Let's show that it's equivalent:
+ # Case x>=0: abs(x) = x , max(x, 0) = x
+ # so we get x - x * y + log(1 + exp(-x)) which is equivalent
+ # Case x<0: abs(x) = -x, max(x, 0) = 0
+ # we have log(1 + exp(-abs(x))) = log(1 + exp(x)) = log(exp(x)(1 + exp(-x))) = x+log(1 + exp(-x))
+ # plugging it in, we get
+ # 0 - x * y + x + log(1 + exp(-x)), which is also equivalent
+ # Note that this is stable because now the exponent are guaranteed to be below 0.
+ max_val = tl.clamp(inputs, min=0, max=1e9)
+ bce_loss = max_val - inputs * targets + tl.log(1 + tl.exp(-tl.abs(inputs)))
+
+ # Modulating factor
+ p_t = sig * targets + (1 - sig) * inv_targets
+ mod_factor = libdevice.pow(1 - p_t, gamma)
+
+ # Alpha factor
+ alpha_t = alpha * targets + (1 - alpha) * inv_targets
+
+ # Final loss calculation
+ return alpha_t * mod_factor * bce_loss
+
+ # Non-reduced version
+ @triton.jit
+ def sigmoid_focal_loss_fwd_kernel(
+ inputs_ptr,
+ targets_ptr,
+ loss_ptr,
+ alpha: float,
+ gamma: float,
+ n_elements: int,
+ BLOCK_SIZE: tl.constexpr,
+ ):
+ pid = tl.program_id(axis=0)
+ block_start = pid * BLOCK_SIZE
+ offset = block_start + tl.arange(0, BLOCK_SIZE)
+ mask = offset < n_elements
+
+ # Load data
+ inputs = tl.load(inputs_ptr + offset, mask=mask).to(tl.float32)
+ targets = tl.load(targets_ptr + offset, mask=mask)
+
+ final_loss = _inner_focal_loss_fwd(inputs, targets, alpha, gamma)
+
+ # Store result
+ tl.store(loss_ptr + offset, final_loss, mask=mask)
+
+ # version with reduction
+ @triton.jit
+ def sigmoid_focal_loss_fwd_kernel_reduce(
+ inputs_ptr,
+ targets_ptr,
+ loss_ptr,
+ alpha: float,
+ gamma: float,
+ n_elements: int,
+ BLOCK_SIZE: tl.constexpr,
+ REDUCE_SIZE: tl.constexpr,
+ ):
+ pid = tl.program_id(axis=0)
+ block_start = pid * BLOCK_SIZE
+ reduce_loc = pid % REDUCE_SIZE
+ offset = block_start + tl.arange(0, BLOCK_SIZE)
+ mask = offset < n_elements
+ # Load data
+ inputs = tl.load(inputs_ptr + offset, mask=mask).to(tl.float32)
+ targets = tl.load(targets_ptr + offset, mask=mask)
+
+ final_loss = _inner_focal_loss_fwd(inputs, targets, alpha, gamma) * mask
+
+ fl = tl.sum(final_loss)
+
+ # Store result
+ tl.atomic_add(loss_ptr + reduce_loc, fl)
+
+ @triton.jit
+ def _inner_focal_loss_bwd(inputs, targets, alpha, gamma):
+ inv_targets = 1 - targets
+
+ # Recompute forward
+ max_val = tl.clamp(inputs, min=0, max=1e9)
+ bce_loss = max_val - inputs * targets + tl.log(1 + tl.exp(-tl.abs(inputs)))
+
+ # Sigmoid
+ sig = tl.sigmoid(inputs)
+ inv_sig = 1 - sig
+
+ # Modulating factor
+ p_t = sig * targets + inv_sig * inv_targets
+ tmp = libdevice.pow(1 - p_t, gamma - 1)
+ mod_factor = tmp * (1 - p_t)
+
+ # Alpha factor
+ alpha_t = alpha * targets + (1 - alpha) * inv_targets
+
+ # Now computing the derivatives
+ d_pt = (2 * targets - 1) * sig * inv_sig
+ d_mod_factor = -gamma * d_pt * tmp
+
+ d_bce_loss = sig - targets
+
+ return alpha_t * (d_bce_loss * mod_factor + d_mod_factor * bce_loss)
+
+ @triton.jit
+ def sigmoid_focal_loss_bwd_kernel(
+ inputs_ptr,
+ targets_ptr,
+ grad_inputs_ptr,
+ grad_out_ptr,
+ alpha: float,
+ gamma: float,
+ n_elements: int,
+ BLOCK_SIZE: tl.constexpr,
+ ):
+ pid = tl.program_id(axis=0)
+ block_start = pid * BLOCK_SIZE
+ offset = block_start + tl.arange(0, BLOCK_SIZE)
+ mask = offset < n_elements
+ input_ptrs = inputs_ptr + offset
+ target_ptrs = targets_ptr + offset
+ grad_input_ptrs = grad_inputs_ptr + offset
+ grad_out_ptrs = grad_out_ptr + offset
+ # Load data
+ inputs = tl.load(input_ptrs, mask=mask).to(tl.float32)
+ targets = tl.load(target_ptrs, mask=mask)
+ grad_out = tl.load(grad_out_ptrs, mask=mask)
+ d_loss = grad_out * _inner_focal_loss_bwd(inputs, targets, alpha, gamma)
+ tl.store(grad_input_ptrs, d_loss, mask=mask)
+
+ @triton.jit
+ def sigmoid_focal_loss_bwd_kernel_reduce(
+ inputs_ptr,
+ targets_ptr,
+ grad_inputs_ptr,
+ grad_out_ptr,
+ alpha: float,
+ gamma: float,
+ n_elements: int,
+ BLOCK_SIZE: tl.constexpr,
+ ):
+ # The only difference is that the gradient is now a single scalar
+ pid = tl.program_id(axis=0)
+ block_start = pid * BLOCK_SIZE
+ offset = block_start + tl.arange(0, BLOCK_SIZE)
+ mask = offset < n_elements
+ input_ptrs = inputs_ptr + offset
+ target_ptrs = targets_ptr + offset
+ grad_input_ptrs = grad_inputs_ptr + offset
+ # Load data
+ inputs = tl.load(input_ptrs, mask=mask).to(tl.float32)
+ targets = tl.load(target_ptrs, mask=mask)
+ grad_out = tl.load(grad_out_ptr)
+ d_loss = grad_out * _inner_focal_loss_bwd(inputs, targets, alpha, gamma)
+ tl.store(grad_input_ptrs, d_loss, mask=mask)
+
+ class SigmoidFocalLossTriton(torch.autograd.Function):
+ BLOCK_SIZE = 256
+
+ @staticmethod
+ def forward(ctx, inputs, targets, alpha=0.25, gamma=2):
+ n_elements = inputs.numel()
+ assert targets.numel() == n_elements
+ input_shape = inputs.shape
+ inputs = inputs.view(-1).contiguous()
+ targets = targets.view(-1).contiguous()
+ loss = torch.empty(inputs.shape, dtype=torch.float32, device=inputs.device)
+ grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
+ sigmoid_focal_loss_fwd_kernel[grid](
+ inputs,
+ targets,
+ loss,
+ alpha,
+ gamma,
+ n_elements,
+ SigmoidFocalLossTriton.BLOCK_SIZE,
+ )
+ ctx.save_for_backward(inputs.view(input_shape), targets.view(input_shape))
+ ctx.alpha = alpha
+ ctx.gamma = gamma
+ return loss.view(input_shape)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ inputs, targets = ctx.saved_tensors
+ alpha = ctx.alpha
+ gamma = ctx.gamma
+ n_elements = inputs.numel()
+ input_shape = inputs.shape
+ grad_inputs = torch.empty(
+ inputs.shape, dtype=grad_output.dtype, device=grad_output.device
+ )
+ inputs_ptr = inputs.view(-1).contiguous()
+ targets_ptr = targets.view(-1).contiguous()
+ grad_output_ptr = grad_output.view(-1).contiguous()
+ grad_inputs_ptr = grad_inputs
+ assert grad_output.numel() == n_elements
+ grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
+ sigmoid_focal_loss_bwd_kernel[grid](
+ inputs_ptr,
+ targets_ptr,
+ grad_inputs_ptr,
+ grad_output_ptr,
+ alpha,
+ gamma,
+ n_elements,
+ SigmoidFocalLossTriton.BLOCK_SIZE,
+ )
+ return grad_inputs.view(input_shape), None, None, None
+
+ class SigmoidFocalLossReducedTriton(torch.autograd.Function):
+ BLOCK_SIZE = 256
+ REDUCE_SIZE = 32
+
+ @staticmethod
+ def forward(ctx, inputs, targets, alpha=0.25, gamma=2):
+ n_elements = inputs.numel()
+ input_shape = inputs.shape
+ inputs = inputs.view(-1).contiguous()
+ targets = targets.view(-1).contiguous()
+ loss = torch.zeros(
+ SigmoidFocalLossReducedTriton.REDUCE_SIZE,
+ device=inputs.device,
+ dtype=torch.float32,
+ )
+ grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
+ sigmoid_focal_loss_fwd_kernel_reduce[grid](
+ inputs,
+ targets,
+ loss,
+ alpha,
+ gamma,
+ n_elements,
+ SigmoidFocalLossReducedTriton.BLOCK_SIZE,
+ SigmoidFocalLossReducedTriton.REDUCE_SIZE,
+ )
+ ctx.save_for_backward(inputs.view(input_shape), targets.view(input_shape))
+ ctx.alpha = alpha
+ ctx.gamma = gamma
+ return loss.sum()
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ inputs, targets = ctx.saved_tensors
+ alpha = ctx.alpha
+ gamma = ctx.gamma
+ n_elements = inputs.numel()
+ input_shape = inputs.shape
+ grad_inputs = torch.empty(
+ inputs.shape, dtype=grad_output.dtype, device=grad_output.device
+ )
+ inputs_ptr = inputs.view(-1).contiguous()
+ targets_ptr = targets.reshape(-1).contiguous()
+ assert grad_output.numel() == 1
+ grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
+ sigmoid_focal_loss_bwd_kernel_reduce[grid](
+ inputs_ptr,
+ targets_ptr,
+ grad_inputs,
+ grad_output,
+ alpha,
+ gamma,
+ n_elements,
+ SigmoidFocalLossReducedTriton.BLOCK_SIZE,
+ )
+ return grad_inputs.view(input_shape), None, None, None
+
+
+# ============================================================================
+# Public API - automatically selects best implementation
+# ============================================================================
+
+
+def sigmoid_focal_loss(
+ inputs: torch.Tensor,
+ targets: torch.Tensor,
+ alpha: float = 0.25,
+ gamma: float = 2.0,
+) -> torch.Tensor:
+ """
+ Sigmoid focal loss without reduction.
+
+ Uses Triton kernel on CUDA when available, falls back to PyTorch otherwise.
+
+ Args:
+ inputs: Tensor of any shape, containing logits
+ targets: Tensor of the same shape as inputs, containing float targets
+ alpha: Weighting factor in range (0,1) to balance positive vs negative examples
+ gamma: Exponent of the modulating factor (1 - p_t) ** gamma
+
+ Returns:
+ Tensor of the same shape as inputs, containing the focal loss for each element
+ """
+ if HAS_TRITON and inputs.is_cuda:
+ return SigmoidFocalLossTriton.apply(inputs, targets, alpha, gamma)
+ else:
+ return sigmoid_focal_loss_pytorch(inputs, targets, alpha, gamma)
+
+
+def sigmoid_focal_loss_reduce(
+ inputs: torch.Tensor,
+ targets: torch.Tensor,
+ alpha: float = 0.25,
+ gamma: float = 2.0,
+) -> torch.Tensor:
+ """
+ Sigmoid focal loss with sum reduction.
+
+ Uses Triton kernel on CUDA when available, falls back to PyTorch otherwise.
+
+ Args:
+ inputs: Tensor of any shape, containing logits
+ targets: Tensor of the same shape as inputs, containing float targets
+ alpha: Weighting factor in range (0,1) to balance positive vs negative examples
+ gamma: Exponent of the modulating factor (1 - p_t) ** gamma
+
+ Returns:
+ Scalar tensor containing the sum of focal losses
+ """
+ if HAS_TRITON and inputs.is_cuda:
+ return SigmoidFocalLossReducedTriton.apply(inputs, targets, alpha, gamma)
+ else:
+ return sigmoid_focal_loss_reduced_pytorch(inputs, targets, alpha, gamma)
+
+
+# Legacy aliases for backward compatibility
+triton_sigmoid_focal_loss = sigmoid_focal_loss
+triton_sigmoid_focal_loss_reduce = sigmoid_focal_loss_reduce
diff --git a/sam3/train/utils/distributed.py b/sam3/train/utils/distributed.py
index 3c87a911..de41d724 100644
--- a/sam3/train/utils/distributed.py
+++ b/sam3/train/utils/distributed.py
@@ -190,6 +190,9 @@ def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, s
For some backends, such as NCCL, communication only works if the
tensor is on the GPU. This helper function converts to the correct
device and returns the tensor + original device.
+
+ Note: NCCL backend only works with CUDA. For MPS or CPU distributed training,
+ use a different backend like 'gloo'.
"""
orig_device = "cpu" if not tensor.is_cuda else "gpu"
if (
@@ -197,6 +200,7 @@ def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, s
and torch.distributed.get_backend() == torch.distributed.Backend.NCCL
and not tensor.is_cuda
):
+ # NCCL requires CUDA tensors
tensor = tensor.cuda()
return (tensor, orig_device)
diff --git a/sam3/utils/__init__.py b/sam3/utils/__init__.py
new file mode 100644
index 00000000..0136676b
--- /dev/null
+++ b/sam3/utils/__init__.py
@@ -0,0 +1,29 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
+
+from sam3.utils.device import (
+ get_device,
+ get_device_str,
+ is_cuda_available,
+ is_gpu_available,
+ is_mps_available,
+ move_model_to_device,
+ setup_device_optimizations,
+ tensor_is_on_cuda,
+ tensor_is_on_gpu,
+ tensor_is_on_mps,
+ to_device,
+)
+
+__all__ = [
+ "get_device",
+ "get_device_str",
+ "is_cuda_available",
+ "is_mps_available",
+ "is_gpu_available",
+ "to_device",
+ "setup_device_optimizations",
+ "tensor_is_on_gpu",
+ "tensor_is_on_cuda",
+ "tensor_is_on_mps",
+ "move_model_to_device",
+]
diff --git a/sam3/utils/device.py b/sam3/utils/device.py
new file mode 100644
index 00000000..fb413394
--- /dev/null
+++ b/sam3/utils/device.py
@@ -0,0 +1,171 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
+
+"""
+Device utilities for supporting CUDA, MPS (Apple Silicon), and CPU backends.
+"""
+
+import logging
+from functools import lru_cache
+from typing import Optional, Union
+
+import torch
+
+logger = logging.getLogger(__name__)
+
+
+@lru_cache(maxsize=1)
+def get_device() -> torch.device:
+ """
+ Get the best available device for computation.
+
+ Priority: CUDA > MPS > CPU
+
+ Returns:
+ torch.device: The best available device
+ """
+ if torch.cuda.is_available():
+ return torch.device("cuda")
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
+ return torch.device("mps")
+ else:
+ return torch.device("cpu")
+
+
+def get_device_str() -> str:
+ """
+ Get the best available device as a string.
+
+ Returns:
+ str: Device string ("cuda", "mps", or "cpu")
+ """
+ return str(get_device())
+
+
+def is_cuda_available() -> bool:
+ """Check if CUDA is available."""
+ return torch.cuda.is_available()
+
+
+def is_mps_available() -> bool:
+ """Check if MPS (Apple Silicon GPU) is available."""
+ return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
+
+
+def is_gpu_available() -> bool:
+ """Check if any GPU (CUDA or MPS) is available."""
+ return is_cuda_available() or is_mps_available()
+
+
+def to_device(
+ tensor: torch.Tensor,
+ device: Optional[Union[str, torch.device]] = None,
+ non_blocking: bool = False,
+) -> torch.Tensor:
+ """
+ Move tensor to the specified device, or to the best available device if not specified.
+
+ Args:
+ tensor: The tensor to move
+ device: Target device. If None, uses get_device()
+ non_blocking: Whether to perform the transfer asynchronously
+
+ Returns:
+ torch.Tensor: Tensor on the target device
+ """
+ if device is None:
+ device = get_device()
+ return tensor.to(device=device, non_blocking=non_blocking)
+
+
+def setup_device_optimizations() -> None:
+ """
+ Setup device-specific optimizations.
+
+ - For CUDA Ampere+ GPUs: Enable TensorFloat-32
+ - For MPS: Enable high water mark ratio for memory management
+ - For CPU: Currently no special optimizations
+ """
+ if torch.cuda.is_available():
+ try:
+ device_props = torch.cuda.get_device_properties(0)
+ if device_props.major >= 8:
+ # Enable TF32 for Ampere GPUs (compute capability >= 8.0)
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+ logger.debug("Enabled TensorFloat-32 for Ampere GPU")
+ except Exception as e:
+ logger.debug(f"Could not set up CUDA optimizations: {e}")
+ elif is_mps_available():
+ # MPS optimizations for Apple Silicon
+ try:
+ # Set high water mark ratio to allow more GPU memory usage
+ # This can improve performance by reducing memory pressure
+ torch.mps.set_per_process_memory_fraction(0.0) # No limit
+ logger.debug("Using MPS (Apple Silicon GPU) with optimizations")
+ except Exception as e:
+ logger.debug(f"MPS optimization setup: {e}")
+ else:
+ logger.debug("Using CPU")
+
+
+def mps_synchronize() -> None:
+ """
+ Synchronize MPS operations.
+
+ Call this when you need to ensure all MPS operations are complete,
+ such as before timing or when switching between GPU and CPU operations.
+ """
+ if is_mps_available():
+ torch.mps.synchronize()
+
+
+def empty_cache() -> None:
+ """
+ Empty the GPU cache to free memory.
+
+ Works for both CUDA and MPS backends.
+ """
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ elif is_mps_available():
+ torch.mps.empty_cache()
+
+
+def get_device_for_tensor(tensor: torch.Tensor) -> torch.device:
+ """Get the device of a tensor."""
+ return tensor.device
+
+
+def tensor_is_on_gpu(tensor: torch.Tensor) -> bool:
+ """Check if tensor is on a GPU (CUDA or MPS)."""
+ device_type = tensor.device.type
+ return device_type in ("cuda", "mps")
+
+
+def tensor_is_on_cuda(tensor: torch.Tensor) -> bool:
+ """Check if tensor is specifically on CUDA."""
+ return tensor.device.type == "cuda"
+
+
+def tensor_is_on_mps(tensor: torch.Tensor) -> bool:
+ """Check if tensor is specifically on MPS."""
+ return tensor.device.type == "mps"
+
+
+def move_model_to_device(
+ model: torch.nn.Module,
+ device: Optional[Union[str, torch.device]] = None,
+) -> torch.nn.Module:
+ """
+ Move a model to the specified device.
+
+ Args:
+ model: The model to move
+ device: Target device. If None, uses get_device()
+
+ Returns:
+ The model on the target device
+ """
+ if device is None:
+ device = get_device()
+ return model.to(device)
diff --git a/tests/test_device_support.py b/tests/test_device_support.py
new file mode 100644
index 00000000..f0246de5
--- /dev/null
+++ b/tests/test_device_support.py
@@ -0,0 +1,324 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
+
+"""
+Tests for CPU and MPS (Apple Silicon) device support.
+
+Run with: pytest tests/test_device_support.py -v
+"""
+
+import pytest
+import torch
+
+
+class TestDeviceUtilities:
+ """Test the device utility module."""
+
+ def test_device_module_imports(self):
+ """Test that device utilities can be imported."""
+ from sam3.utils.device import (
+ get_device,
+ get_device_str,
+ is_cuda_available,
+ is_gpu_available,
+ is_mps_available,
+ setup_device_optimizations,
+ tensor_is_on_cuda,
+ tensor_is_on_gpu,
+ tensor_is_on_mps,
+ to_device,
+ )
+
+ # All functions should be callable
+ assert callable(get_device)
+ assert callable(get_device_str)
+ assert callable(is_cuda_available)
+ assert callable(is_mps_available)
+ assert callable(is_gpu_available)
+ assert callable(to_device)
+ assert callable(setup_device_optimizations)
+
+ def test_get_device_returns_valid_device(self):
+ """Test that get_device returns a valid torch.device."""
+ from sam3.utils.device import get_device
+
+ device = get_device()
+ assert isinstance(device, torch.device)
+ assert device.type in ("cuda", "mps", "cpu")
+
+ def test_get_device_str_returns_string(self):
+ """Test that get_device_str returns a string."""
+ from sam3.utils.device import get_device_str
+
+ device_str = get_device_str()
+ assert isinstance(device_str, str)
+ assert device_str in ("cuda", "mps", "cpu")
+
+ def test_device_detection_consistency(self):
+ """Test that device detection functions are consistent."""
+ from sam3.utils.device import (
+ get_device,
+ is_cuda_available,
+ is_gpu_available,
+ is_mps_available,
+ )
+
+ device = get_device()
+
+ # If CUDA is available, device should be CUDA
+ if is_cuda_available():
+ assert device.type == "cuda"
+ assert is_gpu_available()
+ # If MPS is available and CUDA is not, device should be MPS
+ elif is_mps_available():
+ assert device.type == "mps"
+ assert is_gpu_available()
+ # Otherwise, device should be CPU
+ else:
+ assert device.type == "cpu"
+
+ def test_to_device_moves_tensor(self):
+ """Test that to_device correctly moves tensors."""
+ from sam3.utils.device import get_device, to_device
+
+ tensor = torch.randn(3, 3)
+ moved_tensor = to_device(tensor)
+
+ expected_device = get_device()
+ assert moved_tensor.device.type == expected_device.type
+
+ def test_tensor_device_checks(self):
+ """Test tensor device check functions."""
+ from sam3.utils.device import (
+ tensor_is_on_cuda,
+ tensor_is_on_gpu,
+ tensor_is_on_mps,
+ )
+
+ cpu_tensor = torch.randn(3, 3, device="cpu")
+ assert not tensor_is_on_cuda(cpu_tensor)
+ assert not tensor_is_on_mps(cpu_tensor)
+ assert not tensor_is_on_gpu(cpu_tensor)
+
+
+class TestCPUSupport:
+ """Test that operations work on CPU."""
+
+ def test_sigmoid_focal_loss_cpu(self):
+ """Test sigmoid focal loss works on CPU."""
+ from sam3.train.loss.sigmoid_focal_loss import (
+ sigmoid_focal_loss,
+ sigmoid_focal_loss_reduce,
+ )
+
+ inputs = torch.randn(10, 5, device="cpu", requires_grad=True)
+ targets = torch.rand(10, 5, device="cpu")
+
+ # Test unreduced version
+ loss = sigmoid_focal_loss(inputs, targets)
+ assert loss.device.type == "cpu"
+ assert loss.shape == inputs.shape
+
+ # Test reduced version
+ loss_reduced = sigmoid_focal_loss_reduce(inputs, targets)
+ assert loss_reduced.device.type == "cpu"
+ assert loss_reduced.dim() == 0 # scalar
+
+ # Test backward pass
+ loss_reduced.backward()
+ assert inputs.grad is not None
+ assert inputs.grad.shape == inputs.shape
+
+ def test_edt_cpu(self):
+ """Test EDT (Euclidean Distance Transform) works on CPU."""
+ from sam3.model.edt import edt
+
+ # Create a batch of binary masks
+ data = torch.zeros(2, 64, 64, device="cpu")
+ data[:, 20:40, 20:40] = 1 # Square in the middle
+
+ result = edt(data)
+ assert result.device.type == "cpu"
+ assert result.shape == data.shape
+ # EDT of zeros should be zero
+ assert (result[data == 0] == 0).all()
+
+ def test_nms_cpu(self):
+ """Test NMS works on CPU."""
+ from sam3.perflib.nms import generic_nms
+
+ n = 10
+ # Create a symmetric IoU matrix
+ ious = torch.rand(n, n, device="cpu")
+ ious = (ious + ious.T) / 2 # Make symmetric
+ ious.fill_diagonal_(1.0) # Diagonal should be 1
+
+ scores = torch.rand(n, device="cpu")
+
+ kept = generic_nms(ious, scores, iou_threshold=0.5)
+ assert kept.device.type == "cpu"
+ assert kept.dim() == 1
+ assert len(kept) <= n
+
+ def test_connected_components_cpu(self):
+ """Test connected components works on CPU."""
+ from sam3.perflib.connected_components import connected_components
+
+ # Create a batch of binary masks with distinct components
+ data = torch.zeros(2, 1, 64, 64, device="cpu", dtype=torch.uint8)
+ data[0, 0, 10:20, 10:20] = 1 # Component 1
+ data[0, 0, 40:50, 40:50] = 1 # Component 2
+ data[1, 0, 5:15, 5:15] = 1 # Component in second batch
+
+ labels, counts = connected_components(data)
+ assert labels.device.type == "cpu"
+ assert counts.device.type == "cpu"
+ assert labels.shape == data.shape
+ assert counts.shape == data.shape
+
+
+@pytest.mark.skipif(
+ not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()),
+ reason="MPS not available",
+)
+class TestMPSSupport:
+ """Test that operations work on MPS (Apple Silicon)."""
+
+ def test_sigmoid_focal_loss_mps(self):
+ """Test sigmoid focal loss works on MPS."""
+ from sam3.train.loss.sigmoid_focal_loss import (
+ sigmoid_focal_loss,
+ sigmoid_focal_loss_reduce,
+ )
+
+ inputs = torch.randn(10, 5, device="mps", requires_grad=True)
+ targets = torch.rand(10, 5, device="mps")
+
+ # Test unreduced version
+ loss = sigmoid_focal_loss(inputs, targets)
+ assert loss.device.type == "mps"
+ assert loss.shape == inputs.shape
+
+ # Test reduced version
+ loss_reduced = sigmoid_focal_loss_reduce(inputs, targets)
+ assert loss_reduced.device.type == "mps"
+
+ def test_edt_mps(self):
+ """Test EDT works on MPS (falls back to CPU internally)."""
+ from sam3.model.edt import edt
+
+ # Create a batch of binary masks on MPS
+ data = torch.zeros(2, 64, 64, device="mps")
+ data[:, 20:40, 20:40] = 1
+
+ result = edt(data)
+ # Result should be on MPS (moved back after CPU computation)
+ assert result.device.type == "mps"
+ assert result.shape == data.shape
+
+ def test_nms_mps(self):
+ """Test NMS works on MPS (falls back to CPU internally)."""
+ from sam3.perflib.nms import generic_nms
+
+ n = 10
+ ious = torch.rand(n, n, device="mps")
+ ious = (ious + ious.T) / 2
+ ious.fill_diagonal_(1.0)
+ scores = torch.rand(n, device="mps")
+
+ kept = generic_nms(ious, scores, iou_threshold=0.5)
+ # Result should be on MPS
+ assert kept.device.type == "mps"
+
+ def test_connected_components_mps(self):
+ """Test connected components works on MPS."""
+ from sam3.perflib.connected_components import connected_components
+
+ data = torch.zeros(2, 1, 64, 64, device="mps", dtype=torch.uint8)
+ data[0, 0, 10:20, 10:20] = 1
+ data[0, 0, 40:50, 40:50] = 1
+
+ labels, counts = connected_components(data)
+ # Results should be on MPS
+ assert labels.device.type == "mps"
+ assert counts.device.type == "mps"
+
+ def test_device_detection_mps(self):
+ """Test that MPS is detected when available."""
+ from sam3.utils.device import get_device, is_gpu_available, is_mps_available
+
+ assert is_mps_available()
+ assert is_gpu_available()
+ # If CUDA is not available, MPS should be the default
+ if not torch.cuda.is_available():
+ assert get_device().type == "mps"
+
+
+class TestModelBuilderDeviceSupport:
+ """Test model builder device handling."""
+
+ def test_device_parameter_accepted(self):
+ """Test that build functions accept device parameter."""
+ from sam3.model_builder import build_sam3_image_model, build_sam3_video_model
+ import inspect
+
+ # Check that device parameter exists
+ image_sig = inspect.signature(build_sam3_image_model)
+ video_sig = inspect.signature(build_sam3_video_model)
+
+ assert "device" in image_sig.parameters
+ assert "device" in video_sig.parameters
+
+ # Check defaults are None (auto-detect)
+ assert image_sig.parameters["device"].default is None
+ assert video_sig.parameters["device"].default is None
+
+
+class TestTransformerDeviceSupport:
+ """Test transformer module device handling."""
+
+ def test_rope_attention_cpu(self):
+ """Test RoPEAttention works on CPU."""
+ from sam3.sam.transformer import RoPEAttention
+
+ attention = RoPEAttention(
+ embedding_dim=256,
+ num_heads=8,
+ downsample_rate=1,
+ feat_sizes=(8, 8),
+ )
+ attention = attention.to("cpu")
+
+ # Create dummy inputs
+ batch_size = 2
+ seq_len = 64
+ q = torch.randn(batch_size, seq_len, 256, device="cpu")
+ k = torch.randn(batch_size, seq_len, 256, device="cpu")
+ v = torch.randn(batch_size, seq_len, 256, device="cpu")
+
+ output = attention(q, k, v)
+ assert output.device.type == "cpu"
+ assert output.shape == (batch_size, seq_len, 256)
+
+ def test_attention_cpu(self):
+ """Test base Attention works on CPU."""
+ from sam3.sam.transformer import Attention
+
+ attention = Attention(
+ embedding_dim=256,
+ num_heads=8,
+ )
+ attention = attention.to("cpu")
+
+ batch_size = 2
+ seq_len = 64
+ q = torch.randn(batch_size, seq_len, 256, device="cpu")
+ k = torch.randn(batch_size, seq_len, 256, device="cpu")
+ v = torch.randn(batch_size, seq_len, 256, device="cpu")
+
+ output = attention(q, k, v)
+ assert output.device.type == "cpu"
+ assert output.shape == (batch_size, seq_len, 256)
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])