diff --git a/README.md b/README.md index 669242df..d3493320 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# SAM 3: Segment Anything with Concepts +# SAM 3: Segment Anything with Concepts WITH MPS/CPU SUPPORT FOR APPLE METAL Meta Superintelligence Labs diff --git a/examples/live_camera_segmentation.py b/examples/live_camera_segmentation.py new file mode 100644 index 00000000..283149fc --- /dev/null +++ b/examples/live_camera_segmentation.py @@ -0,0 +1,970 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +""" +Live Camera Segmentation with SAM3 + +This script captures video from a device camera and runs real-time segmentation +using SAM3. It supports text-based detection or interactive point/box prompts. + +Usage: + # Detect objects using text prompt + python live_camera_segmentation.py --prompt "person" + + # Detect multiple object types using comma-separated prompts + python live_camera_segmentation.py --prompt "person, car, dog, cat" + + # Use specific camera device + python live_camera_segmentation.py --camera 0 --prompt "cat" + + # Specify device (cuda, mps, or cpu) + python live_camera_segmentation.py --device mps --prompt "dog" + + # Interactive mode - click to add box prompts + python live_camera_segmentation.py --interactive + + # Skip frames with tracking (masks follow objects between full inference frames) + python live_camera_segmentation.py --prompt "person" --skip-frames 5 --track + +Controls: + - 'q' or ESC: Quit + - 'r': Reset/clear all segments + - 's': Save current frame + - 'p': Pause/resume + - Left click + drag: Draw box prompt (in interactive mode) + - 't': Enter new text prompt +""" + +import argparse +import time +from collections import deque +from typing import Optional, Tuple + +import cv2 +import numpy as np +import torch +from PIL import Image + +from sam3.utils.device import get_device, get_device_str, setup_device_optimizations + + +class LiveCameraSegmenter: + """Real-time camera segmentation using SAM3.""" + + # Color palette for different object masks (BGR format for OpenCV) + COLORS = [ + (255, 0, 0), # Blue + (0, 255, 0), # Green + (0, 0, 255), # Red + (255, 255, 0), # Cyan + (255, 0, 255), # Magenta + (0, 255, 255), # Yellow + (128, 0, 255), # Purple + (255, 128, 0), # Orange + (0, 128, 255), # Light blue + (128, 255, 0), # Lime + ] + + def __init__( + self, + camera_id: int = 0, + device: Optional[str] = None, + text_prompt: str = "object", + confidence_threshold: float = 0.3, + checkpoint_path: Optional[str] = None, + interactive: bool = False, + process_every_n_frames: int = 1, + use_half_precision: bool = False, + enable_tracking: bool = False, + ): + """ + Initialize the live camera segmenter. + + Args: + camera_id: Camera device ID (default 0 for primary camera) + device: Device to run on ('cuda', 'mps', 'cpu', or None for auto) + text_prompt: Text description of objects to detect + confidence_threshold: Confidence threshold for detections + checkpoint_path: Optional path to model checkpoint + interactive: Enable interactive box-based prompting + process_every_n_frames: Only process every N frames (higher = faster but less smooth) + use_half_precision: Use float16 for faster inference (may reduce accuracy) + enable_tracking: Enable mask tracking between skipped frames + """ + self.camera_id = camera_id + self.device_str = device if device else get_device_str() + self.device = torch.device(self.device_str) + self.text_prompt = text_prompt + self.confidence_threshold = confidence_threshold + self.interactive = interactive + self.process_every_n_frames = process_every_n_frames + self.use_half_precision = use_half_precision + self.enable_tracking = enable_tracking + self.frame_count = 0 + + # State + self.paused = False + self.state = None + self.fps_history = deque(maxlen=30) + + # Tracking state + self.tracker = None + self.tracker_state = None + self.last_masks = None + self.last_boxes = None + self.last_scores = None # Store confidence scores + self.last_labels = None # Store per-object labels for multi-prompt mode + self.video_height = None + self.video_width = None + + # For interactive box drawing + self.drawing = False + self.box_start = None + self.box_end = None + + print(f"Initializing SAM3 on device: {self.device}") + self._load_model(checkpoint_path) + + def _load_model(self, checkpoint_path: Optional[str] = None): + """Load the SAM3 model and processor.""" + from sam3.model_builder import build_sam3_image_model + from sam3.model.sam3_image_processor import Sam3Processor + + # Setup device-specific optimizations (MPS memory, CUDA TF32, etc.) + setup_device_optimizations() + + print("Loading SAM3 model...") + model = build_sam3_image_model( + device=self.device_str, + checkpoint_path=checkpoint_path, + load_from_HF=checkpoint_path is None, + eval_mode=True, + enable_segmentation=True, + ) + + # Convert to half precision for faster inference (CUDA only - MPS doesn't support it) + if self.use_half_precision: + if self.device_str == "mps": + print("Warning: Half precision not supported on MPS due to Metal limitations, using float32") + self.use_half_precision = False + else: + print("Converting model to half precision (float16)...") + model = model.half() + + self.processor = Sam3Processor( + model=model, + resolution=1008, # Fixed resolution due to precomputed positional encodings + device=self.device_str, + confidence_threshold=self.confidence_threshold, + ) + print("Model loaded successfully!") + + # For tracking between keyframes, we use optical flow instead of the full SAM3 tracker + # This provides lightweight motion-based tracking without device compatibility issues + if self.enable_tracking: + print("Tracking mode enabled - using optical flow for inter-frame tracking") + self.prev_gray = None # Store previous frame for optical flow + + def _load_tracker(self, checkpoint_path: Optional[str] = None): + """Load the SAM3 tracker for mask propagation between frames.""" + from sam3.model_builder import build_tracker + + print("Loading SAM3 tracker for inter-frame tracking...") + + # Build tracker with backbone for processing new frames + self.tracker = build_tracker( + apply_temporal_disambiguation=True, + with_backbone=True, + ) + self.tracker = self.tracker.to(self.device) + self.tracker.eval() + + # Try to load tracker weights from the same source as the main model + # The tracker shares weights with the main SAM3 model + import os + tracker_ckpt_path = None + + # Use provided checkpoint path first + if checkpoint_path and os.path.exists(checkpoint_path): + tracker_ckpt_path = checkpoint_path + else: + # Check common locations for the checkpoint + # Get the directory where this script is located + script_dir = os.path.dirname(os.path.abspath(__file__)) + possible_paths = [ + os.path.join(script_dir, "sam3.pt"), # Same folder as script (examples/) + "sam3.pt", + "./sam3.pt", + "../sam3.pt", + "examples/sam3.pt", + os.path.expanduser("~/.cache/huggingface/hub/models--facebook--sam3/sam3.pt"), + ] + + for path in possible_paths: + if os.path.exists(path): + tracker_ckpt_path = path + break + + if tracker_ckpt_path is None: + print("Warning: Could not find sam3.pt checkpoint for tracker.") + print("Please ensure sam3.pt is in the current directory or provide --checkpoint path.") + print("Tracking will be disabled.") + self.tracker = None + return + + print(f"Loading tracker weights from: {tracker_ckpt_path}") + tracker_state_dict = torch.load(tracker_ckpt_path, map_location=self.device, weights_only=False) + + # Filter and load tracker-compatible weights + tracker_keys = set(k for k in self.tracker.state_dict().keys()) + filtered_state_dict = {k: v for k, v in tracker_state_dict.items() if k in tracker_keys} + self.tracker.load_state_dict(filtered_state_dict, strict=False) + + print("Tracker loaded successfully!") + + def _init_tracker_state(self, height: int, width: int): + """Initialize tracking state for a video stream.""" + self.video_height = height + self.video_width = width + # Reset masks and optical flow state + self.last_masks = None + self.last_boxes = None + self.last_scores = None + self.last_labels = None + self.prev_gray = None + + def _track_frame(self, frame: np.ndarray, frame_idx: int) -> Optional[torch.Tensor]: + """ + Use optical flow to track masks to a new frame. + + This provides lightweight motion-based tracking between keyframes + without needing the full SAM3 tracker model. + + Returns the tracked masks or None if tracking isn't available. + """ + if self.last_masks is None or len(self.last_masks) == 0: + return None + + if self.prev_gray is None: + return self.last_masks + + try: + # Convert current frame to grayscale + curr_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + + # Calculate dense optical flow using Farneback method + flow = cv2.calcOpticalFlowFarneback( + self.prev_gray, curr_gray, + None, + pyr_scale=0.5, + levels=3, + winsize=15, + iterations=3, + poly_n=5, + poly_sigma=1.2, + flags=0 + ) + + # Create coordinate grids for remapping + h, w = curr_gray.shape + flow_map_x = np.arange(w).reshape(1, -1).repeat(h, axis=0).astype(np.float32) + flow_map_y = np.arange(h).reshape(-1, 1).repeat(w, axis=1).astype(np.float32) + + # Add flow to get new positions + flow_map_x += flow[:, :, 0] + flow_map_y += flow[:, :, 1] + + # Warp each mask using the flow + tracked_masks = [] + for mask in self.last_masks: + # Convert mask to numpy for warping + if isinstance(mask, torch.Tensor): + mask_np = mask.cpu().numpy().squeeze() + else: + mask_np = mask.squeeze() + + # Ensure mask is the right size + if mask_np.shape != (h, w): + mask_np = cv2.resize(mask_np.astype(np.float32), (w, h)) + + # Warp mask using optical flow + warped_mask = cv2.remap( + mask_np.astype(np.float32), + flow_map_x, flow_map_y, + interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=0 + ) + + # Threshold to get binary mask + warped_mask = (warped_mask > 0.5).astype(np.float32) + + # Convert back to tensor + tracked_masks.append( + torch.from_numpy(warped_mask).unsqueeze(0).to(self.device) + ) + + # Update prev_gray for next iteration + self.prev_gray = curr_gray + + if tracked_masks: + return torch.stack(tracked_masks) + + except Exception as e: + print(f"Optical flow tracking error: {e}") + + return self.last_masks + + def _add_mask_to_tracker(self, masks: torch.Tensor, frame: np.ndarray, frame_idx: int): + """Store frame for optical flow tracking.""" + # Store grayscale frame for optical flow computation + self.prev_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + # Masks are already stored in self.last_masks by the caller + + def _process_frame(self, frame: np.ndarray) -> dict: + """Process a frame through SAM3.""" + # Convert BGR to RGB PIL Image + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_image = Image.fromarray(frame_rgb) + + # Set the image + self.state = self.processor.set_image(pil_image, self.state) + + # Run text-based detection + if not self.interactive: + # Support multiple prompts separated by commas + prompts = [p.strip() for p in self.text_prompt.split(',')] + + if len(prompts) == 1: + # Single prompt - use normal detection + self.state = self.processor.set_text_prompt(prompts[0], self.state) + else: + # Multiple prompts - run detection for each and combine results + all_masks = [] + all_boxes = [] + all_scores = [] + all_labels = [] + + for prompt in prompts: + # Reset geometric prompt for each detection + if "geometric_prompt" in self.state: + del self.state["geometric_prompt"] + + self.state = self.processor.set_text_prompt(prompt, self.state) + + masks = self.state.get("masks") + boxes = self.state.get("boxes") + scores = self.state.get("scores") + + if masks is not None and masks.numel() > 0: + for i in range(len(masks)): + all_masks.append(masks[i:i+1]) + if boxes is not None and i < len(boxes): + all_boxes.append(boxes[i:i+1]) + if scores is not None and i < len(scores): + all_scores.append(scores[i:i+1]) + all_labels.append(prompt) + + # Combine all detections + if all_masks: + self.state["masks"] = torch.cat(all_masks, dim=0) + self.state["boxes"] = torch.cat(all_boxes, dim=0) if all_boxes else None + self.state["scores"] = torch.cat(all_scores, dim=0) if all_scores else None + self.state["labels"] = all_labels # Store labels for each detection + else: + self.state["masks"] = None + self.state["boxes"] = None + self.state["scores"] = None + self.state["labels"] = [] + + return self.state + + def _add_box_prompt(self, box: Tuple[int, int, int, int], frame_size: Tuple[int, int]): + """Add a box prompt in interactive mode.""" + if self.state is None: + return + + h, w = frame_size + x1, y1, x2, y2 = box + + # Convert to center format and normalize to [0, 1] + cx = (x1 + x2) / 2 / w + cy = (y1 + y2) / 2 / h + bw = abs(x2 - x1) / w + bh = abs(y2 - y1) / h + + normalized_box = [cx, cy, bw, bh] + self.state = self.processor.add_geometric_prompt( + box=normalized_box, + label=True, # Positive box + state=self.state, + ) + + def _overlay_masks( + self, + frame: np.ndarray, + masks: torch.Tensor, + boxes: torch.Tensor = None, + scores: torch.Tensor = None, + labels: list = None, + alpha: float = 0.5, + ) -> np.ndarray: + """Overlay segmentation masks on the frame with labels and confidence scores.""" + if masks is None or masks.numel() == 0: + return frame + + overlay = frame.copy() + h, w = frame.shape[:2] + + # masks shape: [N, 1, H, W] + masks_np = masks.squeeze(1).cpu().numpy() + + # Get scores if available + scores_np = None + if scores is not None: + scores_np = scores.cpu().numpy() + + # Get boxes if available + boxes_np = None + if boxes is not None: + boxes_np = boxes.cpu().numpy() + + for i, mask in enumerate(masks_np): + # Resize mask to frame size if needed + if mask.shape != (h, w): + mask = cv2.resize(mask.astype(np.float32), (w, h)) > 0.5 + + # Get color for this mask + color = self.COLORS[i % len(self.COLORS)] + + # Create colored overlay + mask_region = mask.astype(bool) + overlay[mask_region] = ( + overlay[mask_region] * (1 - alpha) + + np.array(color) * alpha + ).astype(np.uint8) + + # Draw contour + contours, _ = cv2.findContours( + mask.astype(np.uint8), + cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_SIMPLE + ) + cv2.drawContours(overlay, contours, -1, color, 2) + + # Draw label with confidence score + # Find the top-center of the mask for label placement + if len(contours) > 0: + # Get bounding rect of largest contour + largest_contour = max(contours, key=cv2.contourArea) + x, y, cw, ch = cv2.boundingRect(largest_contour) + + # Get confidence score + conf = scores_np[i] if scores_np is not None and i < len(scores_np) else 0.0 + + # Get label - use per-object label if available, otherwise use prompt + if labels is not None and i < len(labels): + obj_label = labels[i] + else: + obj_label = self.text_prompt.split(',')[0].strip() # Use first prompt as fallback + + label = f"{obj_label}" + conf_text = f"{conf:.0%}" + + # Draw label background + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.5 + thickness = 1 + + # Get text sizes + (label_w, label_h), _ = cv2.getTextSize(label, font, font_scale, thickness) + (conf_w, conf_h), _ = cv2.getTextSize(conf_text, font, font_scale, thickness) + + # Position at top of bounding box + label_x = x + cw // 2 - label_w // 2 + label_y = max(y - 5, label_h + 5) + + # Draw label background + cv2.rectangle(overlay, + (label_x - 2, label_y - label_h - 2), + (label_x + label_w + 2, label_y + 2), + color, -1) + + # Draw label text + cv2.putText(overlay, label, + (label_x, label_y), + font, font_scale, (255, 255, 255), thickness) + + # Draw confidence below label + conf_x = x + cw // 2 - conf_w // 2 + conf_y = label_y + conf_h + 8 + + cv2.rectangle(overlay, + (conf_x - 2, conf_y - conf_h - 2), + (conf_x + conf_w + 2, conf_y + 2), + (0, 0, 0), -1) + cv2.putText(overlay, conf_text, + (conf_x, conf_y), + font, font_scale, (0, 255, 0), thickness) + + return overlay + + def _draw_boxes(self, frame: np.ndarray, boxes: torch.Tensor, scores: torch.Tensor = None) -> np.ndarray: + """Draw bounding boxes on the frame with labels.""" + if boxes is None or boxes.numel() == 0: + return frame + + boxes_np = boxes.cpu().numpy() + scores_np = scores.cpu().numpy() if scores is not None else None + + for i, box in enumerate(boxes_np): + x1, y1, x2, y2 = box.astype(int) + color = self.COLORS[i % len(self.COLORS)] + cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2) + + return frame + + def _draw_object_panel(self, frame: np.ndarray, masks: torch.Tensor, + boxes: torch.Tensor, scores: torch.Tensor, + labels: list = None) -> np.ndarray: + """Draw an info panel on the right side showing detected objects.""" + h, w = frame.shape[:2] + + # Panel dimensions + panel_width = 200 + panel_x = w - panel_width - 10 + + # Count objects + num_objects = len(masks) if masks is not None else 0 + + # Calculate panel height based on number of objects + header_height = 40 + object_height = 50 + panel_height = header_height + max(num_objects, 1) * object_height + 20 + + # Draw semi-transparent panel background + overlay = frame.copy() + cv2.rectangle(overlay, + (panel_x, 10), + (w - 10, min(10 + panel_height, h - 10)), + (0, 0, 0), -1) + frame = cv2.addWeighted(overlay, 0.7, frame, 0.3, 0) + + # Draw panel header + font = cv2.FONT_HERSHEY_SIMPLEX + cv2.putText(frame, "DETECTED OBJECTS", + (panel_x + 10, 35), + font, 0.5, (255, 255, 255), 1) + cv2.line(frame, (panel_x + 5, 45), (w - 15, 45), (100, 100, 100), 1) + + if num_objects == 0: + cv2.putText(frame, "No objects found", + (panel_x + 10, 75), + font, 0.4, (150, 150, 150), 1) + return frame + + # Draw each object + masks_np = masks.squeeze(1).cpu().numpy() if masks is not None else [] + scores_np = scores.cpu().numpy() if scores is not None else [] + boxes_np = boxes.cpu().numpy() if boxes is not None else [] + + for i in range(num_objects): + y_offset = header_height + 15 + i * object_height + + if 10 + y_offset + 40 > h - 10: + # Panel would exceed frame height + cv2.putText(frame, f"... +{num_objects - i} more", + (panel_x + 10, 10 + y_offset), + font, 0.4, (150, 150, 150), 1) + break + + color = self.COLORS[i % len(self.COLORS)] + + # Color indicator + cv2.rectangle(frame, + (panel_x + 10, 10 + y_offset), + (panel_x + 25, 10 + y_offset + 15), + color, -1) + + # Object label - use per-object label if available + if labels is not None and i < len(labels): + obj_label = labels[i] + else: + obj_label = self.text_prompt.split(',')[0].strip() + + # Truncate label if too long + if len(obj_label) > 15: + obj_label = obj_label[:12] + "..." + + cv2.putText(frame, obj_label, + (panel_x + 35, 10 + y_offset + 12), + font, 0.4, (255, 255, 255), 1) + + # Confidence score + if i < len(scores_np): + conf = scores_np[i] + conf_color = (0, 255, 0) if conf > 0.7 else (0, 255, 255) if conf > 0.4 else (0, 0, 255) + cv2.putText(frame, f"Conf: {conf:.0%}", + (panel_x + 35, 10 + y_offset + 28), + font, 0.35, conf_color, 1) + + # Bounding box size + if i < len(boxes_np): + box = boxes_np[i] + bw = int(box[2] - box[0]) + bh = int(box[3] - box[1]) + cv2.putText(frame, f"Size: {bw}x{bh}", + (panel_x + 100, 10 + y_offset + 28), + font, 0.35, (150, 150, 150), 1) + + return frame + + def _draw_info(self, frame: np.ndarray, fps: float, num_objects: int) -> np.ndarray: + """Draw information overlay on the frame.""" + h, w = frame.shape[:2] + + # Semi-transparent background for text + overlay = frame.copy() + info_height = 165 if self.enable_tracking else 140 + cv2.rectangle(overlay, (10, 10), (350, info_height), (0, 0, 0), -1) + frame = cv2.addWeighted(overlay, 0.3, frame, 0.7, 0) + + # Draw text + font = cv2.FONT_HERSHEY_SIMPLEX + cv2.putText(frame, f"FPS: {fps:.1f}", (20, 35), font, 0.6, (255, 255, 255), 2) + cv2.putText(frame, f"Objects: {num_objects}", (20, 60), font, 0.6, (255, 255, 255), 2) + cv2.putText(frame, f"Device: {self.device_str}", (20, 85), font, 0.6, (255, 255, 255), 2) + + mode = "Interactive" if self.interactive else f"Prompt: {self.text_prompt}" + cv2.putText(frame, f"Mode: {mode}", (20, 110), font, 0.6, (255, 255, 255), 2) + cv2.putText(frame, f"Threshold: {self.confidence_threshold:.2f}", (20, 135), font, 0.6, (255, 255, 255), 2) + + if self.enable_tracking: + skip_info = f"Skip: {self.process_every_n_frames} (tracking ON)" + cv2.putText(frame, skip_info, (20, 160), font, 0.6, (0, 255, 0), 2) + + # Draw controls hint at bottom + hint = "Q: Quit | R: Reset | S: Save | P: Pause | T: New prompt" + cv2.putText(frame, hint, (10, h - 10), font, 0.4, (200, 200, 200), 1) + + return frame + + def _draw_current_box(self, frame: np.ndarray) -> np.ndarray: + """Draw the box currently being drawn.""" + if self.drawing and self.box_start and self.box_end: + cv2.rectangle( + frame, + self.box_start, + self.box_end, + (0, 255, 0), + 2 + ) + return frame + + def _mouse_callback(self, event, x, y, flags, param): + """Handle mouse events for interactive mode.""" + if not self.interactive: + return + + if event == cv2.EVENT_LBUTTONDOWN: + self.drawing = True + self.box_start = (x, y) + self.box_end = (x, y) + + elif event == cv2.EVENT_MOUSEMOVE: + if self.drawing: + self.box_end = (x, y) + + elif event == cv2.EVENT_LBUTTONUP: + if self.drawing: + self.drawing = False + self.box_end = (x, y) + + # Add the box prompt if it's a valid box + x1, y1 = self.box_start + x2, y2 = self.box_end + if abs(x2 - x1) > 5 and abs(y2 - y1) > 5: + frame_size = param # Passed as param + self._add_box_prompt((x1, y1, x2, y2), frame_size) + + self.box_start = None + self.box_end = None + + def run(self): + """Run the live camera segmentation loop.""" + # Open camera + print(f"Opening camera {self.camera_id}...") + cap = cv2.VideoCapture(self.camera_id) + + if not cap.isOpened(): + print(f"Error: Could not open camera {self.camera_id}") + return + + # Get camera properties + frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + print(f"Camera resolution: {frame_width}x{frame_height}") + + # Initialize tracker state if tracking is enabled + if self.enable_tracking: + print("Initializing tracker state...") + self._init_tracker_state(frame_height, frame_width) + + # Create window + window_name = "SAM3 Live Segmentation" + cv2.namedWindow(window_name, cv2.WINDOW_NORMAL) + cv2.setMouseCallback(window_name, self._mouse_callback, (frame_height, frame_width)) + + print("\nStarting live segmentation...") + print("Controls:") + print(" Q/ESC: Quit") + print(" R: Reset segments") + print(" S: Save frame") + print(" P: Pause/resume") + print(" T: Enter new text prompt") + if self.interactive: + print(" Left click + drag: Draw box prompt") + + frame_count = 0 + try: + while True: + start_time = time.time() + + # Capture frame + ret, frame = cap.read() + if not ret: + print("Failed to grab frame") + break + + display_frame = frame.copy() + self.frame_count += 1 + + if not self.paused: + is_keyframe = self.frame_count % self.process_every_n_frames == 0 + + if is_keyframe: + # Full inference frame - run text detection + self._process_frame(frame) + + # Store masks, boxes, scores, and labels for tracking + if self.state is not None: + self.last_masks = self.state.get("masks") + self.last_boxes = self.state.get("boxes") + self.last_scores = self.state.get("scores") + self.last_labels = self.state.get("labels") + + # Add masks to tracker for memory-based propagation + if self.enable_tracking and self.last_masks is not None: + self._add_mask_to_tracker(self.last_masks, frame, self.frame_count) + + elif self.enable_tracking and self.last_masks is not None: + # Intermediate frame - use tracker to propagate masks + tracked_masks = self._track_frame(frame, self.frame_count) + if tracked_masks is not None: + self.last_masks = tracked_masks + # Update state with tracked masks + if self.state is not None: + self.state["masks"] = tracked_masks + # else: Just reuse last masks (no tracking) + + # Overlay results - use last_masks if tracking is enabled + masks_to_display = None + boxes_to_display = None + scores_to_display = None + labels_to_display = None + + if self.enable_tracking: + masks_to_display = self.last_masks + boxes_to_display = self.last_boxes + scores_to_display = self.last_scores + labels_to_display = self.last_labels + elif self.state is not None: + masks_to_display = self.state.get("masks") + boxes_to_display = self.state.get("boxes") + scores_to_display = self.state.get("scores") + labels_to_display = self.state.get("labels") + + if masks_to_display is not None: + display_frame = self._overlay_masks( + display_frame, masks_to_display, + boxes=boxes_to_display, scores=scores_to_display, + labels=labels_to_display + ) + if boxes_to_display is not None: + display_frame = self._draw_boxes(display_frame, boxes_to_display, scores_to_display) + + # Draw object info panel on the right + display_frame = self._draw_object_panel( + display_frame, masks_to_display, boxes_to_display, scores_to_display, + labels=labels_to_display + ) + + # Draw current box being drawn + if self.interactive: + display_frame = self._draw_current_box(display_frame) + + # Calculate FPS + elapsed = time.time() - start_time + fps = 1.0 / elapsed if elapsed > 0 else 0 + self.fps_history.append(fps) + avg_fps = sum(self.fps_history) / len(self.fps_history) + + # Draw info overlay + num_objects = 0 + if masks_to_display is not None: + num_objects = len(masks_to_display) + display_frame = self._draw_info(display_frame, avg_fps, num_objects) + + # Show frame + cv2.imshow(window_name, display_frame) + + # Handle keyboard input + key = cv2.waitKey(1) & 0xFF + + if key == ord('q') or key == 27: # Q or ESC + print("Quitting...") + break + + elif key == ord('r'): # Reset + print("Resetting segments...") + if self.state is not None: + self.processor.reset_all_prompts(self.state) + self.state = None + self.last_masks = None + self.last_boxes = None + self.last_scores = None + self.last_labels = None + # Reset tracker state + if self.enable_tracking: + self._init_tracker_state(frame_height, frame_width) + + elif key == ord('s'): # Save + filename = f"sam3_capture_{frame_count}.png" + cv2.imwrite(filename, display_frame) + print(f"Saved frame to {filename}") + + elif key == ord('p'): # Pause + self.paused = not self.paused + print("Paused" if self.paused else "Resumed") + + elif key == ord('t'): # New text prompt + self.paused = True + new_prompt = input("Enter new text prompt: ").strip() + if new_prompt: + self.text_prompt = new_prompt + if self.state is not None: + self.processor.reset_all_prompts(self.state) + self.state = None + self.last_masks = None + self.last_boxes = None + self.last_scores = None + self.last_labels = None + # Reset tracker for new prompt + if self.enable_tracking: + self._init_tracker_state(frame_height, frame_width) + print(f"Text prompt set to: {self.text_prompt}") + self.paused = False + + frame_count += 1 + + except KeyboardInterrupt: + print("\nInterrupted by user") + + finally: + cap.release() + cv2.destroyAllWindows() + print("Cleanup complete") + + +def main(): + parser = argparse.ArgumentParser( + description="Live Camera Segmentation with SAM3", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "--camera", "-c", + type=int, + default=0, + help="Camera device ID (default: 0)", + ) + parser.add_argument( + "--device", "-d", + type=str, + default=None, + choices=["cuda", "mps", "cpu"], + help="Device to run on (default: auto-detect)", + ) + parser.add_argument( + "--prompt", + type=str, + default="object", + help="Text prompt for detection (default: 'object')", + ) + parser.add_argument( + "--threshold", + type=float, + default=0.3, + help="Detection confidence threshold (default: 0.3)", + ) + parser.add_argument( + "--checkpoint", + type=str, + default=None, + help="Path to model checkpoint (default: download from HuggingFace)", + ) + parser.add_argument( + "--interactive", "-i", + action="store_true", + help="Enable interactive box-based prompting", + ) + parser.add_argument( + "--skip-frames", + type=int, + default=1, + help="Process every N frames (higher = faster, default: 1)", + ) + parser.add_argument( + "--half", + action="store_true", + help="Use half precision (float16) for faster inference", + ) + parser.add_argument( + "--track", + action="store_true", + help="Enable mask tracking between skipped frames (smoother results when using --skip-frames)", + ) + + args = parser.parse_args() + + # Print device info + device = args.device or get_device_str() + print(f"SAM3 Live Camera Segmentation") + print(f"=" * 40) + print(f"Device: {device}") + print(f"Camera: {args.camera}") + print(f"Text prompt: {args.prompt}") + print(f"Threshold: {args.threshold}") + print(f"Interactive: {args.interactive}") + print(f"Skip frames: {args.skip_frames}") + print(f"Half precision: {args.half}") + print(f"Tracking: {args.track}") + print(f"=" * 40) + + # Create and run segmenter + segmenter = LiveCameraSegmenter( + camera_id=args.camera, + device=args.device, + text_prompt=args.prompt, + confidence_threshold=args.threshold, + checkpoint_path=args.checkpoint, + interactive=args.interactive, + process_every_n_frames=args.skip_frames, + use_half_precision=args.half, + enable_tracking=args.track, + ) + segmenter.run() + + +if __name__ == "__main__": + main() diff --git a/examples/web_command_center/app.py b/examples/web_command_center/app.py new file mode 100644 index 00000000..a2f17424 --- /dev/null +++ b/examples/web_command_center/app.py @@ -0,0 +1,5374 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +""" +SAM3 Web Command Center + +A Flask-based web interface for real-time object detection and tracking +using SAM3. Features include: +- Live camera feed with segmentation overlay +- Multi-prompt detection configuration +- Object count limits with show/hide functionality +- Claude Vision API integration for detailed object analysis +- Video tracking with memory (SAM3 tracker) +- Multi-object tracking with persistent IDs +- Mask refinement (fill holes, non-overlap) +- Advanced detection controls (boundary/occlusion suppression, hotstart) +- YOLO integration for classification and pose estimation +- Command center style interface with verbose logging + +Usage: + python app.py --prompt "person, car" --camera 0 + +Then open http://localhost:5000 in your browser. +""" + +import argparse +import base64 +import io +import ipaddress +import json +import os +import sqlite3 +import ssl +import sys +import threading +import time +import uuid +from collections import deque +from datetime import datetime +from typing import Optional, Dict, List, Any, Tuple + +import cv2 +import numpy as np +import torch +from PIL import Image +from flask import Flask, Response, render_template, request, jsonify +from scipy import ndimage + +# Load environment variables from .env file if present +try: + from dotenv import load_dotenv + # Look for .env in the web_command_center directory + env_path = os.path.join(os.path.dirname(__file__), '.env') + if os.path.exists(env_path): + load_dotenv(env_path) + print(f"Loaded environment from {env_path}") + else: + # Also check current working directory + load_dotenv() +except ImportError: + pass # python-dotenv not installed, rely on system environment + +# Add parent directory to path for sam3 imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) + +from sam3.utils.device import get_device, get_device_str, setup_device_optimizations, empty_cache + +app = Flask(__name__) + +# Global API key storage (can be set via CLI arg or environment) +ANTHROPIC_API_KEY = os.environ.get('ANTHROPIC_API_KEY') + + +# ===== SAM3 to COCO Label Mapping ===== +# Maps open-vocabulary SAM3 labels to COCO class indices for YOLO +SAM3_TO_COCO = { + # Person variations -> COCO class 0 + "person": 0, "human": 0, "man": 0, "woman": 0, "child": 0, "kid": 0, + "boy": 0, "girl": 0, "people": 0, "pedestrian": 0, "worker": 0, + "player": 0, "athlete": 0, "runner": 0, "cyclist": 0, + + # Vehicles + "bicycle": 1, "bike": 1, "cycle": 1, + "car": 2, "automobile": 2, "vehicle": 2, "sedan": 2, "suv": 2, + "motorcycle": 3, "motorbike": 3, "scooter": 3, + "airplane": 4, "plane": 4, "aircraft": 4, "jet": 4, + "bus": 5, "coach": 5, + "train": 6, "locomotive": 6, "railway": 6, + "truck": 7, "lorry": 7, "pickup": 7, + "boat": 8, "ship": 8, "vessel": 8, "yacht": 8, + + # Traffic + "traffic light": 9, "stoplight": 9, + "fire hydrant": 10, "hydrant": 10, + "stop sign": 11, + "parking meter": 12, + + # Animals + "bird": 14, "sparrow": 14, "pigeon": 14, "crow": 14, + "cat": 15, "kitten": 15, "feline": 15, "kitty": 15, + "dog": 16, "puppy": 16, "canine": 16, "hound": 16, + "horse": 17, "pony": 17, "stallion": 17, "mare": 17, + "sheep": 18, "lamb": 18, + "cow": 19, "cattle": 19, "bull": 19, + "elephant": 20, + "bear": 21, "grizzly": 21, + "zebra": 22, + "giraffe": 23, + + # Accessories + "backpack": 24, "bag": 24, "rucksack": 24, + "umbrella": 25, "parasol": 25, + "handbag": 26, "purse": 26, + "tie": 27, "necktie": 27, + "suitcase": 28, "luggage": 28, + + # Sports + "frisbee": 29, + "skis": 30, "ski": 30, + "snowboard": 31, + "sports ball": 32, "ball": 32, "football": 32, "soccer ball": 32, + "kite": 33, + "baseball bat": 34, "bat": 34, + "baseball glove": 35, "glove": 35, + "skateboard": 36, + "surfboard": 37, + "tennis racket": 38, "racket": 38, + + # Kitchen + "bottle": 39, "water bottle": 39, + "wine glass": 40, "glass": 40, + "cup": 41, "mug": 41, "coffee cup": 41, + "fork": 42, + "knife": 43, + "spoon": 44, + "bowl": 45, + + # Food + "banana": 46, + "apple": 47, + "sandwich": 48, + "orange": 49, + "broccoli": 50, + "carrot": 51, + "hot dog": 52, "hotdog": 52, + "pizza": 53, + "donut": 54, "doughnut": 54, + "cake": 55, + + # Furniture + "chair": 56, "seat": 56, + "couch": 57, "sofa": 57, + "potted plant": 58, "plant": 58, "houseplant": 58, + "bed": 59, + "dining table": 60, "table": 60, "desk": 60, + "toilet": 61, + + # Electronics + "tv": 62, "television": 62, "monitor": 62, "screen": 62, + "laptop": 63, "notebook": 63, "computer": 63, + "mouse": 64, "computer mouse": 64, + "remote": 65, "remote control": 65, + "keyboard": 66, + "cell phone": 67, "phone": 67, "mobile": 67, "smartphone": 67, + + # Appliances + "microwave": 68, + "oven": 69, "stove": 69, + "toaster": 70, + "sink": 71, + "refrigerator": 72, "fridge": 72, + + # Other + "book": 73, + "clock": 74, "watch": 74, + "vase": 75, + "scissors": 76, + "teddy bear": 77, "stuffed animal": 77, + "hair drier": 78, "hairdryer": 78, + "toothbrush": 79, +} + +# COCO class names (80 classes) +COCO_CLASSES = [ + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', + 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', + 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', + 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', + 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', + 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', + 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', + 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', + 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' +] + +# Pose keypoint names (COCO format - 17 keypoints) +POSE_KEYPOINTS = [ + 'nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear', + 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow', + 'left_wrist', 'right_wrist', 'left_hip', 'right_hip', + 'left_knee', 'right_knee', 'left_ankle', 'right_ankle' +] + +# Skeleton connections for drawing +SKELETON_CONNECTIONS = [ + (0, 1), (0, 2), (1, 3), (2, 4), # Face + (5, 6), (5, 7), (7, 9), (6, 8), (8, 10), # Arms + (5, 11), (6, 12), (11, 12), # Torso + (11, 13), (13, 15), (12, 14), (14, 16) # Legs +] + +# Keypoint colors (BGR) +KEYPOINT_COLORS = { + 'face': (255, 200, 100), # Light blue for face points + 'left_arm': (0, 255, 0), # Green for left side + 'right_arm': (0, 0, 255), # Red for right side + 'left_leg': (0, 200, 0), # Darker green + 'right_leg': (0, 0, 200), # Darker red + 'torso': (255, 255, 0), # Cyan +} + + +def get_keypoint_color(idx: int) -> Tuple[int, int, int]: + """Get color for a keypoint based on its index.""" + if idx <= 4: + return KEYPOINT_COLORS['face'] + elif idx in [5, 7, 9]: + return KEYPOINT_COLORS['left_arm'] + elif idx in [6, 8, 10]: + return KEYPOINT_COLORS['right_arm'] + elif idx in [11, 13, 15]: + return KEYPOINT_COLORS['left_leg'] + elif idx in [12, 14, 16]: + return KEYPOINT_COLORS['right_leg'] + else: + return KEYPOINT_COLORS['torso'] + + +# ===== DATABASE ===== +class Database: + """SQLite database for storing all command center data.""" + + def __init__(self, db_path: str = None): + if db_path is None: + db_path = os.path.join(os.path.dirname(__file__), 'command_center.db') + self.db_path = db_path + self.lock = threading.Lock() + self._init_db() + + def _get_connection(self) -> sqlite3.Connection: + """Get a thread-local database connection.""" + conn = sqlite3.connect(self.db_path, check_same_thread=False) + conn.row_factory = sqlite3.Row + return conn + + def _init_db(self): + """Initialize database tables.""" + with self._get_connection() as conn: + cursor = conn.cursor() + + # Sessions table - tracks each app run + cursor.execute(''' + CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, + started_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + ended_at TIMESTAMP, + device TEXT, + prompts TEXT, + settings TEXT + ) + ''') + + # Detections table - all detected objects + cursor.execute(''' + CREATE TABLE IF NOT EXISTS detections ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT, + detection_id INTEGER, + persistent_id INTEGER, + label TEXT, + confidence REAL, + box TEXT, + mask_area INTEGER, + timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + frame_number INTEGER, + yolo_class TEXT, + yolo_confidence REAL, + FOREIGN KEY (session_id) REFERENCES sessions(id) + ) + ''') + + # Analysis results from Claude + cursor.execute(''' + CREATE TABLE IF NOT EXISTS analysis_results ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT, + detection_id INTEGER, + label TEXT, + analysis TEXT, + timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + image_data TEXT, + FOREIGN KEY (session_id) REFERENCES sessions(id) + ) + ''') + + # Location memory - where objects are typically found + cursor.execute(''' + CREATE TABLE IF NOT EXISTS location_memory ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + label TEXT NOT NULL, + context TEXT, + position TEXT, + frequency INTEGER DEFAULT 1, + first_seen TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_seen TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(label, context) + ) + ''') + + # Navigation sessions + cursor.execute(''' + CREATE TABLE IF NOT EXISTS navigation_sessions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT, + target_label TEXT, + target_id INTEGER, + started_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + ended_at TIMESTAMP, + reached BOOLEAN DEFAULT FALSE, + path_history TEXT, + scene_context TEXT, + FOREIGN KEY (session_id) REFERENCES sessions(id) + ) + ''') + + # Obstacles detected during navigation + cursor.execute(''' + CREATE TABLE IF NOT EXISTS obstacles ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + navigation_id INTEGER, + label TEXT, + obstacle_type TEXT, + box TEXT, + distance TEXT, + alert_sent BOOLEAN DEFAULT FALSE, + timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (navigation_id) REFERENCES navigation_sessions(id) + ) + ''') + + # Voice queries and results + cursor.execute(''' + CREATE TABLE IF NOT EXISTS voice_queries ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT, + query TEXT, + parsed_prompts TEXT, + was_search BOOLEAN, + was_describe BOOLEAN, + timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (session_id) REFERENCES sessions(id) + ) + ''') + + # General event log + cursor.execute(''' + CREATE TABLE IF NOT EXISTS event_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id TEXT, + event_type TEXT, + level TEXT DEFAULT 'INFO', + message TEXT, + data TEXT, + timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (session_id) REFERENCES sessions(id) + ) + ''') + + # Create indexes for common queries + cursor.execute('CREATE INDEX IF NOT EXISTS idx_detections_session ON detections(session_id)') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_detections_label ON detections(label)') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_location_label ON location_memory(label)') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_obstacles_nav ON obstacles(navigation_id)') + cursor.execute('CREATE INDEX IF NOT EXISTS idx_events_session ON event_log(session_id)') + + conn.commit() + print(f"Database initialized: {self.db_path}") + + # ===== SESSION METHODS ===== + + def create_session(self, device: str, prompts: List[str], settings: Dict) -> str: + """Create a new session and return its ID.""" + session_id = str(uuid.uuid4()) + with self.lock: + with self._get_connection() as conn: + conn.execute( + 'INSERT INTO sessions (id, device, prompts, settings) VALUES (?, ?, ?, ?)', + (session_id, device, json.dumps(prompts), json.dumps(settings)) + ) + conn.commit() + return session_id + + def end_session(self, session_id: str): + """Mark a session as ended.""" + with self.lock: + with self._get_connection() as conn: + conn.execute( + 'UPDATE sessions SET ended_at = CURRENT_TIMESTAMP WHERE id = ?', + (session_id,) + ) + conn.commit() + + # ===== DETECTION METHODS ===== + + def save_detection(self, session_id: str, detection: Dict, frame_number: int): + """Save a detection to the database.""" + with self.lock: + with self._get_connection() as conn: + conn.execute(''' + INSERT INTO detections + (session_id, detection_id, persistent_id, label, confidence, box, mask_area, frame_number, yolo_class, yolo_confidence) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ''', ( + session_id, + detection.get('id'), + detection.get('persistent_id'), + detection.get('label'), + detection.get('confidence'), + json.dumps(detection.get('box')), + detection.get('mask_area'), + frame_number, + detection.get('yolo_class'), + detection.get('yolo_confidence') + )) + conn.commit() + + def save_detections_batch(self, session_id: str, detections: List[Dict], frame_number: int): + """Save multiple detections in a batch.""" + if not detections: + return + with self.lock: + with self._get_connection() as conn: + conn.executemany(''' + INSERT INTO detections + (session_id, detection_id, persistent_id, label, confidence, box, mask_area, frame_number, yolo_class, yolo_confidence) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ''', [( + session_id, + d.get('id'), + d.get('persistent_id'), + d.get('label'), + d.get('confidence'), + json.dumps(d.get('box')), + d.get('mask_area'), + frame_number, + d.get('yolo_class'), + d.get('yolo_confidence') + ) for d in detections]) + conn.commit() + + def get_detection_history(self, session_id: str = None, label: str = None, limit: int = 100) -> List[Dict]: + """Get detection history with optional filters.""" + query = 'SELECT * FROM detections WHERE 1=1' + params = [] + + if session_id: + query += ' AND session_id = ?' + params.append(session_id) + if label: + query += ' AND label LIKE ?' + params.append(f'%{label}%') + + query += ' ORDER BY timestamp DESC LIMIT ?' + params.append(limit) + + with self._get_connection() as conn: + rows = conn.execute(query, params).fetchall() + return [dict(row) for row in rows] + + # ===== ANALYSIS METHODS ===== + + def save_analysis(self, session_id: str, detection_id: int, label: str, analysis: str, image_data: str = None): + """Save Claude analysis result.""" + with self.lock: + with self._get_connection() as conn: + conn.execute(''' + INSERT INTO analysis_results (session_id, detection_id, label, analysis, image_data) + VALUES (?, ?, ?, ?, ?) + ''', (session_id, detection_id, label, analysis, image_data)) + conn.commit() + + def get_analysis_history(self, session_id: str = None, limit: int = 50) -> List[Dict]: + """Get analysis history.""" + query = 'SELECT * FROM analysis_results' + params = [] + + if session_id: + query += ' WHERE session_id = ?' + params.append(session_id) + + query += ' ORDER BY timestamp DESC LIMIT ?' + params.append(limit) + + with self._get_connection() as conn: + rows = conn.execute(query, params).fetchall() + return [dict(row) for row in rows] + + # ===== LOCATION MEMORY METHODS ===== + + def remember_location(self, label: str, context: str, position: Dict = None): + """Remember where an object was found.""" + label_key = label.lower().strip() + context_key = context.lower().strip() if context else "" + + with self.lock: + with self._get_connection() as conn: + # Try to update existing entry + cursor = conn.execute(''' + UPDATE location_memory + SET frequency = frequency + 1, + last_seen = CURRENT_TIMESTAMP, + position = ? + WHERE label = ? AND context = ? + ''', (json.dumps(position) if position else None, label_key, context_key)) + + if cursor.rowcount == 0: + # Insert new entry + conn.execute(''' + INSERT INTO location_memory (label, context, position, frequency) + VALUES (?, ?, ?, 1) + ''', (label_key, context_key, json.dumps(position) if position else None)) + + conn.commit() + + def recall_location(self, label: str) -> Optional[Dict]: + """Recall where an object was typically found.""" + label_key = label.lower().strip() + + with self._get_connection() as conn: + row = conn.execute(''' + SELECT * FROM location_memory + WHERE label = ? + ORDER BY frequency DESC, last_seen DESC + LIMIT 1 + ''', (label_key,)).fetchone() + + if row: + result = dict(row) + if result.get('position'): + result['position'] = json.loads(result['position']) + return result + return None + + def get_all_location_memories(self) -> List[Dict]: + """Get all location memories.""" + with self._get_connection() as conn: + rows = conn.execute(''' + SELECT label, context, frequency, last_seen + FROM location_memory + ORDER BY frequency DESC, last_seen DESC + ''').fetchall() + return [dict(row) for row in rows] + + def clear_location_memory(self, label: str = None): + """Clear location memory for a label or all.""" + with self.lock: + with self._get_connection() as conn: + if label: + conn.execute('DELETE FROM location_memory WHERE label = ?', (label.lower().strip(),)) + else: + conn.execute('DELETE FROM location_memory') + conn.commit() + + # ===== NAVIGATION METHODS ===== + + def start_navigation_session(self, session_id: str, target_label: str, target_id: int = None) -> int: + """Start a new navigation session and return its ID.""" + with self.lock: + with self._get_connection() as conn: + cursor = conn.execute(''' + INSERT INTO navigation_sessions (session_id, target_label, target_id) + VALUES (?, ?, ?) + ''', (session_id, target_label, target_id)) + conn.commit() + return cursor.lastrowid + + def end_navigation_session(self, nav_id: int, reached: bool, path_history: List = None, scene_context: Dict = None): + """End a navigation session.""" + with self.lock: + with self._get_connection() as conn: + conn.execute(''' + UPDATE navigation_sessions + SET ended_at = CURRENT_TIMESTAMP, + reached = ?, + path_history = ?, + scene_context = ? + WHERE id = ? + ''', (reached, json.dumps(path_history), json.dumps(scene_context), nav_id)) + conn.commit() + + def save_obstacle(self, nav_id: int, label: str, obstacle_type: str, box: List, distance: str, alert_sent: bool = False): + """Save an obstacle detected during navigation.""" + with self.lock: + with self._get_connection() as conn: + conn.execute(''' + INSERT INTO obstacles (navigation_id, label, obstacle_type, box, distance, alert_sent) + VALUES (?, ?, ?, ?, ?, ?) + ''', (nav_id, label, obstacle_type, json.dumps(box), distance, alert_sent)) + conn.commit() + + def get_navigation_history(self, session_id: str = None, limit: int = 20) -> List[Dict]: + """Get navigation history.""" + query = 'SELECT * FROM navigation_sessions' + params = [] + + if session_id: + query += ' WHERE session_id = ?' + params.append(session_id) + + query += ' ORDER BY started_at DESC LIMIT ?' + params.append(limit) + + with self._get_connection() as conn: + rows = conn.execute(query, params).fetchall() + return [dict(row) for row in rows] + + # ===== VOICE QUERY METHODS ===== + + def save_voice_query(self, session_id: str, query: str, parsed_prompts: List[str], + was_search: bool = True, was_describe: bool = False): + """Save a voice query.""" + with self.lock: + with self._get_connection() as conn: + conn.execute(''' + INSERT INTO voice_queries (session_id, query, parsed_prompts, was_search, was_describe) + VALUES (?, ?, ?, ?, ?) + ''', (session_id, query, json.dumps(parsed_prompts), was_search, was_describe)) + conn.commit() + + # ===== EVENT LOG METHODS ===== + + def log_event(self, session_id: str, event_type: str, message: str, level: str = 'INFO', data: Dict = None): + """Log an event to the database.""" + with self.lock: + with self._get_connection() as conn: + conn.execute(''' + INSERT INTO event_log (session_id, event_type, level, message, data) + VALUES (?, ?, ?, ?, ?) + ''', (session_id, event_type, level, message, json.dumps(data) if data else None)) + conn.commit() + + def get_event_log(self, session_id: str = None, event_type: str = None, limit: int = 100) -> List[Dict]: + """Get event log with optional filters.""" + query = 'SELECT * FROM event_log WHERE 1=1' + params = [] + + if session_id: + query += ' AND session_id = ?' + params.append(session_id) + if event_type: + query += ' AND event_type = ?' + params.append(event_type) + + query += ' ORDER BY timestamp DESC LIMIT ?' + params.append(limit) + + with self._get_connection() as conn: + rows = conn.execute(query, params).fetchall() + return [dict(row) for row in rows] + + # ===== STATISTICS METHODS ===== + + def get_session_stats(self, session_id: str) -> Dict: + """Get statistics for a session.""" + with self._get_connection() as conn: + stats = {} + + # Detection count + row = conn.execute( + 'SELECT COUNT(*) as count FROM detections WHERE session_id = ?', + (session_id,) + ).fetchone() + stats['total_detections'] = row['count'] if row else 0 + + # Unique labels + rows = conn.execute( + 'SELECT DISTINCT label FROM detections WHERE session_id = ?', + (session_id,) + ).fetchall() + stats['unique_labels'] = [row['label'] for row in rows] + stats['unique_label_count'] = len(stats['unique_labels']) + + # Analysis count + row = conn.execute( + 'SELECT COUNT(*) as count FROM analysis_results WHERE session_id = ?', + (session_id,) + ).fetchone() + stats['total_analyses'] = row['count'] if row else 0 + + # Navigation count + row = conn.execute( + 'SELECT COUNT(*) as count, SUM(CASE WHEN reached THEN 1 ELSE 0 END) as reached FROM navigation_sessions WHERE session_id = ?', + (session_id,) + ).fetchone() + stats['navigation_sessions'] = row['count'] if row else 0 + stats['successful_navigations'] = row['reached'] if row and row['reached'] else 0 + + return stats + + def migrate_from_json(self, location_memory_file: str): + """Migrate existing JSON location memory to SQLite.""" + if not os.path.exists(location_memory_file): + return + + try: + with open(location_memory_file, 'r') as f: + old_memory = json.load(f) + + for label, entries in old_memory.items(): + for entry in entries: + self.remember_location( + label=label, + context=entry.get('context', ''), + position=entry.get('position') + ) + # Update frequency if specified + if entry.get('frequency', 1) > 1: + with self._get_connection() as conn: + conn.execute(''' + UPDATE location_memory + SET frequency = ? + WHERE label = ? AND context = ? + ''', (entry['frequency'], label.lower(), entry.get('context', '').lower())) + conn.commit() + + print(f"Migrated {len(old_memory)} items from JSON to SQLite") + + # Optionally rename old file + backup_path = location_memory_file + '.bak' + os.rename(location_memory_file, backup_path) + print(f"Old JSON file backed up to {backup_path}") + + except Exception as e: + print(f"Error migrating from JSON: {e}") + + +# Global database instance +db = Database() + + +# ===== SMART OBSTACLE DETECTION ===== +# Uses Claude AI to understand context and identify actual obstacles in the path + +def analyze_obstacles_with_claude(image_data: str, target_label: str, target_box: List = None) -> List[Dict]: + """ + Use Claude to intelligently identify obstacles in the user's path. + + This is smarter than a static list because Claude: + 1. Understands what the user is looking for (won't mark it as obstacle) + 2. Understands spatial relationships (what's actually in the path) + 3. Understands environmental context (room type, indoor/outdoor) + 4. Can identify hazards specific to the situation + """ + if not ANTHROPIC_API_KEY: + return [] + + try: + from anthropic import Anthropic + client = Anthropic(api_key=ANTHROPIC_API_KEY) + + # Build context about the target + target_context = f"The user is navigating to find: {target_label}" + if target_box: + # Describe where the target is in the frame + frame_center_x = 320 # Assuming 640 width + target_center_x = (target_box[0] + target_box[2]) / 2 + if target_center_x < frame_center_x - 100: + target_position = "on the left side of the view" + elif target_center_x > frame_center_x + 100: + target_position = "on the right side of the view" + else: + target_position = "ahead in the center of the view" + target_context += f". The {target_label} is currently visible {target_position}." + + prompt = f"""You are helping a visually impaired person navigate to an object. Analyze this image for obstacles. + +{target_context} + +IMPORTANT RULES: +1. The {target_label} is NOT an obstacle - it's the destination +2. Only identify objects that could physically block the path to the {target_label} +3. Focus on objects between the camera/user and the target +4. Consider floor-level hazards (cables, steps, rugs, wet surfaces) +5. Consider objects at body height that could be walked into +6. Ignore objects that are clearly not in the walking path + +For each obstacle you identify, provide: +- name: What the obstacle is (be specific, e.g., "wooden chair" not just "furniture") +- severity: "high" (could cause injury/fall), "medium" (could cause collision), or "low" (minor obstruction) +- position: Where in the frame (left, center, right, floor, ahead) +- distance: How close it appears (very_close, close, medium, far) +- reason: Brief explanation of why it's an obstacle + +Respond in JSON format: +{{ + "environment": "brief description of the space (e.g., living room, hallway, outdoor path)", + "path_clear": true/false, + "obstacles": [ + {{ + "name": "obstacle name", + "severity": "high/medium/low", + "position": "left/center/right/floor", + "distance": "very_close/close/medium/far", + "reason": "why this is in the way" + }} + ], + "safe_direction": "suggestion for safest path if obstacles present" +}} + +If the path appears clear, return an empty obstacles array. +Only include obstacles that are genuinely in the way - don't over-report.""" + + response = client.messages.create( + model="claude-sonnet-4-20250514", + max_tokens=1000, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": image_data + } + }, + { + "type": "text", + "text": prompt + } + ] + } + ] + ) + + # Parse Claude's response + response_text = response.content[0].text + + # Extract JSON from response + import re + json_match = re.search(r'\{[\s\S]*\}', response_text) + if json_match: + result = json.loads(json_match.group()) + + obstacles = [] + for obs in result.get("obstacles", []): + obstacles.append({ + "label": obs.get("name", "unknown obstacle"), + "type": obs.get("severity", "medium"), + "position": obs.get("position", "ahead"), + "distance": obs.get("distance", "medium"), + "reason": obs.get("reason", ""), + "from_claude": True + }) + + # Store environment info + if result.get("environment"): + cc.navigation_context = cc.navigation_context or {} + cc.navigation_context["environment"] = result.get("environment") + cc.navigation_context["path_clear"] = result.get("path_clear", True) + cc.navigation_context["safe_direction"] = result.get("safe_direction") + + return obstacles + + return [] + + except Exception as e: + cc.log(f"Claude obstacle analysis failed: {e}", "ERROR") + return [] + + +# Global state +class CommandCenter: + """Global state manager for the command center.""" + + def __init__(self): + self.lock = threading.Lock() + self.running = False + self.paused = False + + # Detection settings + self.prompts = ["object"] + self.confidence_threshold = 0.3 + self.max_objects_per_prompt = {} # prompt -> max count (None = unlimited) + self.show_all_matches = {} # prompt -> bool (show all even if over limit) + + # Current detection state + self.current_detections = [] # List of detection dicts + self.frame_count = 0 + self.fps = 0.0 + self.device_str = "cpu" + + # Verbose log + self.log_entries = deque(maxlen=100) + + # Claude analysis results + self.analysis_queue = [] # Objects waiting for analysis + self.analysis_results = deque(maxlen=20) # Recent analysis results + self.analyzing = False + + # Frame for streaming + self.current_frame = None # Frame with overlays (for display) + self.current_raw_frame = None # Raw frame without overlays (for analysis) + self.current_frame_jpeg = None + + # Camera and model + self.camera = None + self.processor = None + self.state = None + self.video_predictor = None # SAM3 video predictor for memory tracking + + # Basic tracking state (optical flow) + self.enable_tracking = True + self.skip_frames = 3 + self.last_masks = None + self.last_boxes = None + self.last_scores = None + self.last_labels = None + self.prev_gray = None + + # ===== FEATURE TOGGLES ===== + + # Video Tracking with Memory (SAM3 tracker) + self.enable_memory_tracking = False + self.memory_bank = {} # object_id -> list of mask features + self.memory_max_frames = 10 # Max frames to keep in memory per object + + # Multi-Object Tracking with Persistent IDs + self.enable_persistent_ids = False + self.object_registry = {} # object_id -> {label, first_seen, last_seen, color, ...} + self.next_object_id = 1 + self.iou_threshold = 0.3 # IoU threshold for matching objects + + # Multi-Object Video Tracking + self.tracked_objects = {} # object_id -> tracking state + self.object_colors = {} # object_id -> color + + # Mask Refinement Options + self.enable_fill_holes = False + self.fill_hole_area = 100 # Max hole area to fill (pixels) + self.enable_non_overlap = False # Prevent mask overlaps + self.enable_smooth_edges = False + self.smooth_kernel_size = 5 + + # Advanced Detection Controls + self.enable_boundary_suppression = False + self.boundary_margin = 10 # Pixels from edge to suppress + self.enable_occlusion_suppression = False + self.occlusion_threshold = 0.5 # Overlap ratio to suppress + self.enable_hotstart = False + self.hotstart_frames = 5 # Frames before confirming new detection + self.pending_detections = {} # id -> {frames_seen, detection_data} + + # ===== YOLO FEATURES ===== + self.yolo_classify_model = None + self.yolo_pose_model = None + self.yolo_available = False + + # YOLO Classification + self.enable_yolo_classify = False + self.yolo_classify_threshold = 0.3 + self.yolo_classify_every_n = 1 # Run classification every N keyframes + + # YOLO Pose Estimation + self.enable_yolo_pose = False + self.yolo_pose_threshold = 0.5 + self.show_keypoint_labels = False + self.show_skeleton = True + self.keypoint_radius = 4 + self.skeleton_thickness = 2 + + # Label spoofing (use SAM3->COCO mapping) + self.enable_label_spoofing = True + + # Store pose results + self.last_poses = {} # object_id -> keypoints + + # ===== VOICE SEARCH ===== + self.voice_enabled = True + self.last_voice_query = "" + self.last_parsed_prompts = [] + self.tts_enabled = True + self.tts_voice = "default" + self.voice_feedback_messages = deque(maxlen=10) + + # ===== CAMERA SETTINGS ===== + self.current_camera_id = 0 + self.available_cameras = [] # List of {id, name, description} + self.flip_horizontal = False + self.flip_vertical = False + + # ===== REFERENCE IMAGE SEARCH ===== + self.clip_model = None + self.clip_processor = None + self.clip_available = False + self.reference_image = None # PIL Image + self.reference_embedding = None # CLIP embedding + self.reference_description = None # Text description from Claude + self.visual_match_threshold = 0.75 # Similarity threshold for CLIP matching + self.visual_match_enabled = False # Whether to use CLIP matching + + # ===== GEOMETRIC PROMPTS (Draw to Search) ===== + self.pending_box_prompt = None # (x1, y1, x2, y2) for box prompt + self.pending_point_prompt = None # (x, y) for point prompt + self.draw_mode = None # 'box' or 'point' + + # ===== SESSION TRACKING ===== + self.session_id = None # Current session ID for database + + # ===== NAVIGATION SYSTEM (Accessibility) ===== + self.navigation_active = False + self.navigation_target = None # Target object label + self.navigation_target_id = None # Target detection ID + self.navigation_db_id = None # Navigation session ID in database + self.navigation_start_time = None + self.navigation_last_seen = None # Last position of target + self.navigation_guidance_queue = deque(maxlen=10) # Pending guidance messages + self.navigation_last_guidance = None # Last spoken guidance + self.navigation_last_guidance_time = 0 + self.navigation_guidance_interval = 1.5 # Seconds between guidance + self.navigation_reached = False # Whether target was reached + self.navigation_context = None # Scene context from Claude + + # Navigation spatial tracking + self.navigation_target_history = [] # History of target positions + self.navigation_frame_center = (320, 240) # Frame center (updated dynamically) + self.navigation_proximity_threshold = 0.25 # Object covers 25% of frame = reachable + self.navigation_close_threshold = 0.15 # Getting close + self.navigation_direction_deadzone = 0.1 # Center deadzone + + # ===== OBSTACLE DETECTION ===== + self.obstacle_detection_active = False # Run obstacle detection during navigation + self.current_obstacles = [] # Currently detected obstacles + self.obstacle_alert_cooldown = {} # obstacle_label -> last_alert_time + self.obstacle_alert_interval = 3.0 # Seconds between repeated alerts for same obstacle + self.obstacle_masks = None # Masks for obstacles to render + self.obstacle_boxes = None # Boxes for obstacles + + # ===== LOCATION MEMORY (Now uses SQLite) ===== + self.location_memory_file = os.path.join(os.path.dirname(__file__), '.location_memory.json') + self._migrate_location_memory() + + def _migrate_location_memory(self): + """Migrate old JSON location memory to SQLite if it exists.""" + if os.path.exists(self.location_memory_file): + db.migrate_from_json(self.location_memory_file) + + def remember_location(self, label: str, context: str, position: Dict = None): + """Remember where an object was found (uses SQLite).""" + db.remember_location(label, context, position) + self.log(f"Remembered: {label} found in {context}") + + def recall_location(self, label: str) -> Optional[Dict]: + """Recall where an object was last found (uses SQLite).""" + return db.recall_location(label) + + def get_all_location_memories(self) -> List[Dict]: + """Get all location memories from database.""" + return db.get_all_location_memories() + + def clear_location_memory(self, label: str = None): + """Clear location memory (uses SQLite).""" + db.clear_location_memory(label) + self.log(f"Cleared location memory" + (f" for {label}" if label else "")) + + def _old_recall_location(self, label: str) -> Optional[Dict]: + """Old recall method - kept for reference.""" + label_key = label.lower().strip() + + if label_key not in self.location_memory: + return None + + entries = self.location_memory[label_key] + if not entries: + return None + + # Return most frequent location, or most recent + sorted_entries = sorted(entries, key=lambda x: (x.get("frequency", 1), x.get("timestamp", "")), reverse=True) + return sorted_entries[0] + + def add_navigation_guidance(self, message: str, priority: int = 1): + """Add a guidance message to the queue.""" + with self.lock: + self.navigation_guidance_queue.append({ + "message": message, + "priority": priority, + "timestamp": time.time() + }) + + def get_pending_guidance(self) -> Optional[str]: + """Get the next pending guidance message.""" + with self.lock: + if self.navigation_guidance_queue: + # Get highest priority message + sorted_queue = sorted(self.navigation_guidance_queue, key=lambda x: -x["priority"]) + msg = sorted_queue[0] + self.navigation_guidance_queue.remove(msg) + return msg["message"] + return None + + def add_voice_feedback(self, message: str, msg_type: str = "info"): + """Add a voice feedback message.""" + with self.lock: + self.voice_feedback_messages.append({ + "message": message, + "type": msg_type, + "timestamp": datetime.now().strftime("%H:%M:%S") + }) + + def log(self, message: str, level: str = "INFO"): + """Add a log entry.""" + timestamp = datetime.now().strftime("%H:%M:%S.%f")[:-3] + entry = { + "timestamp": timestamp, + "level": level, + "message": message + } + with self.lock: + self.log_entries.append(entry) + + def get_logs(self, limit: int = 50) -> List[Dict]: + """Get recent log entries.""" + with self.lock: + return list(self.log_entries)[-limit:] + + def add_detection(self, detection: Dict): + """Add a detection to the current list.""" + with self.lock: + self.current_detections.append(detection) + + def clear_detections(self): + """Clear all current detections.""" + with self.lock: + self.current_detections = [] + + def get_filtered_detections(self) -> Tuple[List[Dict], Dict]: + """Get detections filtered by max count settings.""" + with self.lock: + detections = self.current_detections.copy() + + # Group by prompt + by_prompt = {} + for det in detections: + prompt = det.get("label", "unknown") + if prompt not in by_prompt: + by_prompt[prompt] = [] + by_prompt[prompt].append(det) + + # Apply filters + filtered = [] + hidden_counts = {} + + for prompt, dets in by_prompt.items(): + max_count = self.max_objects_per_prompt.get(prompt) + show_all = self.show_all_matches.get(prompt, False) + + if max_count is not None and not show_all: + dets_sorted = sorted(dets, key=lambda d: d.get("confidence", 0), reverse=True) + filtered.extend(dets_sorted[:max_count]) + hidden = len(dets_sorted) - max_count + if hidden > 0: + hidden_counts[prompt] = hidden + else: + filtered.extend(dets) + + return filtered, hidden_counts + + def queue_analysis(self, detection_id: int, image_data: str): + """Queue an object for Claude analysis.""" + with self.lock: + self.analysis_queue.append({ + "id": detection_id, + "image_data": image_data, + "timestamp": datetime.now().isoformat() + }) + + def add_analysis_result(self, detection_id: int, result: str): + """Add a Claude analysis result.""" + with self.lock: + self.analysis_results.append({ + "id": detection_id, + "result": result, + "timestamp": datetime.now().strftime("%H:%M:%S") + }) + + def get_feature_status(self) -> Dict: + """Get status of all feature toggles.""" + return { + "tracking": self.enable_tracking, + "memory_tracking": self.enable_memory_tracking, + "persistent_ids": self.enable_persistent_ids, + "fill_holes": self.enable_fill_holes, + "non_overlap": self.enable_non_overlap, + "smooth_edges": self.enable_smooth_edges, + "boundary_suppression": self.enable_boundary_suppression, + "occlusion_suppression": self.enable_occlusion_suppression, + "hotstart": self.enable_hotstart, + "yolo_classify": self.enable_yolo_classify, + "yolo_pose": self.enable_yolo_pose, + "show_keypoint_labels": self.show_keypoint_labels, + "show_skeleton": self.show_skeleton, + "label_spoofing": self.enable_label_spoofing, + } + + +# Global command center instance +cc = CommandCenter() + + +# Color palette (BGR for OpenCV) +COLORS = [ + (255, 0, 0), # Blue + (0, 255, 0), # Green + (0, 0, 255), # Red + (255, 255, 0), # Cyan + (255, 0, 255), # Magenta + (0, 255, 255), # Yellow + (128, 0, 255), # Purple + (255, 128, 0), # Orange + (128, 255, 0), # Lime + (0, 128, 255), # Sky blue +] + + +def load_yolo_models(): + """Load YOLO models for classification and pose estimation.""" + global cc + + try: + from ultralytics import YOLO + + cc.log("Loading YOLO models...") + + # Model priority: YOLO12 -> YOLO11 -> YOLOv8 + # YOLO12 is newest (Feb 2025) but pretrained weights may not be available for all tasks + cls_models = ['yolo12n-cls.pt', 'yolo11n-cls.pt', 'yolov8n-cls.pt'] + pose_models = ['yolo12n-pose.pt', 'yolo11n-pose.pt', 'yolov8n-pose.pt'] + + # Load classification model + cc.yolo_classify_model = None + for model_name in cls_models: + try: + cc.yolo_classify_model = YOLO(model_name) + cc.log(f"YOLO classification model loaded ({model_name})", "SUCCESS") + break + except Exception as e: + cc.log(f"Could not load {model_name}: {e}", "WARN") + continue + + if cc.yolo_classify_model is None: + cc.log("No classification model available", "WARN") + + # Load pose estimation model + cc.yolo_pose_model = None + for model_name in pose_models: + try: + cc.yolo_pose_model = YOLO(model_name) + cc.log(f"YOLO pose model loaded ({model_name})", "SUCCESS") + break + except Exception as e: + cc.log(f"Could not load {model_name}: {e}", "WARN") + continue + + if cc.yolo_pose_model is None: + cc.log("No pose model available", "WARN") + + cc.yolo_available = cc.yolo_classify_model is not None or cc.yolo_pose_model is not None + + if cc.yolo_available: + cc.log("YOLO models ready", "SUCCESS") + else: + cc.log("No YOLO models available", "WARN") + + except ImportError: + cc.log("ultralytics not installed. YOLO features disabled. Install with: pip install ultralytics", "WARN") + cc.yolo_available = False + + +def load_clip_model(): + """Load CLIP model for visual similarity matching.""" + global cc + + try: + from transformers import CLIPProcessor, CLIPModel + + cc.log("Loading CLIP model for visual matching...") + + # Use a smaller/faster CLIP model + model_name = "openai/clip-vit-base-patch32" + + cc.clip_processor = CLIPProcessor.from_pretrained(model_name) + cc.clip_model = CLIPModel.from_pretrained(model_name) + + # Move to appropriate device + device = get_device() + cc.clip_model = cc.clip_model.to(device) + cc.clip_model.eval() + + cc.clip_available = True + cc.log("CLIP model loaded successfully", "SUCCESS") + + except ImportError: + cc.log("transformers not installed. Visual matching disabled. Install with: pip install transformers", "WARN") + cc.clip_available = False + except Exception as e: + cc.log(f"Failed to load CLIP model: {e}", "ERROR") + cc.clip_available = False + + +def get_clip_embedding(image: Image.Image) -> Optional[torch.Tensor]: + """Get CLIP embedding for an image.""" + global cc + + if not cc.clip_available or cc.clip_model is None: + return None + + try: + device = get_device() + inputs = cc.clip_processor(images=image, return_tensors="pt") + inputs = {k: v.to(device) for k, v in inputs.items()} + + with torch.no_grad(): + embedding = cc.clip_model.get_image_features(**inputs) + # Normalize + embedding = embedding / embedding.norm(dim=-1, keepdim=True) + + return embedding + + except Exception as e: + cc.log(f"Failed to get CLIP embedding: {e}", "ERROR") + return None + + +def compute_clip_similarity(embedding1: torch.Tensor, embedding2: torch.Tensor) -> float: + """Compute cosine similarity between two CLIP embeddings.""" + if embedding1 is None or embedding2 is None: + return 0.0 + + with torch.no_grad(): + similarity = torch.nn.functional.cosine_similarity(embedding1, embedding2) + return float(similarity.item()) + + +def describe_image_with_claude(image_data: str) -> Optional[str]: + """Use Claude to generate a detailed description of an image for search.""" + global ANTHROPIC_API_KEY + + if not ANTHROPIC_API_KEY: + return None + + try: + import anthropic + + client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY) + + message = client.messages.create( + model="claude-sonnet-4-20250514", + max_tokens=200, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": image_data, + }, + }, + { + "type": "text", + "text": "Describe this object concisely for visual detection. Focus on: type of object, color, distinctive features, shape. Return ONLY the description phrase (e.g., 'red baseball cap with white Nike logo', 'black leather handbag with gold clasp'). No other text." + } + ], + } + ], + ) + + return message.content[0].text.strip() + + except Exception as e: + cc.log(f"Failed to describe image with Claude: {e}", "ERROR") + return None + + +# ===== NAVIGATION SYSTEM FUNCTIONS ===== + +def analyze_scene_context(image_data: str) -> Optional[Dict]: + """ + Use Claude to analyze the scene for navigation context. + Returns location type, obstacles, and spatial awareness info. + """ + global ANTHROPIC_API_KEY + + if not ANTHROPIC_API_KEY: + return None + + try: + import anthropic + + client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY) + + message = client.messages.create( + model="claude-sonnet-4-20250514", + max_tokens=300, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": image_data, + }, + }, + { + "type": "text", + "text": """Analyze this scene for navigation assistance. Return JSON only: +{ + "location": "room type (kitchen, living room, bedroom, bathroom, office, hallway, outdoor, etc.)", + "obstacles": ["list of obstacles or hazards visible"], + "surfaces": ["tables, counters, shelves visible"], + "lighting": "bright/dim/dark", + "space": "open/cluttered/narrow", + "landmarks": ["notable items that help orient"] +}""" + } + ], + } + ], + ) + + response_text = message.content[0].text.strip() + + # Parse JSON + if "```json" in response_text: + response_text = response_text.split("```json")[1].split("```")[0].strip() + elif "```" in response_text: + response_text = response_text.split("```")[1].split("```")[0].strip() + + return json.loads(response_text) + + except Exception as e: + cc.log(f"Scene analysis failed: {e}", "WARN") + return None + + +def compute_navigation_guidance(target_box: List[float], frame_shape: Tuple[int, int]) -> Dict: + """ + Compute navigation guidance based on target position in frame. + + Returns: + direction: 'left', 'right', 'center', 'up', 'down' + distance: 'far', 'medium', 'close', 'reachable' + guidance_text: Human-readable guidance + arrow_angle: Angle for AR arrow (degrees) + confidence: How confident we are in the guidance + """ + global cc + + if not target_box: + return { + "direction": "unknown", + "distance": "unknown", + "guidance_text": "Looking for target...", + "arrow_angle": 0, + "confidence": 0 + } + + h, w = frame_shape[:2] + x1, y1, x2, y2 = target_box + + # Object center + obj_center_x = (x1 + x2) / 2 + obj_center_y = (y1 + y2) / 2 + + # Frame center + frame_center_x = w / 2 + frame_center_y = h / 2 + + # Normalized position (-1 to 1, 0 = center) + norm_x = (obj_center_x - frame_center_x) / (w / 2) + norm_y = (obj_center_y - frame_center_y) / (h / 2) + + # Object size relative to frame + obj_width = x2 - x1 + obj_height = y2 - y1 + obj_area_ratio = (obj_width * obj_height) / (w * h) + + # Determine direction + deadzone = cc.navigation_direction_deadzone + + if abs(norm_x) < deadzone and abs(norm_y) < deadzone: + direction = "center" + h_dir = "" + elif abs(norm_x) > abs(norm_y): + direction = "right" if norm_x > 0 else "left" + h_dir = direction + else: + direction = "down" if norm_y > 0 else "up" + h_dir = "" + + # Secondary direction + if direction in ["center"]: + secondary = "" + elif direction in ["left", "right"]: + if norm_y < -deadzone: + secondary = " and up" + elif norm_y > deadzone: + secondary = " and down" + else: + secondary = "" + else: + if norm_x < -deadzone: + secondary = " and left" + elif norm_x > deadzone: + secondary = " and right" + else: + secondary = "" + + # Determine distance based on object size + if obj_area_ratio >= cc.navigation_proximity_threshold: + distance = "reachable" + elif obj_area_ratio >= cc.navigation_close_threshold: + distance = "close" + elif obj_area_ratio >= 0.05: + distance = "medium" + else: + distance = "far" + + # Calculate arrow angle (0 = up, 90 = right, etc.) + import math + arrow_angle = math.degrees(math.atan2(norm_x, -norm_y)) + + # Generate guidance text + if distance == "reachable": + if direction == "center": + guidance_text = "Object is directly in front of you, within reach!" + else: + guidance_text = f"Object is within reach, slightly to the {direction}{secondary}" + elif distance == "close": + if direction == "center": + guidance_text = "Almost there! Object is straight ahead, getting close" + else: + guidance_text = f"Getting close! Turn {direction}{secondary}" + elif distance == "medium": + if direction == "center": + guidance_text = "Keep moving forward, object ahead" + else: + guidance_text = f"Object is to the {direction}{secondary}, move that way" + else: # far + if direction == "center": + guidance_text = "Object detected ahead, continue forward" + else: + guidance_text = f"Object is far to the {direction}{secondary}" + + return { + "direction": direction, + "secondary": secondary.strip(), + "distance": distance, + "guidance_text": guidance_text, + "arrow_angle": arrow_angle, + "norm_x": norm_x, + "norm_y": norm_y, + "obj_area_ratio": obj_area_ratio, + "confidence": min(1.0, obj_area_ratio * 10 + 0.5) # Higher for larger objects + } + + +def get_navigation_status() -> Dict: + """Get current navigation status and guidance.""" + global cc + + if not cc.navigation_active: + return { + "active": False, + "target": None, + "guidance": None + } + + # Find target in current detections + target_detection = None + for det in cc.current_detections: + if det.get("label", "").lower() == cc.navigation_target.lower(): + target_detection = det + break + if cc.navigation_target_id is not None and det.get("id") == cc.navigation_target_id: + target_detection = det + break + + if target_detection: + cc.navigation_last_seen = target_detection + box = target_detection.get("box") + + if cc.current_raw_frame is not None: + frame_shape = cc.current_raw_frame.shape + else: + frame_shape = (480, 640) + + guidance = compute_navigation_guidance(box, frame_shape) + + # Check if reached + if guidance["distance"] == "reachable" and not cc.navigation_reached: + cc.navigation_reached = True + guidance["reached"] = True + guidance["guidance_text"] = f"You've reached the {cc.navigation_target}! It's right in front of you." + + return { + "active": True, + "target": cc.navigation_target, + "target_visible": True, + "target_bbox": box, # For AR path rendering + "guidance": guidance, + "reached": cc.navigation_reached, + "context": cc.navigation_context, + "duration": time.time() - cc.navigation_start_time if cc.navigation_start_time else 0 + } + else: + # Target not currently visible + last_guidance = None + if cc.navigation_last_seen: + box = cc.navigation_last_seen.get("box") + if box: + frame_shape = (480, 640) + if cc.current_raw_frame is not None: + frame_shape = cc.current_raw_frame.shape + last_guidance = compute_navigation_guidance(box, frame_shape) + last_guidance["guidance_text"] = f"Lost sight of {cc.navigation_target}. Last seen to the {last_guidance['direction']}" + + return { + "active": True, + "target": cc.navigation_target, + "target_visible": False, + "guidance": last_guidance or { + "direction": "unknown", + "distance": "unknown", + "guidance_text": f"Looking for {cc.navigation_target}... Turn slowly to scan the area", + "arrow_angle": 0 + }, + "reached": False, + "context": cc.navigation_context, + "searching": True + } + + +def get_coco_class_for_label(sam3_label: str) -> Optional[int]: + """Get COCO class ID for a SAM3 label using the mapping.""" + label_lower = sam3_label.lower().strip() + + # Direct lookup + if label_lower in SAM3_TO_COCO: + return SAM3_TO_COCO[label_lower] + + # Try partial match + for key, coco_id in SAM3_TO_COCO.items(): + if key in label_lower or label_lower in key: + return coco_id + + return None + + +def is_person_label(label: str) -> bool: + """Check if a label refers to a person.""" + coco_id = get_coco_class_for_label(label) + return coco_id == 0 + + +def classify_region(frame: np.ndarray, box: List[float], sam3_label: str) -> Dict: + """ + Run YOLO classification on a detected region. + + Returns dict with: + - yolo_class: Top predicted class name + - yolo_confidence: Confidence score + - top5_classes: List of top 5 (class, confidence) tuples + - matches_sam3: Whether YOLO agrees with SAM3 label + """ + if cc.yolo_classify_model is None: + return None + + try: + # Crop region from frame + x1, y1, x2, y2 = [int(v) for v in box] + h, w = frame.shape[:2] + + # Add padding + pad = 10 + x1 = max(0, x1 - pad) + y1 = max(0, y1 - pad) + x2 = min(w, x2 + pad) + y2 = min(h, y2 + pad) + + crop = frame[y1:y2, x1:x2] + + if crop.size == 0: + return None + + # Run classification + results = cc.yolo_classify_model(crop, verbose=False) + + if len(results) == 0: + return None + + probs = results[0].probs + + if probs is None: + return None + + # Get top 5 predictions + top5_indices = probs.top5 + top5_confs = probs.top5conf.cpu().numpy() + + # Get class names from model + names = cc.yolo_classify_model.names + + top5_classes = [(names[idx], float(conf)) for idx, conf in zip(top5_indices, top5_confs)] + + top_class = top5_classes[0][0] if top5_classes else "unknown" + top_conf = top5_classes[0][1] if top5_classes else 0.0 + + # Check if YOLO agrees with SAM3 + sam3_coco_id = get_coco_class_for_label(sam3_label) + matches = False + + if sam3_coco_id is not None and sam3_coco_id < len(COCO_CLASSES): + sam3_coco_name = COCO_CLASSES[sam3_coco_id] + # Check if any top-5 class matches + for cls_name, conf in top5_classes: + if cls_name.lower() == sam3_coco_name.lower() or sam3_coco_name.lower() in cls_name.lower(): + matches = True + break + + return { + "yolo_class": top_class, + "yolo_confidence": top_conf, + "top5_classes": top5_classes, + "matches_sam3": matches + } + + except Exception as e: + cc.log(f"YOLO classification error: {e}", "ERROR") + return None + + +def estimate_pose(frame: np.ndarray, box: List[float]) -> Dict: + """ + Run YOLO pose estimation on a person region. + + Returns dict with: + - keypoints: List of 17 (x, y, confidence) tuples + - keypoint_names: List of keypoint names + - confidence: Overall pose confidence + """ + if cc.yolo_pose_model is None: + return None + + try: + # Crop region from frame (with extra padding for pose) + x1, y1, x2, y2 = [int(v) for v in box] + h, w = frame.shape[:2] + + # Add generous padding for pose estimation + box_w = x2 - x1 + box_h = y2 - y1 + pad_x = int(box_w * 0.2) + pad_y = int(box_h * 0.1) + + x1 = max(0, x1 - pad_x) + y1 = max(0, y1 - pad_y) + x2 = min(w, x2 + pad_x) + y2 = min(h, y2 + pad_y) + + crop = frame[y1:y2, x1:x2] + + if crop.size == 0: + return None + + # Run pose estimation + results = cc.yolo_pose_model(crop, verbose=False) + + if len(results) == 0 or results[0].keypoints is None: + return None + + keypoints_data = results[0].keypoints + + if keypoints_data.xy is None or len(keypoints_data.xy) == 0: + return None + + # Get first person's keypoints (we're analyzing one box at a time) + kpts = keypoints_data.xy[0].cpu().numpy() # Shape: (17, 2) + confs = keypoints_data.conf[0].cpu().numpy() if keypoints_data.conf is not None else np.ones(17) + + # Convert coordinates back to full frame + keypoints = [] + for i, (pt, conf) in enumerate(zip(kpts, confs)): + # Add offset back to get coordinates in original frame + frame_x = pt[0] + x1 + frame_y = pt[1] + y1 + keypoints.append((float(frame_x), float(frame_y), float(conf))) + + # Calculate overall confidence + valid_confs = [c for x, y, c in keypoints if c > 0.1] + overall_conf = np.mean(valid_confs) if valid_confs else 0.0 + + return { + "keypoints": keypoints, + "keypoint_names": POSE_KEYPOINTS, + "confidence": float(overall_conf), + "box_offset": (x1, y1) # For reference + } + + except Exception as e: + cc.log(f"YOLO pose estimation error: {e}", "ERROR") + return None + + +def draw_pose_overlay(frame: np.ndarray, pose_data: Dict, object_id: int = None) -> np.ndarray: + """Draw pose keypoints and skeleton on frame.""" + if pose_data is None or "keypoints" not in pose_data: + return frame + + overlay = frame.copy() + keypoints = pose_data["keypoints"] + + # Draw skeleton connections first (so points are on top) + if cc.show_skeleton: + for start_idx, end_idx in SKELETON_CONNECTIONS: + if start_idx < len(keypoints) and end_idx < len(keypoints): + x1, y1, c1 = keypoints[start_idx] + x2, y2, c2 = keypoints[end_idx] + + # Only draw if both points have sufficient confidence + if c1 > cc.yolo_pose_threshold and c2 > cc.yolo_pose_threshold: + pt1 = (int(x1), int(y1)) + pt2 = (int(x2), int(y2)) + + # Get color based on connection type + color = get_keypoint_color(start_idx) + cv2.line(overlay, pt1, pt2, color, cc.skeleton_thickness) + + # Draw keypoints + for i, (x, y, conf) in enumerate(keypoints): + if conf > cc.yolo_pose_threshold: + pt = (int(x), int(y)) + color = get_keypoint_color(i) + + # Draw filled circle + cv2.circle(overlay, pt, cc.keypoint_radius, color, -1) + # Draw outline + cv2.circle(overlay, pt, cc.keypoint_radius, (255, 255, 255), 1) + + # Draw label if enabled + if cc.show_keypoint_labels and i < len(POSE_KEYPOINTS): + label = POSE_KEYPOINTS[i].replace('_', ' ') + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.35 + (tw, th), _ = cv2.getTextSize(label, font, font_scale, 1) + + # Position label above point + label_x = int(x - tw/2) + label_y = int(y - cc.keypoint_radius - 3) + + # Background + cv2.rectangle(overlay, + (label_x - 1, label_y - th - 1), + (label_x + tw + 1, label_y + 1), + (0, 0, 0), -1) + + # Text + cv2.putText(overlay, label, (label_x, label_y), + font, font_scale, (255, 255, 255), 1) + + return overlay + + +def load_model(checkpoint_path: Optional[str] = None): + """Load the SAM3 model.""" + from sam3.model_builder import build_sam3_image_model + from sam3.model.sam3_image_processor import Sam3Processor + + cc.log("Loading SAM3 model...") + cc.device_str = get_device_str() + + # Setup device-specific optimizations (MPS memory, CUDA TF32, etc.) + setup_device_optimizations() + cc.log(f"Device optimizations enabled for {cc.device_str}") + + model = build_sam3_image_model( + device=cc.device_str, + checkpoint_path=checkpoint_path, + load_from_HF=checkpoint_path is None, + eval_mode=True, + enable_segmentation=True, + ) + + cc.processor = Sam3Processor( + model=model, + resolution=1008, + device=cc.device_str, + confidence_threshold=cc.confidence_threshold, + ) + + cc.log(f"Model loaded on {cc.device_str}", "SUCCESS") + + # Load YOLO models + load_yolo_models() + + # Load CLIP model for visual matching (optional) + load_clip_model() + + +# ===== CAMERA FUNCTIONS ===== + +def detect_available_cameras(max_cameras: int = 10) -> List[Dict]: + """ + Detect available cameras on the system. + + Returns list of dicts with: + - id: Camera index + - name: Camera name/description + - resolution: (width, height) if detectable + """ + cameras = [] + + for i in range(max_cameras): + cap = cv2.VideoCapture(i) + if cap.isOpened(): + # Get camera properties + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = cap.get(cv2.CAP_PROP_FPS) + + # Try to get backend name + backend = cap.getBackendName() + + # Create descriptive name + if i == 0: + name = "Default Camera" + else: + name = f"Camera {i}" + + # Add platform-specific hints + import platform + if platform.system() == "Darwin": # macOS + if i == 0: + name = "FaceTime HD Camera (Built-in)" + elif i == 1: + name = "External Camera" + elif platform.system() == "Linux": + # Try to read device name from v4l2 + try: + import subprocess + result = subprocess.run( + ['v4l2-ctl', '--device', f'/dev/video{i}', '--info'], + capture_output=True, text=True, timeout=1 + ) + for line in result.stdout.split('\n'): + if 'Card type' in line: + name = line.split(':')[1].strip() + break + except Exception: + pass + + cameras.append({ + "id": i, + "name": name, + "resolution": f"{width}x{height}", + "fps": fps, + "backend": backend, + "description": f"{name} ({width}x{height} @ {fps:.0f}fps)" + }) + + cap.release() + + return cameras + + +def switch_camera(camera_id: int) -> bool: + """Switch to a different camera and reset detection state.""" + global cc + + cc.log(f"Switching to camera {camera_id}...") + + # Release current camera + if cc.camera is not None: + cc.camera.release() + cc.camera = None + + # Open new camera + new_camera = cv2.VideoCapture(camera_id) + + if not new_camera.isOpened(): + cc.log(f"Failed to open camera {camera_id}", "ERROR") + # Try to reopen previous camera + cc.camera = cv2.VideoCapture(cc.current_camera_id) + return False + + cc.camera = new_camera + cc.current_camera_id = camera_id + + # Get camera info + width = int(cc.camera.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cc.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + # Reset detection state + reset_detection_state() + + cc.log(f"Switched to camera {camera_id} ({width}x{height})", "SUCCESS") + return True + + +def reset_detection_state(): + """Reset all detection state for a fresh start.""" + global cc + + cc.state = None + cc.last_masks = None + cc.last_boxes = None + cc.last_scores = None + cc.last_labels = None + cc.tracked_objects = {} + cc.memory_bank = {} + cc.object_colors = {} + cc.next_object_id = 1 + cc.pending_detections = {} + cc.last_poses = {} + cc.prev_gray = None + cc.current_detections = [] + cc.frame_count = 0 + + +# ===== MASK REFINEMENT FUNCTIONS ===== + +def fill_holes_in_mask(mask: np.ndarray, max_hole_area: int = 100) -> np.ndarray: + """Fill small holes in a binary mask.""" + mask_bool = mask.astype(bool) + # Find holes (inverted connected components) + inverted = ~mask_bool + labeled, num_features = ndimage.label(inverted) + + # Fill small holes + for i in range(1, num_features + 1): + hole = labeled == i + if hole.sum() <= max_hole_area: + mask_bool[hole] = True + + return mask_bool.astype(np.float32) + + +def smooth_mask_edges(mask: np.ndarray, kernel_size: int = 5) -> np.ndarray: + """Smooth mask edges using morphological operations.""" + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) + # Close then open to smooth + smoothed = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel) + smoothed = cv2.morphologyEx(smoothed, cv2.MORPH_OPEN, kernel) + return smoothed.astype(np.float32) + + +def remove_mask_overlaps(masks: List[np.ndarray], scores: List[float]) -> List[np.ndarray]: + """Remove overlapping regions, keeping higher confidence masks.""" + if len(masks) <= 1: + return masks + + # Sort by score (highest first) + sorted_indices = np.argsort(scores)[::-1] + result_masks = [None] * len(masks) + occupied = np.zeros_like(masks[0], dtype=bool) + + for idx in sorted_indices: + mask = masks[idx].astype(bool) + # Remove already occupied regions + mask = mask & ~occupied + result_masks[idx] = mask.astype(np.float32) + occupied |= mask + + return result_masks + + +# ===== DETECTION CONTROL FUNCTIONS ===== + +def is_near_boundary(box: List[float], frame_shape: Tuple[int, int], margin: int = 10) -> bool: + """Check if a bounding box is near the frame boundary.""" + h, w = frame_shape[:2] + x1, y1, x2, y2 = box + return x1 < margin or y1 < margin or x2 > w - margin or y2 > h - margin + + +def calculate_iou(mask1: np.ndarray, mask2: np.ndarray) -> float: + """Calculate Intersection over Union between two masks.""" + intersection = np.logical_and(mask1, mask2).sum() + union = np.logical_or(mask1, mask2).sum() + return intersection / union if union > 0 else 0 + + +def match_detection_to_object(mask: np.ndarray, existing_masks: Dict[int, np.ndarray], + threshold: float = 0.3) -> Optional[int]: + """Match a detection to an existing tracked object by IoU.""" + best_match = None + best_iou = threshold + + for obj_id, existing_mask in existing_masks.items(): + iou = calculate_iou(mask, existing_mask) + if iou > best_iou: + best_iou = iou + best_match = obj_id + + return best_match + + +def get_bounding_box_from_mask(mask: np.ndarray) -> Optional[List[float]]: + """Extract bounding box from a binary mask.""" + if mask is None or mask.sum() == 0: + return None + + rows = np.any(mask, axis=1) + cols = np.any(mask, axis=0) + + if not rows.any() or not cols.any(): + return None + + y_min, y_max = np.where(rows)[0][[0, -1]] + x_min, x_max = np.where(cols)[0][[0, -1]] + + return [float(x_min), float(y_min), float(x_max), float(y_max)] + + +def is_mask_valid(mask: np.ndarray, frame_shape: Tuple[int, int], min_area: int = 50, + boundary_margin: int = 5) -> bool: + """ + Check if a tracked mask is still valid. + Returns False if: + - Mask is too small (object left frame) + - Mask is mostly outside the frame boundaries + """ + if mask is None: + return False + + mask_area = mask.sum() + if mask_area < min_area: + return False + + h, w = frame_shape[:2] + + # Check if mask is mostly within frame bounds + if mask.shape != (h, w): + return False + + # Get bounding box + box = get_bounding_box_from_mask(mask) + if box is None: + return False + + x1, y1, x2, y2 = box + + # Check if box is mostly outside frame + if x2 < boundary_margin or x1 > w - boundary_margin: + return False + if y2 < boundary_margin or y1 > h - boundary_margin: + return False + + return True + + +def update_detections_from_tracked_masks(tracked_masks: torch.Tensor, frame_shape: Tuple[int, int]): + """ + Update current_detections based on tracked masks. + Removes detections for masks that are no longer valid (left frame). + Updates bounding boxes for masks that moved. + """ + global cc + + if tracked_masks is None or len(cc.current_detections) == 0: + return + + h, w = frame_shape[:2] + masks_np = tracked_masks.squeeze(1).cpu().numpy() + + updated_detections = [] + valid_mask_indices = [] + + for i, det in enumerate(cc.current_detections): + if i >= len(masks_np): + break + + mask = masks_np[i] + if mask.shape != (h, w): + mask = cv2.resize(mask.astype(np.float32), (w, h)) > 0.5 + + # Check if mask is still valid + if is_mask_valid(mask, frame_shape): + # Update bounding box from tracked mask + new_box = get_bounding_box_from_mask(mask) + if new_box: + det = det.copy() # Don't modify original + det["box"] = new_box + det["tracked"] = True # Mark as being tracked (not fresh detection) + updated_detections.append(det) + valid_mask_indices.append(i) + else: + # Object has left the frame or tracking failed + label = det.get("label", "object") + obj_id = det.get("id", i) + cc.log(f"Object #{obj_id} ({label}) left frame", "INFO") + + # Update global state + with cc.lock: + cc.current_detections = updated_detections + + return valid_mask_indices + + +# ===== MEMORY TRACKING FUNCTIONS ===== + +def update_memory_bank(object_id: int, mask_features: torch.Tensor): + """Update memory bank for an object.""" + if object_id not in cc.memory_bank: + cc.memory_bank[object_id] = [] + + cc.memory_bank[object_id].append(mask_features) + + # Keep only recent frames + if len(cc.memory_bank[object_id]) > cc.memory_max_frames: + cc.memory_bank[object_id].pop(0) + + +# ===== ADVANCED MONOCULAR DEPTH ESTIMATION ===== +# Proprietary: LIDAR-like depth from single RGB camera using AI + +_depth_model = None +_depth_transform = None +_depth_available = False +_depth_device = None + +# Optical flow state for motion-based collision detection +_prev_flow_frame = None +_obstacle_tracking = {} # Track obstacles over time for approach detection + + +def load_depth_model(): + """ + Load monocular depth estimation model. + Provides LIDAR-like depth perception from a single RGB camera. + + Tries models in order of quality: + 1. Depth Anything (state-of-the-art 2024) + 2. MiDaS (widely compatible) + """ + global _depth_model, _depth_transform, _depth_available, _depth_device + + if _depth_available: + return True + + _depth_device = torch.device("cuda" if torch.cuda.is_available() else + "mps" if torch.backends.mps.is_available() else "cpu") + + # Try Depth Anything first (best quality, 2024 state-of-the-art) + try: + from transformers import pipeline + _depth_model = pipeline("depth-estimation", + model="LiheYoung/depth-anything-small-hf", + device=0 if torch.cuda.is_available() else -1) + _depth_available = True + print(f"✓ Loaded Depth Anything for monocular depth estimation") + return True + except Exception as e: + print(f" Depth Anything not available: {e}") + + # Try MiDaS (more compatible) + try: + _depth_model = torch.hub.load("intel-isl/MiDaS", "MiDaS_small", trust_repo=True) + _depth_model.to(_depth_device) + _depth_model.eval() + + midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms", trust_repo=True) + _depth_transform = midas_transforms.small_transform + + _depth_available = True + print(f"✓ Loaded MiDaS for monocular depth estimation on {_depth_device}") + return True + except Exception as e: + print(f" MiDaS not available: {e}") + + print(" No depth model available - using edge-based detection only") + return False + + +def estimate_depth(frame: np.ndarray) -> Optional[np.ndarray]: + """ + Estimate depth map from a single RGB image. + + Returns depth map where HIGHER values = CLOSER to camera. + This mimics LIDAR point cloud distance measurement. + """ + global _depth_model, _depth_transform, _depth_available, _depth_device + + if not _depth_available or _depth_model is None: + return None + + try: + # Depth Anything (pipeline-based) + if hasattr(_depth_model, '__call__') and hasattr(_depth_model, 'task'): + pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + result = _depth_model(pil_image) + depth = np.array(result["depth"]) + + # Resize to match frame + depth = cv2.resize(depth, (frame.shape[1], frame.shape[0])) + + # Normalize and invert (so closer = higher value) + depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-6) + depth = 1.0 - depth # Invert + depth = (depth * 255).astype(np.uint8) + + return depth + + # MiDaS model + else: + img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + input_batch = _depth_transform(img_rgb).to(_depth_device) + + with torch.no_grad(): + prediction = _depth_model(input_batch) + prediction = torch.nn.functional.interpolate( + prediction.unsqueeze(1), + size=frame.shape[:2], + mode="bicubic", + align_corners=False, + ).squeeze() + + depth = prediction.cpu().numpy() + + # Normalize (MiDaS: higher = further, so we invert) + depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-6) + depth = 1.0 - depth # Invert so closer = higher + depth = (depth * 255).astype(np.uint8) + + return depth + + except Exception as e: + cc.log(f"Depth estimation error: {e}", "ERROR") + return None + + +def detect_obstacles_depth(frame: np.ndarray, depth_map: np.ndarray) -> List[Dict]: + """ + Detect obstacles using AI-generated depth map. + + This is MORE ACCURATE than edge detection because it knows actual distance, + not just "there's something there". + """ + if depth_map is None: + return [] + + h, w = frame.shape[:2] + obstacles = [] + + try: + # Focus on walking path (center and bottom of frame) + path_mask = np.zeros_like(depth_map) + path_mask[h // 3:, w // 6:5 * w // 6] = 1 + + path_depth = depth_map * path_mask + + # Thresholds for proximity (calibrated for normalized 0-255 depth) + very_close_thresh = 200 # Within arm's reach + close_thresh = 150 # Few steps away + medium_thresh = 100 # Room distance + + # Find very close obstacles + very_close_mask = (path_depth > very_close_thresh).astype(np.uint8) * 255 + close_mask = ((path_depth > close_thresh) & (path_depth <= very_close_thresh)).astype(np.uint8) * 255 + + # Morphological cleanup + kernel = np.ones((7, 7), np.uint8) + very_close_mask = cv2.morphologyEx(very_close_mask, cv2.MORPH_CLOSE, kernel) + very_close_mask = cv2.morphologyEx(very_close_mask, cv2.MORPH_OPEN, kernel) + + # Find contours for very close obstacles + contours, _ = cv2.findContours(very_close_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + min_area = (h * w) * 0.01 # 1% of frame minimum + + for contour in contours: + area = cv2.contourArea(contour) + if area < min_area: + continue + + x, y, cw, ch = cv2.boundingRect(contour) + + # Get depth stats for this region + region_depth = depth_map[y:y+ch, x:x+cw] + avg_depth = np.mean(region_depth) + max_depth = np.max(region_depth) + + # Classify severity + if max_depth > very_close_thresh: + severity = "high" + distance = "very_close" + elif max_depth > close_thresh: + severity = "medium" + distance = "close" + else: + severity = "low" + distance = "medium" + + # Position classification + center_x = x + cw // 2 + if center_x < w // 3: + position = "left" + elif center_x > 2 * w // 3: + position = "right" + else: + position = "center" + + # Track this obstacle over time + obstacle_id = f"depth_{position}_{int(avg_depth)}" + approach_info = track_obstacle_approach(obstacle_id, avg_depth, position) + + obstacles.append({ + "label": "obstacle (depth)", + "type": severity, + "position": position, + "distance": distance, + "box": [x, y, x + cw, y + ch], + "depth_value": float(avg_depth), + "max_depth": float(max_depth), + "area_pct": float(area / (h * w) * 100), + "approaching": approach_info.get("approaching", False), + "approach_rate": approach_info.get("rate", 0), + "time_to_collision": approach_info.get("ttc"), + "source": "depth_ai", + "reason": f"Depth AI: {avg_depth:.0f}/255 proximity" + }) + + except Exception as e: + cc.log(f"Depth obstacle detection error: {e}", "ERROR") + + return obstacles + + +def detect_collision_optical_flow(frame: np.ndarray) -> List[Dict]: + """ + Detect approaching obstacles using optical flow expansion. + + PROPRIETARY TECHNIQUE: Objects approaching you EXPAND in the frame. + This mimics how flying insects detect and avoid collisions! + + Physics: An object moving toward you at constant speed will appear to + grow larger. The rate of expansion indicates approach speed. + """ + global _prev_flow_frame, _obstacle_tracking + + h, w = frame.shape[:2] + obstacles = [] + + try: + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + gray = cv2.GaussianBlur(gray, (5, 5), 0) + + if _prev_flow_frame is None: + _prev_flow_frame = gray + return [] + + # Dense optical flow (Farneback method) + flow = cv2.calcOpticalFlowFarneback( + _prev_flow_frame, gray, None, + pyr_scale=0.5, levels=3, winsize=15, + iterations=3, poly_n=5, poly_sigma=1.2, flags=0 + ) + + _prev_flow_frame = gray + + # Analyze expansion in different regions + regions = { + "left": (0, h // 4, w // 3, 3 * h // 4), + "center": (w // 3, h // 4, 2 * w // 3, 3 * h // 4), + "right": (2 * w // 3, h // 4, w, 3 * h // 4), + "floor": (w // 4, 2 * h // 3, 3 * w // 4, h), + } + + for region_name, (x1, y1, x2, y2) in regions.items(): + region_flow = flow[y1:y2, x1:x2] + rh, rw = region_flow.shape[:2] + + if rh < 10 or rw < 10: + continue + + fx = region_flow[:, :, 0] + fy = region_flow[:, :, 1] + + # Flow magnitude + magnitude = np.sqrt(fx**2 + fy**2) + avg_magnitude = np.mean(magnitude) + + # Skip if no significant motion + if avg_magnitude < 1.0: + continue + + # Calculate EXPANSION: do flow vectors point outward from center? + # This is the key insight - approaching objects expand! + center_y, center_x = rh // 2, rw // 2 + y_coords, x_coords = np.meshgrid(np.arange(rh) - center_y, + np.arange(rw) - center_x, indexing='ij') + + # Outward direction from center + dist = np.sqrt(x_coords**2 + y_coords**2) + 1e-6 + out_x = x_coords / dist + out_y = y_coords / dist + + # Dot product: positive = expanding (approaching) + expansion = fx * out_x + fy * out_y + avg_expansion = np.mean(expansion) + + # Temporal smoothing + key = f"flow_{region_name}" + if key not in _obstacle_tracking: + _obstacle_tracking[key] = [] + _obstacle_tracking[key].append(avg_expansion) + _obstacle_tracking[key] = _obstacle_tracking[key][-15:] + + smoothed = np.mean(_obstacle_tracking[key]) + + # Threshold for collision warning + if smoothed > 0.8 and avg_magnitude > 1.5: + if smoothed > 2.0: + severity = "high" + distance = "very_close" + elif smoothed > 1.2: + severity = "medium" + distance = "close" + else: + severity = "low" + distance = "medium" + + obstacles.append({ + "label": "approaching", + "type": severity, + "position": region_name, + "distance": distance, + "box": [x1, y1, x2, y2], + "expansion_rate": float(smoothed), + "flow_magnitude": float(avg_magnitude), + "source": "optical_flow", + "reason": f"Motion expansion: {smoothed:.1f}x (collision trajectory)" + }) + + except Exception as e: + cc.log(f"Optical flow error: {e}", "ERROR") + + return obstacles + + +def segment_walkable_ground(frame: np.ndarray, depth_map: np.ndarray = None) -> Dict: + """ + Segment walkable floor area from obstacles. + + PROPRIETARY: Combines color consistency + depth + geometry to find safe walking path. + """ + h, w = frame.shape[:2] + + try: + # Sample ground color from bottom center (assumed floor) + sample = frame[h - 60:h - 20, w // 3:2 * w // 3] + ground_mean = np.mean(sample, axis=(0, 1)) + ground_std = np.std(sample, axis=(0, 1)) + + # Color-based ground mask + lower = np.clip(ground_mean - 2.5 * ground_std, 0, 255).astype(np.uint8) + upper = np.clip(ground_mean + 2.5 * ground_std, 0, 255).astype(np.uint8) + color_mask = cv2.inRange(frame, lower, upper) + + # If depth available, refine with depth consistency + if depth_map is not None: + ground_depth = np.median(depth_map[3 * h // 4:, w // 3:2 * w // 3]) + depth_tolerance = 40 + depth_mask = np.abs(depth_map.astype(float) - ground_depth) < depth_tolerance + combined_mask = cv2.bitwise_and(color_mask, depth_mask.astype(np.uint8) * 255) + else: + combined_mask = color_mask + + # Morphological cleanup + kernel = np.ones((7, 7), np.uint8) + combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_CLOSE, kernel) + combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_OPEN, kernel) + + # Analyze walkability per region + left_walk = np.mean(combined_mask[h // 2:, :w // 3] > 0) + center_walk = np.mean(combined_mask[h // 2:, w // 3:2 * w // 3] > 0) + right_walk = np.mean(combined_mask[h // 2:, 2 * w // 3:] > 0) + + paths = {"left": left_walk, "center": center_walk, "right": right_walk} + best = max(paths, key=paths.get) + worst = min(paths, key=paths.get) + + return { + "walkable_mask": combined_mask, + "left": float(left_walk), + "center": float(center_walk), + "right": float(right_walk), + "best_path": best, + "blocked_path": worst, + "confidence": float(max(paths.values())) + } + + except Exception as e: + cc.log(f"Ground segmentation error: {e}", "ERROR") + return {"best_path": "center", "confidence": 0.0} + + +def track_obstacle_approach(obstacle_id: str, current_depth: float, position: str) -> Dict: + """ + Track obstacle over time to detect if it's getting closer. + + Returns approach rate and estimated time-to-collision. + """ + global _obstacle_tracking + + current_time = time.time() + key = f"approach_{obstacle_id}" + + if key not in _obstacle_tracking: + _obstacle_tracking[key] = {"history": [], "first_seen": current_time} + + _obstacle_tracking[key]["history"].append({ + "time": current_time, + "depth": current_depth + }) + + # Keep last 30 readings + _obstacle_tracking[key]["history"] = _obstacle_tracking[key]["history"][-30:] + + history = _obstacle_tracking[key]["history"] + + if len(history) < 5: + return {"approaching": False, "rate": 0, "ttc": None} + + # Calculate approach rate + time_span = history[-1]["time"] - history[0]["time"] + if time_span < 0.1: + return {"approaching": False, "rate": 0, "ttc": None} + + depth_change = history[-1]["depth"] - history[0]["depth"] + rate = depth_change / time_span + + # Positive rate = getting closer (depth increasing) + approaching = rate > 3 + + # Time to collision estimate + ttc = None + if approaching and rate > 0: + remaining = 255 - history[-1]["depth"] + ttc = remaining / rate if rate > 0 else None + + return { + "approaching": approaching, + "rate": float(rate), + "ttc": float(ttc) if ttc and ttc > 0 else None + } + + +# ===== OBSTACLE DETECTION ===== + +# Timing control for Claude obstacle analysis (don't call too frequently) +_last_obstacle_analysis_time = 0 +_obstacle_analysis_interval = 3.0 # Seconds between Claude calls +_cached_obstacles = [] +_cached_opencv_obstacles = [] + + +# ===== OPENCV REAL-TIME OBSTACLE DETECTION ===== +# Fast, runs every frame - detects "something is there" +# Complements Claude AI which understands "what is it and is it dangerous" + +def detect_obstacles_opencv(frame: np.ndarray) -> List[Dict]: + """ + Real-time obstacle detection using OpenCV techniques. + + This runs every frame and detects: + 1. Large objects/edges in the path (via Canny edge detection) + 2. Proximity based on edge density in regions + 3. Floor-level obstacles (via bottom-region analysis) + + This is FAST but DUMB - it detects "something is there" but doesn't know what. + Claude AI provides the smart context about whether it's actually dangerous. + """ + global cc + + h, w = frame.shape[:2] + obstacles = [] + + try: + # Convert to grayscale + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + + # Apply bilateral filter to reduce noise while preserving edges + # This is key for obstacle detection - keeps edges sharp + filtered = cv2.bilateralFilter(gray, 9, 75, 75) + + # Canny edge detection + edges = cv2.Canny(filtered, 50, 150) + + # Dilate edges to connect nearby edges + kernel = np.ones((3, 3), np.uint8) + edges_dilated = cv2.dilate(edges, kernel, iterations=2) + + # Define regions of interest (ROI) for obstacle detection + # Focus on center and bottom of frame (where obstacles matter for walking) + regions = { + "center_close": (w // 4, h // 2, 3 * w // 4, h - 50), # Center, bottom half + "left_path": (0, h // 2, w // 4, h - 50), # Left side path + "right_path": (3 * w // 4, h // 2, w, h - 50), # Right side path + "floor_immediate": (w // 6, 2 * h // 3, 5 * w // 6, h), # Immediate floor area + } + + for region_name, (x1, y1, x2, y2) in regions.items(): + # Extract region + roi = edges_dilated[y1:y2, x1:x2] + + if roi.size == 0: + continue + + # Calculate edge density in this region + edge_density = np.sum(roi > 0) / roi.size + + # Find contours in this region + contours, _ = cv2.findContours(roi, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + # Filter significant contours (large enough to be obstacles) + min_contour_area = (x2 - x1) * (y2 - y1) * 0.05 # At least 5% of region + significant_contours = [c for c in contours if cv2.contourArea(c) > min_contour_area] + + # Determine if this region has an obstacle + # High edge density + significant contours = likely obstacle + if edge_density > 0.15 and len(significant_contours) > 0: + # Estimate proximity based on position in frame + # Objects lower in frame = closer + vertical_position = (y1 + y2) / 2 / h + + if vertical_position > 0.8: # Very low in frame + distance = "very_close" + severity = "high" + elif vertical_position > 0.65: + distance = "close" + severity = "medium" + else: + distance = "medium" + severity = "low" + + # Map region to position + if "left" in region_name: + position = "left" + elif "right" in region_name: + position = "right" + elif "floor" in region_name: + position = "floor" + else: + position = "center" + + # Get the largest contour for this obstacle + largest_contour = max(significant_contours, key=cv2.contourArea) + contour_box = cv2.boundingRect(largest_contour) + + # Adjust box coordinates to full frame + box = [ + x1 + contour_box[0], + y1 + contour_box[1], + x1 + contour_box[0] + contour_box[2], + y1 + contour_box[1] + contour_box[3] + ] + + obstacles.append({ + "label": f"obstacle ({position})", + "type": severity, + "position": position, + "distance": distance, + "box": box, + "edge_density": edge_density, + "contour_count": len(significant_contours), + "source": "opencv", + "reason": f"Edge detection: {edge_density:.0%} density" + }) + + # Also check for sudden large objects in center (collision imminent) + center_roi = edges_dilated[h // 3:, w // 4:3 * w // 4] + center_density = np.sum(center_roi > 0) / center_roi.size + + if center_density > 0.25: # Very high edge density in center + # This suggests a large object directly ahead + obstacles.append({ + "label": "large obstacle ahead", + "type": "high", + "position": "center", + "distance": "close", + "box": [w // 4, h // 3, 3 * w // 4, h], + "edge_density": center_density, + "source": "opencv", + "reason": "High edge density directly ahead - possible collision" + }) + + except Exception as e: + cc.log(f"OpenCV obstacle detection error: {e}", "ERROR") + + return obstacles + + +def analyze_floor_clearance(frame: np.ndarray) -> Dict: + """ + Analyze if the immediate floor area is clear for walking. + + Uses color consistency and edge analysis of the floor region + to detect trip hazards, steps, or objects on the ground. + """ + h, w = frame.shape[:2] + + # Focus on bottom third of frame (floor area) + floor_region = frame[2 * h // 3:, :] + + # Convert to grayscale + gray = cv2.cvtColor(floor_region, cv2.COLOR_BGR2GRAY) + + # Calculate standard deviation - uniform floor has low std dev + std_dev = np.std(gray) + + # Edge detection on floor + edges = cv2.Canny(gray, 30, 100) + edge_ratio = np.sum(edges > 0) / edges.size + + # Analyze left, center, right paths + third = w // 3 + left_edges = np.sum(edges[:, :third] > 0) / (edges[:, :third].size + 1) + center_edges = np.sum(edges[:, third:2*third] > 0) / (edges[:, third:2*third].size + 1) + right_edges = np.sum(edges[:, 2*third:] > 0) / (edges[:, 2*third:].size + 1) + + # Determine clearest path + paths = {"left": left_edges, "center": center_edges, "right": right_edges} + clearest = min(paths, key=paths.get) + + return { + "floor_uniformity": 1.0 - min(std_dev / 80, 1.0), # Higher = more uniform + "edge_ratio": edge_ratio, + "path_analysis": paths, + "suggested_path": clearest, + "floor_clear": edge_ratio < 0.1 and std_dev < 40 + } + + +def detect_obstacles(frame: np.ndarray, pil_image: Image.Image) -> List[Dict]: + """ + PROPRIETARY 4-LAYER OBSTACLE DETECTION SYSTEM + + Combines multiple techniques for comprehensive obstacle detection + using only a single RGB camera (no LIDAR/radar needed): + + Layer 1: OpenCV Edge Detection (every frame, ~20ms) + - Canny edges, contours, bilateral filtering + - Immediate response for sudden obstacles + + Layer 2: AI Depth Estimation (every frame if available, ~50ms) + - MiDaS or Depth Anything for LIDAR-like depth + - Knows actual distance, not just "something is there" + + Layer 3: Optical Flow Collision Detection (every frame, ~30ms) + - Detects APPROACHING objects via motion expansion + - Biomimetic: same technique insects use! + + Layer 4: Claude AI Analysis (every 3 seconds, ~1-2s) + - Semantic understanding of obstacles + - Knows target is NOT an obstacle + - Explains WHY something is dangerous + + Plus: Ground Plane Segmentation, Temporal Tracking, Time-to-Collision + """ + global cc, _last_obstacle_analysis_time, _cached_obstacles, _cached_opencv_obstacles, _depth_available + + if not cc.obstacle_detection_active: + return [] + + current_time = time.time() + all_obstacles = [] + + # ===== LAYER 1: OpenCV Edge Detection (every frame) ===== + # Fast, detects "something is there" + opencv_obstacles = detect_obstacles_opencv(frame) + + for obs in opencv_obstacles: + if obs["distance"] in ["very_close", "close"] and obs.get("edge_density", 0) > 0.2: + obs["timestamp"] = current_time + cooldown_key = f"opencv_{obs['position']}_{obs['distance']}" + last_alert = cc.obstacle_alert_cooldown.get(cooldown_key, 0) + + if current_time - last_alert > 2.0: + obs["should_alert"] = True + cc.obstacle_alert_cooldown[cooldown_key] = current_time + else: + obs["should_alert"] = False + + all_obstacles.append(obs) + + _cached_opencv_obstacles = opencv_obstacles + + # ===== LAYER 2: AI Depth Estimation (LIDAR-like) ===== + # Uses MiDaS or Depth Anything for real distance measurement + depth_map = None + if _depth_available: + depth_map = estimate_depth(frame) + if depth_map is not None: + depth_obstacles = detect_obstacles_depth(frame, depth_map) + + for obs in depth_obstacles: + obs["timestamp"] = current_time + cooldown_key = f"depth_{obs['position']}_{obs['distance']}" + last_alert = cc.obstacle_alert_cooldown.get(cooldown_key, 0) + + # Depth is more reliable, use slightly shorter cooldown + if current_time - last_alert > 1.5: + obs["should_alert"] = True + cc.obstacle_alert_cooldown[cooldown_key] = current_time + + # Alert on approaching objects with TTC + if obs.get("approaching") and obs.get("time_to_collision"): + ttc = obs["time_to_collision"] + if ttc < 2.0: + obs["type"] = "high" + obs["reason"] = f"Approaching! {ttc:.1f}s to collision" + cc.log(f"COLLISION WARNING: {obs['label']} in {ttc:.1f}s", "ERROR") + else: + obs["should_alert"] = False + + all_obstacles.append(obs) + + # ===== LAYER 3: Optical Flow Collision Detection ===== + # Biomimetic: detects approaching objects via expansion + flow_obstacles = detect_collision_optical_flow(frame) + + for obs in flow_obstacles: + obs["timestamp"] = current_time + cooldown_key = f"flow_{obs['position']}" + last_alert = cc.obstacle_alert_cooldown.get(cooldown_key, 0) + + if current_time - last_alert > 1.5 and obs.get("expansion_rate", 0) > 1.0: + obs["should_alert"] = True + cc.obstacle_alert_cooldown[cooldown_key] = current_time + cc.log(f"MOTION: {obs['label']} expanding at {obs['expansion_rate']:.1f}x", "WARN") + else: + obs["should_alert"] = False + + all_obstacles.append(obs) + + # ===== Ground Plane & Walkable Path Analysis ===== + walkable = segment_walkable_ground(frame, depth_map) + cc.navigation_context = cc.navigation_context or {} + cc.navigation_context["walkable"] = walkable + cc.navigation_context["best_path"] = walkable.get("best_path", "center") + + # Also run simpler floor analysis + floor_analysis = analyze_floor_clearance(frame) + if not floor_analysis["floor_clear"]: + cc.navigation_context["floor_analysis"] = floor_analysis + cc.navigation_context["suggested_path"] = floor_analysis["suggested_path"] + + # ===== LAYER 4: Claude AI Analysis (every few seconds) ===== + # Smart contextual understanding + if current_time - _last_obstacle_analysis_time >= _obstacle_analysis_interval: + try: + # Encode frame for Claude + _, buffer = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 70]) + image_data = base64.b64encode(buffer).decode('utf-8') + + # Get target box if available (for spatial context) + target_box = None + if cc.navigation_target: + for det in cc.current_detections: + if det.get("label", "").lower() == cc.navigation_target.lower(): + target_box = det.get("box") + break + + # Call Claude for intelligent obstacle analysis + claude_obstacles = analyze_obstacles_with_claude( + image_data, + cc.navigation_target or "the object", + target_box + ) + + _last_obstacle_analysis_time = current_time + + # Process Claude's obstacles + for obs in claude_obstacles: + obstacle = { + "label": obs["label"], + "type": obs["type"], + "position": obs.get("position", "ahead"), + "distance": obs["distance"], + "reason": obs.get("reason", ""), + "timestamp": current_time, + "box": None, + "mask": None, + "source": "claude" + } + + # Check cooldown for alerts + cooldown_key = f"claude_{obs['label']}_{obs['distance']}" + last_alert = cc.obstacle_alert_cooldown.get(cooldown_key, 0) + + if current_time - last_alert > cc.obstacle_alert_interval: + obstacle["should_alert"] = True + cc.obstacle_alert_cooldown[cooldown_key] = current_time + + # Log the obstacle with reason + cc.log(f"OBSTACLE: {obs['label']} ({obs['distance']}) - {obs.get('reason', '')}", "WARN") + + # Save to database + if cc.navigation_db_id: + db.save_obstacle( + cc.navigation_db_id, + obs["label"], + obs["type"], + [], + obs["distance"], + alert_sent=True + ) + else: + obstacle["should_alert"] = False + + all_obstacles.append(obstacle) + + # Log safe direction if available + if cc.navigation_context and cc.navigation_context.get("safe_direction"): + cc.log(f"Safe path: {cc.navigation_context['safe_direction']}", "INFO") + + _cached_obstacles = [o for o in all_obstacles if o.get("source") == "claude"] + + except Exception as e: + cc.log(f"Claude obstacle analysis error: {e}", "ERROR") + # Fall back to cached Claude results + all_obstacles.extend(_cached_obstacles) + + else: + # Use cached Claude results between API calls + all_obstacles.extend(_cached_obstacles) + + return all_obstacles + + +def get_obstacle_segmentation(frame: np.ndarray, obstacle_label: str) -> Optional[np.ndarray]: + """ + Optional: Get SAM3 segmentation mask for an obstacle identified by Claude. + This can be used if we want to visually highlight the obstacle. + """ + global cc + + if cc.processor is None: + return None + + try: + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_image = Image.fromarray(frame_rgb) + + state = cc.processor.set_image(pil_image, {}) + state = cc.processor.set_text_prompt(obstacle_label, state) + + masks = state.get("masks") + if masks is not None and masks.numel() > 0: + return masks[0].squeeze().cpu().numpy() + + except Exception as e: + cc.log(f"Obstacle segmentation failed: {e}", "ERROR") + + return None + + +def overlay_obstacles(display: np.ndarray, obstacles: List[Dict]) -> np.ndarray: + """ + Overlay obstacle alerts on the display frame. + + Handles both: + - OpenCV obstacles (have precise bounding boxes from edge detection) + - Claude obstacles (have position-based info like left/center/right) + """ + if not obstacles: + return display + + h, w = display.shape[:2] + + # Obstacle color (orange/red based on severity) + colors = { + "high": (0, 0, 255), # Red + "medium": (0, 165, 255), # Orange + "low": (0, 255, 255) # Yellow + } + + # Position to screen region mapping (for Claude obstacles without boxes) + position_regions = { + "left": (10, h // 3, w // 3, 2 * h // 3), + "center": (w // 3, h // 3, 2 * w // 3, 2 * h // 3), + "right": (2 * w // 3, h // 3, w - 10, 2 * h // 3), + "floor": (w // 4, 2 * h // 3, 3 * w // 4, h - 10), + "ahead": (w // 4, h // 4, 3 * w // 4, 3 * h // 4), + } + + for i, obstacle in enumerate(obstacles): + severity = obstacle.get("type", "medium") + label = obstacle.get("label", "Obstacle") + distance = obstacle.get("distance", "medium") + position = obstacle.get("position", "ahead") + reason = obstacle.get("reason", "") + source = obstacle.get("source", "unknown") + box = obstacle.get("box") + + color = colors.get(severity, (0, 165, 255)) + + # Determine region - use box if available (OpenCV), otherwise use position (Claude) + if box and len(box) == 4 and all(v is not None for v in box): + # OpenCV obstacle with precise box + rx1, ry1, rx2, ry2 = [int(v) for v in box] + + # Draw bounding box with dashed lines for OpenCV detections + if source == "opencv": + # Dashed rectangle effect + for j in range(rx1, rx2, 10): + cv2.line(display, (j, ry1), (min(j + 5, rx2), ry1), color, 2) + cv2.line(display, (j, ry2), (min(j + 5, rx2), ry2), color, 2) + for j in range(ry1, ry2, 10): + cv2.line(display, (rx1, j), (rx1, min(j + 5, ry2)), color, 2) + cv2.line(display, (rx2, j), (rx2, min(j + 5, ry2)), color, 2) + else: + cv2.rectangle(display, (rx1, ry1), (rx2, ry2), color, 2) + else: + # Claude obstacle - use position-based region + region = position_regions.get(position, position_regions["ahead"]) + rx1, ry1, rx2, ry2 = region + + # Draw semi-transparent warning zone for close obstacles + if distance in ["very_close", "close"]: + overlay = display.copy() + alpha = 0.25 if severity == "high" else 0.15 + cv2.rectangle(overlay, (rx1, ry1), (rx2, ry2), color, -1) + display = cv2.addWeighted(overlay, alpha, display, 1 - alpha, 0) + + # Draw thick border + cv2.rectangle(display, (rx1, ry1), (rx2, ry2), color, 3) + + # Draw warning icon + icon_size = 35 if severity == "high" else 25 + icon_x = (rx1 + rx2) // 2 - icon_size // 2 + icon_y = max(ry1 - icon_size - 5, 5) + + # Warning triangle + triangle = np.array([ + [icon_x + icon_size // 2, icon_y], + [icon_x, icon_y + icon_size], + [icon_x + icon_size, icon_y + icon_size] + ], np.int32) + cv2.fillPoly(display, [triangle], color) + cv2.polylines(display, [triangle], True, (0, 0, 0), 2) + + # Exclamation mark + cv2.line(display, (icon_x + icon_size // 2, icon_y + 8), + (icon_x + icon_size // 2, icon_y + icon_size - 12), (0, 0, 0), 2) + cv2.circle(display, (icon_x + icon_size // 2, icon_y + icon_size - 6), 2, (0, 0, 0), -1) + + # Label text + if distance in ["very_close"]: + label_text = f"STOP! {label}" + elif distance == "close": + label_text = f"WARNING: {label}" + else: + label_text = f"CAUTION: {label}" + + # Add source indicator for debugging + if source == "opencv": + label_text += " [CV]" + + text_x = rx1 + 5 + text_y = ry2 + 20 if ry2 + 25 < h else ry1 - 40 + + # Text with background + (text_w, text_h), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.55, 2) + cv2.rectangle(display, (text_x - 2, text_y - text_h - 3), + (text_x + text_w + 2, text_y + 3), (0, 0, 0), -1) + cv2.putText(display, label_text, (text_x, text_y), + cv2.FONT_HERSHEY_SIMPLEX, 0.55, color, 2) + + # Distance text + text_y += 18 + distance_text = distance.replace("_", " ").upper() + cv2.putText(display, distance_text, (text_x, text_y), + cv2.FONT_HERSHEY_SIMPLEX, 0.45, (255, 255, 255), 1) + + # Reason (if from Claude) + if reason and source == "claude" and len(reason) < 40: + text_y += 16 + cv2.putText(display, reason, (text_x, text_y), + cv2.FONT_HERSHEY_SIMPLEX, 0.35, (200, 200, 200), 1) + + # Draw path status indicator + if cc.navigation_context: + if cc.navigation_context.get("path_clear"): + cv2.putText(display, "PATH CLEAR", (w // 2 - 60, 30), + cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2) + elif cc.navigation_context.get("safe_direction"): + safe_text = f"Go: {cc.navigation_context['safe_direction']}" + cv2.putText(display, safe_text, (10, h - 20), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) + elif cc.navigation_context.get("suggested_path"): + # From floor analysis + path_text = f"Clearest path: {cc.navigation_context['suggested_path']}" + cv2.putText(display, path_text, (10, h - 20), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 200, 255), 1) + + return display + + +# ===== FRAME PROCESSING ===== + +def process_frame(frame: np.ndarray) -> np.ndarray: + """Process a frame through SAM3 and overlay results.""" + global cc + + cc.frame_count += 1 + is_keyframe = cc.frame_count % cc.skip_frames == 0 + + # Handle geometric prompts (draw to search) + if cc.pending_box_prompt is not None or cc.pending_point_prompt is not None: + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_image = Image.fromarray(frame_rgb) + + cc.state = cc.processor.set_image(pil_image, cc.state) + + if cc.pending_box_prompt is not None: + # Box prompt + x1, y1, x2, y2 = cc.pending_box_prompt + cc.state["geometric_prompt"] = { + "type": "box", + "box": [x1, y1, x2, y2] + } + cc.log(f"Processing box prompt: ({x1:.0f},{y1:.0f}) to ({x2:.0f},{y2:.0f})") + + elif cc.pending_point_prompt is not None: + # Point prompt + x, y = cc.pending_point_prompt + cc.state["geometric_prompt"] = { + "type": "point", + "point": [x, y], + "label": 1 # 1 = foreground, 0 = background + } + cc.log(f"Processing point prompt: ({x:.0f},{y:.0f})") + + # Get mask from geometric prompt + try: + # Use the processor's segment method with geometric prompt + masks = cc.state.get("masks") + if masks is not None and len(masks) > 0: + mask_np = masks[0].squeeze().cpu().numpy() + box = get_bounding_box_from_mask(mask_np) + + cc.last_masks = masks[:1] + cc.last_boxes = torch.tensor([box]) if box else None + cc.last_scores = torch.tensor([1.0]) + cc.last_labels = ["selected object"] + + # Add to detections + with cc.lock: + cc.current_detections = [{ + "id": 0, + "label": "selected object", + "confidence": 1.0, + "box": box, + "tracked": False, + }] + + cc.log("Object selected via drawing", "SUCCESS") + + except Exception as e: + cc.log(f"Geometric prompt failed: {e}", "ERROR") + + # Clear the pending prompts + cc.pending_box_prompt = None + cc.pending_point_prompt = None + cc.draw_mode = None + cc.prev_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + + elif is_keyframe and not cc.paused: + # Full inference + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_image = Image.fromarray(frame_rgb) + + cc.state = cc.processor.set_image(pil_image, cc.state) + + # Build new detections list (don't clear until we have new ones) + new_detections = [] + cc.last_poses = {} + + all_masks = [] + all_boxes = [] + all_scores = [] + all_labels = [] + all_object_ids = [] + + for prompt in cc.prompts: + if "geometric_prompt" in cc.state: + del cc.state["geometric_prompt"] + + cc.state = cc.processor.set_text_prompt(prompt.strip(), cc.state) + + masks = cc.state.get("masks") + boxes = cc.state.get("boxes") + scores = cc.state.get("scores") + + if masks is not None and masks.numel() > 0: + for i in range(len(masks)): + mask_np = masks[i].squeeze().cpu().numpy() + box = boxes[i].cpu().numpy().tolist() if boxes is not None and i < len(boxes) else None + score = float(scores[i].cpu()) if scores is not None and i < len(scores) else 0.0 + + # Boundary suppression + if cc.enable_boundary_suppression and box: + if is_near_boundary(box, frame.shape, cc.boundary_margin): + continue + + # Hotstart + if cc.enable_hotstart: + det_hash = f"{prompt}_{int(box[0]) if box else 0}_{int(box[1]) if box else 0}" + if det_hash not in cc.pending_detections: + cc.pending_detections[det_hash] = {"frames": 1, "data": None} + continue + else: + cc.pending_detections[det_hash]["frames"] += 1 + if cc.pending_detections[det_hash]["frames"] < cc.hotstart_frames: + continue + del cc.pending_detections[det_hash] + + # Fill holes + if cc.enable_fill_holes: + mask_np = fill_holes_in_mask(mask_np, cc.fill_hole_area) + + # Smooth edges + if cc.enable_smooth_edges: + mask_np = smooth_mask_edges(mask_np, cc.smooth_kernel_size) + + # Persistent object IDs + object_id = len(all_masks) + if cc.enable_persistent_ids: + if cc.tracked_objects: + match_id = match_detection_to_object( + mask_np, + {oid: obj["last_mask"] for oid, obj in cc.tracked_objects.items() + if "last_mask" in obj}, + cc.iou_threshold + ) + if match_id is not None: + object_id = match_id + else: + object_id = cc.next_object_id + cc.next_object_id += 1 + + if object_id not in cc.tracked_objects: + cc.tracked_objects[object_id] = { + "label": prompt.strip(), + "first_seen": cc.frame_count, + "color": COLORS[object_id % len(COLORS)], + } + cc.object_colors[object_id] = COLORS[object_id % len(COLORS)] + + cc.tracked_objects[object_id]["last_seen"] = cc.frame_count + cc.tracked_objects[object_id]["last_mask"] = mask_np + cc.tracked_objects[object_id]["confidence"] = score + + # Memory tracking + if cc.enable_memory_tracking: + mask_tensor = torch.from_numpy(mask_np).unsqueeze(0) + update_memory_bank(object_id, mask_tensor) + + # ===== YOLO INTEGRATION ===== + yolo_info = {} + + # YOLO Classification + if cc.enable_yolo_classify and box and cc.yolo_classify_model is not None: + if cc.frame_count % cc.yolo_classify_every_n == 0: + classify_result = classify_region(frame, box, prompt.strip()) + if classify_result: + yolo_info["classify"] = classify_result + if classify_result["yolo_confidence"] >= cc.yolo_classify_threshold: + cc.log(f"YOLO: {classify_result['yolo_class']} ({classify_result['yolo_confidence']:.0%})") + + # YOLO Pose Estimation (only for person-like labels) + if cc.enable_yolo_pose and box and cc.yolo_pose_model is not None: + if is_person_label(prompt.strip()): + pose_result = estimate_pose(frame, box) + if pose_result and pose_result["confidence"] >= cc.yolo_pose_threshold: + yolo_info["pose"] = pose_result + cc.last_poses[object_id] = pose_result + cc.log(f"Pose detected for {prompt} (conf: {pose_result['confidence']:.0%})") + + detection = { + "id": object_id, + "label": prompt.strip(), + "confidence": score, + "box": box, + "persistent_id": object_id if cc.enable_persistent_ids else None, + "yolo": yolo_info if yolo_info else None, + "tracked": False, # Fresh detection from SAM3 + } + new_detections.append(detection) + + all_masks.append(mask_np) + all_object_ids.append(object_id) + if box: + all_boxes.append(box) + all_scores.append(score) + all_labels.append(prompt.strip()) + + # Remove overlapping masks + if cc.enable_non_overlap and len(all_masks) > 1: + all_masks = remove_mask_overlaps(all_masks, all_scores) + + # Occlusion suppression + if cc.enable_occlusion_suppression and len(all_masks) > 1: + keep_indices = [] + for i, mask_i in enumerate(all_masks): + is_occluded = False + for j, mask_j in enumerate(all_masks): + if i != j and all_scores[j] > all_scores[i]: + overlap = np.logical_and(mask_i, mask_j).sum() / (mask_i.sum() + 1e-6) + if overlap > cc.occlusion_threshold: + is_occluded = True + break + if not is_occluded: + keep_indices.append(i) + + all_masks = [all_masks[i] for i in keep_indices] + all_boxes = [all_boxes[i] for i in keep_indices if i < len(all_boxes)] + all_scores = [all_scores[i] for i in keep_indices] + all_labels = [all_labels[i] for i in keep_indices] + all_object_ids = [all_object_ids[i] for i in keep_indices] + + # Store for tracking + if all_masks: + cc.last_masks = torch.stack([torch.from_numpy(m).unsqueeze(0) for m in all_masks]) + cc.last_boxes = torch.tensor(all_boxes) if all_boxes else None + cc.last_scores = torch.tensor(all_scores) if all_scores else None + cc.last_labels = all_labels + cc.prev_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + else: + cc.last_masks = None + cc.last_boxes = None + cc.last_scores = None + cc.last_labels = None + + # CLIP-based visual matching filter + if cc.visual_match_enabled and cc.reference_embedding is not None and new_detections: + matched_detections = [] + matched_indices = [] + + for i, det in enumerate(new_detections): + box = det.get("box") + if box: + # Crop the detected region + x1, y1, x2, y2 = [int(v) for v in box] + h, w = frame.shape[:2] + x1, y1 = max(0, x1), max(0, y1) + x2, y2 = min(w, x2), min(h, y2) + + if x2 > x1 and y2 > y1: + crop = frame[y1:y2, x1:x2] + crop_pil = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)) + + # Get CLIP embedding + crop_embedding = get_clip_embedding(crop_pil) + if crop_embedding is not None: + similarity = compute_clip_similarity(cc.reference_embedding, crop_embedding) + det["clip_similarity"] = similarity + + if similarity >= cc.visual_match_threshold: + matched_detections.append(det) + matched_indices.append(i) + cc.log(f"Visual match: {det['label']} (sim: {similarity:.2f})") + + if matched_detections: + new_detections = matched_detections + # Also filter masks + if all_masks and matched_indices: + all_masks = [all_masks[i] for i in matched_indices if i < len(all_masks)] + cc.last_masks = torch.stack([torch.from_numpy(m).unsqueeze(0) for m in all_masks]) if all_masks else None + else: + cc.log("No visual matches found", "WARN") + new_detections = [] + + # Atomically update detections (only update if we have new detections, + # otherwise keep the existing tracked detections) + if new_detections: + with cc.lock: + cc.current_detections = new_detections + # Note: If SAM3 found nothing but we have tracked objects, keep them + # They will be removed by tracking when they actually leave the frame + + if all_labels: + cc.log(f"Detected: {', '.join(all_labels)}") + + elif cc.enable_tracking and cc.last_masks is not None and not cc.paused: + # Track with optical flow + tracked = track_frame(frame) + if tracked is not None: + cc.last_masks = tracked + + # Update detections based on tracked masks and remove objects that left frame + valid_indices = update_detections_from_tracked_masks(tracked, frame.shape) + + # If some masks were invalidated, update the mask list too + if valid_indices is not None and len(valid_indices) < len(tracked): + # Keep only valid masks + valid_masks = [tracked[i] for i in valid_indices] + if valid_masks: + cc.last_masks = torch.stack(valid_masks) + else: + cc.last_masks = None + + # Also update labels, scores, boxes to stay in sync + if cc.last_labels: + cc.last_labels = [cc.last_labels[i] for i in valid_indices if i < len(cc.last_labels)] + if cc.last_scores is not None and len(cc.last_scores) > 0: + try: + idx_tensor = torch.tensor(valid_indices, dtype=torch.long) + cc.last_scores = cc.last_scores[idx_tensor] if len(valid_indices) > 0 else None + except Exception: + cc.last_scores = None + if cc.last_boxes is not None and len(cc.last_boxes) > 0: + try: + idx_tensor = torch.tensor(valid_indices, dtype=torch.long) + cc.last_boxes = cc.last_boxes[idx_tensor] if len(valid_indices) > 0 else None + except Exception: + cc.last_boxes = None + + # Overlay masks on frame + display = frame.copy() + if cc.last_masks is not None: + display = overlay_masks(display, cc.last_masks, cc.last_boxes, cc.last_scores, cc.last_labels) + + # Draw pose overlays + if cc.enable_yolo_pose and cc.last_poses: + for obj_id, pose_data in cc.last_poses.items(): + display = draw_pose_overlay(display, pose_data, obj_id) + + # Obstacle detection during navigation (run on keyframes) + if cc.obstacle_detection_active and is_keyframe and not cc.paused: + try: + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_image = Image.fromarray(frame_rgb) + obstacles = detect_obstacles(frame, pil_image) + + if obstacles: + cc.current_obstacles = obstacles + display = overlay_obstacles(display, obstacles) + + # Log high-severity obstacles that should alert + for obs in obstacles: + if obs.get("should_alert") and obs.get("type") in ["high", "medium"]: + cc.log(f"OBSTACLE: {obs['label']} ({obs['distance']})", "WARN") + except Exception as e: + cc.log(f"Obstacle overlay error: {e}", "ERROR") + + return display + + +def track_frame(frame: np.ndarray) -> Optional[torch.Tensor]: + """Track masks using optical flow.""" + if cc.last_masks is None or cc.prev_gray is None: + return None + + try: + curr_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + + flow = cv2.calcOpticalFlowFarneback( + cc.prev_gray, curr_gray, None, + pyr_scale=0.5, levels=3, winsize=15, + iterations=3, poly_n=5, poly_sigma=1.2, flags=0 + ) + + h, w = curr_gray.shape + flow_map_x = np.arange(w).reshape(1, -1).repeat(h, axis=0).astype(np.float32) + flow_map_y = np.arange(h).reshape(-1, 1).repeat(w, axis=1).astype(np.float32) + flow_map_x += flow[:, :, 0] + flow_map_y += flow[:, :, 1] + + tracked_masks = [] + for mask in cc.last_masks: + if isinstance(mask, torch.Tensor): + mask_np = mask.cpu().numpy().squeeze() + else: + mask_np = mask.squeeze() + + if mask_np.shape != (h, w): + mask_np = cv2.resize(mask_np.astype(np.float32), (w, h)) + + warped = cv2.remap( + mask_np.astype(np.float32), + flow_map_x, flow_map_y, + interpolation=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=0 + ) + warped = (warped > 0.5).astype(np.float32) + + if cc.enable_fill_holes: + warped = fill_holes_in_mask(warped, cc.fill_hole_area) + if cc.enable_smooth_edges: + warped = smooth_mask_edges(warped, cc.smooth_kernel_size) + + tracked_masks.append(torch.from_numpy(warped).unsqueeze(0).to(cc.device_str)) + + cc.prev_gray = curr_gray + + if tracked_masks: + return torch.stack(tracked_masks) + + except Exception as e: + cc.log(f"Tracking error: {e}", "ERROR") + + return None + + +def overlay_masks(frame: np.ndarray, masks: torch.Tensor, boxes=None, scores=None, labels=None, alpha=0.5) -> np.ndarray: + """Overlay masks on frame.""" + if masks is None or masks.numel() == 0: + return frame + + overlay = frame.copy() + h, w = frame.shape[:2] + masks_np = masks.squeeze(1).cpu().numpy() + + scores_np = scores.cpu().numpy() if scores is not None and isinstance(scores, torch.Tensor) else scores + + for i, mask in enumerate(masks_np): + if mask.shape != (h, w): + mask = cv2.resize(mask.astype(np.float32), (w, h)) > 0.5 + + # Use persistent color if available + if cc.enable_persistent_ids and i < len(cc.current_detections): + det = cc.current_detections[i] + obj_id = det.get("persistent_id") + color = cc.object_colors.get(obj_id, COLORS[i % len(COLORS)]) + else: + color = COLORS[i % len(COLORS)] + + mask_region = mask.astype(bool) + overlay[mask_region] = ( + overlay[mask_region] * (1 - alpha) + np.array(color) * alpha + ).astype(np.uint8) + + # Draw contour + contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + cv2.drawContours(overlay, contours, -1, color, 2) + + # Draw label + if len(contours) > 0: + largest = max(contours, key=cv2.contourArea) + x, y, cw, ch = cv2.boundingRect(largest) + + label = labels[i] if labels and i < len(labels) else "object" + conf = scores_np[i] if scores_np is not None and i < len(scores_np) else 0.0 + + # Add persistent ID and YOLO info to label + text_parts = [] + if cc.enable_persistent_ids and i < len(cc.current_detections): + obj_id = cc.current_detections[i].get("persistent_id") + text_parts.append(f"#{obj_id}") + + text_parts.append(f"{label} {conf:.0%}") + + # Add YOLO classification if available + if i < len(cc.current_detections): + det = cc.current_detections[i] + yolo_info = det.get("yolo") + if yolo_info and "classify" in yolo_info: + yolo_class = yolo_info["classify"]["yolo_class"] + yolo_conf = yolo_info["classify"]["yolo_confidence"] + text_parts.append(f"[{yolo_class} {yolo_conf:.0%}]") + + text = " ".join(text_parts) + + font = cv2.FONT_HERSHEY_SIMPLEX + (tw, th), _ = cv2.getTextSize(text, font, 0.5, 1) + + cv2.rectangle(overlay, (x, y - th - 4), (x + tw + 4, y), color, -1) + cv2.putText(overlay, text, (x + 2, y - 2), font, 0.5, (255, 255, 255), 1) + + return overlay + + +def generate_frames(): + """Generator for video streaming.""" + global cc + + while cc.running: + if cc.camera is None or not cc.camera.isOpened(): + time.sleep(0.1) + continue + + ret, frame = cc.camera.read() + if not ret: + time.sleep(0.1) + continue + + # Apply flip transformations + if cc.flip_horizontal and cc.flip_vertical: + frame = cv2.flip(frame, -1) # Flip both + elif cc.flip_horizontal: + frame = cv2.flip(frame, 1) # Flip horizontally (mirror) + elif cc.flip_vertical: + frame = cv2.flip(frame, 0) # Flip vertically + + start = time.time() + + # Store raw frame (without overlays) for Claude analysis + cc.current_raw_frame = frame.copy() + + # Process frame (adds overlays) + display = process_frame(frame) + + # Calculate FPS + elapsed = time.time() - start + cc.fps = 1.0 / elapsed if elapsed > 0 else 0 + + # Encode to JPEG + _, buffer = cv2.imencode('.jpg', display, [cv2.IMWRITE_JPEG_QUALITY, 85]) + cc.current_frame = display + cc.current_frame_jpeg = buffer.tobytes() + + yield (b'--frame\r\n' + b'Content-Type: image/jpeg\r\n\r\n' + cc.current_frame_jpeg + b'\r\n') + + +def analyze_with_claude(image_data: str, label: str) -> str: + """Send image to Claude for analysis.""" + global ANTHROPIC_API_KEY + + if not ANTHROPIC_API_KEY: + return "Error: ANTHROPIC_API_KEY not set. Set it via environment variable or --api-key argument." + + try: + import anthropic + + client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY) + + if image_data.startswith("data:"): + image_data = image_data.split(",", 1)[1] + + message = client.messages.create( + model="claude-sonnet-4-20250514", + max_tokens=500, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": image_data, + }, + }, + { + "type": "text", + "text": f"This is a cropped image of a detected '{label}'. Please provide a brief, detailed description of what you see. Focus on: appearance, distinctive features, actions/pose, and any notable details. Keep it concise (2-3 sentences)." + } + ], + } + ], + ) + + return message.content[0].text + + except Exception as e: + return f"Analysis error: {str(e)}" + + +def analysis_worker(): + """Background worker for Claude analysis.""" + global cc + + while cc.running: + if cc.analysis_queue: + with cc.lock: + if cc.analysis_queue: + item = cc.analysis_queue.pop(0) + cc.analyzing = True + else: + item = None + + if item: + cc.log(f"Analyzing object #{item['id']}...", "INFO") + + detections = cc.current_detections + label = "object" + for det in detections: + if det.get("id") == item["id"]: + label = det.get("label", "object") + break + + result = analyze_with_claude(item["image_data"], label) + cc.add_analysis_result(item["id"], result) + cc.log(f"Analysis complete for #{item['id']}", "SUCCESS") + cc.analyzing = False + else: + time.sleep(0.5) + + +# ===== FLASK ROUTES ===== + +@app.route('/') +def index(): + """Main command center page.""" + return render_template('index.html', + prompts=cc.prompts, + threshold=cc.confidence_threshold, + skip_frames=cc.skip_frames, + tracking=cc.enable_tracking, + features=cc.get_feature_status(), + yolo_available=cc.yolo_available) + + +@app.route('/video_feed') +def video_feed(): + """Video streaming route.""" + return Response(generate_frames(), + mimetype='multipart/x-mixed-replace; boundary=frame') + + +@app.route('/api/status') +def api_status(): + """Get current status.""" + filtered, hidden = cc.get_filtered_detections() + return jsonify({ + "running": cc.running, + "paused": cc.paused, + "fps": round(cc.fps, 1), + "frame_count": cc.frame_count, + "device": cc.device_str, + "detections": filtered, + "hidden_counts": hidden, + "prompts": cc.prompts, + "max_objects": cc.max_objects_per_prompt, + "show_all": cc.show_all_matches, + "analyzing": cc.analyzing, + "analysis_queue_size": len(cc.analysis_queue), + "features": cc.get_feature_status(), + "tracked_objects_count": len(cc.tracked_objects), + "memory_bank_size": len(cc.memory_bank), + "yolo_available": cc.yolo_available, + "poses_count": len(cc.last_poses), + }) + + +@app.route('/api/logs') +def api_logs(): + """Get recent logs.""" + return jsonify({"logs": cc.get_logs()}) + + +@app.route('/api/analysis_results') +def api_analysis_results(): + """Get analysis results.""" + with cc.lock: + results = list(cc.analysis_results) + return jsonify({"results": results}) + + +@app.route('/api/set_prompts', methods=['POST']) +def api_set_prompts(): + """Set detection prompts.""" + data = request.json + prompts_str = data.get("prompts", "object") + cc.prompts = [p.strip() for p in prompts_str.split(",") if p.strip()] + cc.state = None + cc.last_masks = None + cc.last_boxes = None + cc.last_scores = None + cc.last_labels = None + cc.tracked_objects = {} + cc.memory_bank = {} + cc.last_poses = {} + cc.log(f"Prompts updated: {', '.join(cc.prompts)}") + return jsonify({"success": True, "prompts": cc.prompts}) + + +@app.route('/api/set_limit', methods=['POST']) +def api_set_limit(): + """Set max objects limit for a prompt.""" + data = request.json + prompt = data.get("prompt") + limit = data.get("limit") + + if limit is not None: + cc.max_objects_per_prompt[prompt] = int(limit) + elif prompt in cc.max_objects_per_prompt: + del cc.max_objects_per_prompt[prompt] + + cc.log(f"Limit for '{prompt}': {limit if limit else 'unlimited'}") + return jsonify({"success": True}) + + +@app.route('/api/toggle_show_all', methods=['POST']) +def api_toggle_show_all(): + """Toggle show all matches for a prompt.""" + data = request.json + prompt = data.get("prompt") + cc.show_all_matches[prompt] = not cc.show_all_matches.get(prompt, False) + cc.log(f"Show all for '{prompt}': {cc.show_all_matches[prompt]}") + return jsonify({"success": True, "show_all": cc.show_all_matches[prompt]}) + + +@app.route('/api/toggle_pause', methods=['POST']) +def api_toggle_pause(): + """Toggle pause state.""" + cc.paused = not cc.paused + cc.log("Paused" if cc.paused else "Resumed") + return jsonify({"success": True, "paused": cc.paused}) + + +@app.route('/api/reset', methods=['POST']) +def api_reset(): + """Reset detection state.""" + cc.state = None + cc.last_masks = None + cc.last_boxes = None + cc.last_scores = None + cc.last_labels = None + cc.tracked_objects = {} + cc.memory_bank = {} + cc.object_colors = {} + cc.next_object_id = 1 + cc.pending_detections = {} + cc.last_poses = {} + cc.clear_detections() + cc.log("Detection state reset") + return jsonify({"success": True}) + + +@app.route('/api/set_threshold', methods=['POST']) +def api_set_threshold(): + """Set confidence threshold.""" + data = request.json + cc.confidence_threshold = float(data.get("threshold", 0.3)) + if cc.processor: + cc.processor.confidence_threshold = cc.confidence_threshold + cc.log(f"Threshold set to {cc.confidence_threshold:.2f}") + return jsonify({"success": True}) + + +@app.route('/api/set_skip_frames', methods=['POST']) +def api_set_skip_frames(): + """Set skip frames value.""" + data = request.json + cc.skip_frames = max(1, int(data.get("skip_frames", 3))) + cc.log(f"Skip frames set to {cc.skip_frames}") + return jsonify({"success": True}) + + +# ===== FEATURE TOGGLE ROUTES ===== + +@app.route('/api/toggle_feature', methods=['POST']) +def api_toggle_feature(): + """Toggle a feature on/off.""" + data = request.json + feature = data.get("feature") + + feature_map = { + "tracking": "enable_tracking", + "memory_tracking": "enable_memory_tracking", + "persistent_ids": "enable_persistent_ids", + "fill_holes": "enable_fill_holes", + "non_overlap": "enable_non_overlap", + "smooth_edges": "enable_smooth_edges", + "boundary_suppression": "enable_boundary_suppression", + "occlusion_suppression": "enable_occlusion_suppression", + "hotstart": "enable_hotstart", + "yolo_classify": "enable_yolo_classify", + "yolo_pose": "enable_yolo_pose", + "show_keypoint_labels": "show_keypoint_labels", + "show_skeleton": "show_skeleton", + "label_spoofing": "enable_label_spoofing", + } + + if feature in feature_map: + attr = feature_map[feature] + current = getattr(cc, attr) + setattr(cc, attr, not current) + new_val = getattr(cc, attr) + cc.log(f"{feature}: {'ON' if new_val else 'OFF'}") + return jsonify({"success": True, "feature": feature, "enabled": new_val}) + + return jsonify({"success": False, "error": "Unknown feature"}) + + +@app.route('/api/set_feature_param', methods=['POST']) +def api_set_feature_param(): + """Set a feature parameter value.""" + data = request.json + param = data.get("param") + value = data.get("value") + + param_map = { + "fill_hole_area": ("fill_hole_area", int), + "smooth_kernel_size": ("smooth_kernel_size", int), + "boundary_margin": ("boundary_margin", int), + "occlusion_threshold": ("occlusion_threshold", float), + "hotstart_frames": ("hotstart_frames", int), + "iou_threshold": ("iou_threshold", float), + "memory_max_frames": ("memory_max_frames", int), + "yolo_classify_threshold": ("yolo_classify_threshold", float), + "yolo_pose_threshold": ("yolo_pose_threshold", float), + "yolo_classify_every_n": ("yolo_classify_every_n", int), + "keypoint_radius": ("keypoint_radius", int), + "skeleton_thickness": ("skeleton_thickness", int), + } + + if param in param_map: + attr, type_fn = param_map[param] + setattr(cc, attr, type_fn(value)) + cc.log(f"{param} set to {value}") + return jsonify({"success": True}) + + return jsonify({"success": False, "error": "Unknown parameter"}) + + +@app.route('/api/analyze_object', methods=['POST']) +def api_analyze_object(): + """Queue an object for Claude analysis with mask-based cropping.""" + data = request.json + detection_id = data.get("detection_id") + box = data.get("box") + mask_index = data.get("mask_index") # Index into cc.last_masks + + # Use raw frame (without overlays) for analysis + if cc.current_raw_frame is None: + return jsonify({"success": False, "error": "No frame available"}) + + try: + frame = cc.current_raw_frame.copy() + h, w = frame.shape[:2] + + # Try to use mask for better cropping + mask = None + if mask_index is not None and cc.last_masks is not None: + try: + if mask_index < len(cc.last_masks): + mask = cc.last_masks[mask_index].squeeze().cpu().numpy() + if mask.shape != (h, w): + mask = cv2.resize(mask.astype(np.float32), (w, h)) > 0.5 + except Exception as e: + cc.log(f"Could not get mask: {e}", "WARN") + mask = None + + if mask is not None and mask.sum() > 0: + # Use mask to create a clean crop with transparent/white background + # Get bounding box from mask + rows = np.any(mask, axis=1) + cols = np.any(mask, axis=0) + y_min, y_max = np.where(rows)[0][[0, -1]] + x_min, x_max = np.where(cols)[0][[0, -1]] + + # Add padding + pad = 15 + x1 = max(0, x_min - pad) + y1 = max(0, y_min - pad) + x2 = min(w, x_max + pad) + y2 = min(h, y_max + pad) + + # Crop the region + crop = frame[y1:y2, x1:x2].copy() + mask_crop = mask[y1:y2, x1:x2] + + # Apply mask - set background to white for cleaner analysis + mask_3ch = np.stack([mask_crop] * 3, axis=-1) + crop = np.where(mask_3ch, crop, 255).astype(np.uint8) + + elif box: + # Fallback to box-based cropping + x1, y1, x2, y2 = [int(v) for v in box] + pad = 20 + x1 = max(0, x1 - pad) + y1 = max(0, y1 - pad) + x2 = min(w, x2 + pad) + y2 = min(h, y2 + pad) + crop = frame[y1:y2, x1:x2] + else: + crop = frame + + _, buffer = cv2.imencode('.jpg', crop, [cv2.IMWRITE_JPEG_QUALITY, 90]) + image_data = base64.b64encode(buffer).decode('utf-8') + + cc.queue_analysis(detection_id, image_data) + cc.log(f"Queued object #{detection_id} for analysis (mask-cropped: {mask is not None})") + + return jsonify({"success": True}) + + except Exception as e: + cc.log(f"Failed to queue analysis: {e}", "ERROR") + return jsonify({"success": False, "error": str(e)}) + + +@app.route('/api/describe_scene', methods=['POST']) +def api_describe_scene(): + """Send full scene to Claude for description.""" + global ANTHROPIC_API_KEY + + if not ANTHROPIC_API_KEY: + return jsonify({"success": False, "error": "ANTHROPIC_API_KEY not set"}) + + if cc.current_raw_frame is None: + return jsonify({"success": False, "error": "No frame available"}) + + try: + import anthropic + + frame = cc.current_raw_frame.copy() + _, buffer = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 85]) + image_data = base64.b64encode(buffer).decode('utf-8') + + cc.log("Analyzing full scene with Claude...") + + client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY) + + message = client.messages.create( + model="claude-sonnet-4-20250514", + max_tokens=800, + messages=[ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": image_data, + }, + }, + { + "type": "text", + "text": "Please describe this scene in detail. Include: the setting/environment, all visible objects and people, their positions and relationships, any activities or actions taking place, lighting conditions, and any notable details. Be comprehensive but concise (3-5 sentences)." + } + ], + } + ], + ) + + result = message.content[0].text + cc.log("Scene analysis complete", "SUCCESS") + + # Add to analysis results + cc.add_analysis_result(-1, f"[SCENE] {result}") + + return jsonify({ + "success": True, + "description": result + }) + + except Exception as e: + cc.log(f"Scene analysis failed: {e}", "ERROR") + return jsonify({"success": False, "error": str(e)}) + + +@app.route('/api/tracked_objects') +def api_tracked_objects(): + """Get list of tracked objects with persistent IDs.""" + objects = [] + for obj_id, data in cc.tracked_objects.items(): + objects.append({ + "id": obj_id, + "label": data.get("label"), + "first_seen": data.get("first_seen"), + "last_seen": data.get("last_seen"), + "confidence": data.get("confidence", 0), + "frames_tracked": data.get("last_seen", 0) - data.get("first_seen", 0), + }) + return jsonify({"objects": objects}) + + +@app.route('/api/poses') +def api_poses(): + """Get current pose data for all detected persons.""" + poses = [] + for obj_id, pose_data in cc.last_poses.items(): + poses.append({ + "object_id": obj_id, + "confidence": pose_data.get("confidence", 0), + "keypoints": [ + {"name": name, "x": kp[0], "y": kp[1], "confidence": kp[2]} + for name, kp in zip(POSE_KEYPOINTS, pose_data.get("keypoints", [])) + ] + }) + return jsonify({"poses": poses}) + + +@app.route('/api/coco_mapping') +def api_coco_mapping(): + """Get SAM3 to COCO label mapping.""" + return jsonify({ + "mapping": SAM3_TO_COCO, + "coco_classes": COCO_CLASSES + }) + + +# ===== VOICE SEARCH ROUTES ===== + +def parse_voice_query_with_claude(voice_text: str) -> Dict: + """ + Use Claude to parse a voice query into search prompts. + + Handles queries like: + - "help me find a red car" + - "can you search for a person and a dog" + - "look for my phone, keys, and wallet" + - "find the blue cup on the table" + + Returns dict with: + - prompts: List of parsed object prompts (comma-separated format) + - is_multi: Whether multiple objects were requested + - feedback: Human-readable feedback message + """ + global ANTHROPIC_API_KEY + + if not ANTHROPIC_API_KEY: + return { + "success": False, + "error": "ANTHROPIC_API_KEY not set", + "prompts": [voice_text], # Fallback: use raw text + "feedback": f"API key not set. Searching for: {voice_text}" + } + + try: + import anthropic + + client = anthropic.Anthropic(api_key=ANTHROPIC_API_KEY) + + message = client.messages.create( + model="claude-sonnet-4-20250514", + max_tokens=300, + messages=[ + { + "role": "user", + "content": f"""Parse this voice command for an object detection system. Extract the objects the user wants to find. + +Voice command: "{voice_text}" + +Rules: +1. Extract object names/descriptions that can be detected visually +2. If multiple objects are mentioned, list them all +3. Include color/size descriptors if mentioned (e.g., "red car", "large dog") +4. Ignore filler words like "help me find", "can you search for", "look for" +5. Return ONLY a JSON object, no other text + +Return JSON format: +{{"prompts": ["object1", "object2"], "feedback": "Searching for object1 and object2"}} + +Examples: +- "help me find a red car" -> {{"prompts": ["red car"], "feedback": "Searching for red car"}} +- "search for people and dogs" -> {{"prompts": ["person", "dog"], "feedback": "Searching for person and dog"}} +- "find my phone and keys" -> {{"prompts": ["phone", "keys"], "feedback": "Searching for phone and keys"}} +- "look for a blue cup" -> {{"prompts": ["blue cup"], "feedback": "Searching for blue cup"}}""" + } + ], + ) + + response_text = message.content[0].text.strip() + + # Parse JSON from response + # Handle potential markdown code blocks + if "```json" in response_text: + response_text = response_text.split("```json")[1].split("```")[0].strip() + elif "```" in response_text: + response_text = response_text.split("```")[1].split("```")[0].strip() + + result = json.loads(response_text) + + prompts = result.get("prompts", []) + feedback = result.get("feedback", f"Searching for {', '.join(prompts)}") + + return { + "success": True, + "prompts": prompts, + "is_multi": len(prompts) > 1, + "feedback": feedback, + "raw_query": voice_text + } + + except json.JSONDecodeError as e: + cc.log(f"Failed to parse Claude response as JSON: {e}", "ERROR") + # Fallback: just use the voice text directly + return { + "success": True, + "prompts": [voice_text], + "is_multi": False, + "feedback": f"Searching for {voice_text}", + "raw_query": voice_text + } + except Exception as e: + cc.log(f"Voice query parsing error: {e}", "ERROR") + return { + "success": False, + "error": str(e), + "prompts": [], + "feedback": "Failed to parse voice command" + } + + +def check_describe_command(voice_text: str) -> Optional[Dict]: + """ + Check if voice command is a describe command. + Returns dict with action info if it's a describe command, None otherwise. + + Handles: + - "describe scene" / "describe the scene" / "what do you see" + - "describe object 1" / "describe the first object" / "tell me about object 2" + - "analyze object 3" / "what is object 1" + """ + text_lower = voice_text.lower().strip() + + # Scene describe patterns + scene_patterns = [ + "describe scene", "describe the scene", "describe this scene", + "what do you see", "what's in the scene", "describe everything", + "describe the view", "describe what you see", "analyze scene", + "tell me about the scene", "what's happening" + ] + + for pattern in scene_patterns: + if pattern in text_lower: + return { + "action": "describe_scene", + "feedback": "Describing the scene..." + } + + # Object describe patterns - extract object number + import re + + # Patterns like "describe object 1", "analyze object 2", "tell me about object 3" + object_patterns = [ + r"describe (?:the )?(?:object|item|thing) (\d+)", + r"analyze (?:the )?(?:object|item|thing) (\d+)", + r"what is (?:object|item|thing) (\d+)", + r"tell me about (?:object|item|thing) (\d+)", + r"describe (?:the )?(\d+)(?:st|nd|rd|th)? (?:object|item|thing)", + r"describe number (\d+)", + r"object (\d+) describe", + ] + + for pattern in object_patterns: + match = re.search(pattern, text_lower) + if match: + obj_num = int(match.group(1)) + return { + "action": "describe_object", + "object_id": obj_num, + "feedback": f"Describing object {obj_num}..." + } + + # Ordinal patterns like "describe the first object", "analyze the second item" + ordinals = { + "first": 0, "1st": 0, + "second": 1, "2nd": 1, + "third": 2, "3rd": 2, + "fourth": 3, "4th": 3, + "fifth": 4, "5th": 4, + } + + for ordinal, idx in ordinals.items(): + if ordinal in text_lower and ("object" in text_lower or "item" in text_lower or "thing" in text_lower): + if "describe" in text_lower or "analyze" in text_lower or "tell me" in text_lower or "what is" in text_lower: + return { + "action": "describe_object", + "object_index": idx, + "feedback": f"Describing the {ordinal} object..." + } + + return None + + +@app.route('/api/voice_search', methods=['POST']) +def api_voice_search(): + """Process a voice search query through Claude and set prompts.""" + data = request.json + voice_text = data.get("text", "").strip() + + if not voice_text: + return jsonify({"success": False, "error": "No voice text provided"}) + + cc.log(f"Voice query received: '{voice_text}'", "INFO") + cc.last_voice_query = voice_text + + # First check for describe commands + describe_cmd = check_describe_command(voice_text) + if describe_cmd: + cc.add_voice_feedback(describe_cmd["feedback"], "info") + + if describe_cmd["action"] == "describe_scene": + return jsonify({ + "success": True, + "action": "describe_scene", + "feedback": describe_cmd["feedback"], + "tts_message": describe_cmd["feedback"] + }) + + elif describe_cmd["action"] == "describe_object": + # Find the object to describe + obj_id = describe_cmd.get("object_id") + obj_index = describe_cmd.get("object_index") + + detections = cc.current_detections + + if not detections: + return jsonify({ + "success": False, + "error": "No objects detected", + "feedback": "No objects are currently detected" + }) + + # Find the detection + target_det = None + target_index = None + + if obj_id is not None: + # Look for object with this ID + for i, det in enumerate(detections): + if det.get("id") == obj_id: + target_det = det + target_index = i + break + elif obj_index is not None: + # Use index directly + if obj_index < len(detections): + target_det = detections[obj_index] + target_index = obj_index + + if target_det is None: + return jsonify({ + "success": False, + "error": f"Object not found", + "feedback": f"Could not find the specified object" + }) + + return jsonify({ + "success": True, + "action": "describe_object", + "detection": target_det, + "mask_index": target_index, + "feedback": describe_cmd["feedback"], + "tts_message": describe_cmd["feedback"] + }) + + # Parse the voice query with Claude for search + result = parse_voice_query_with_claude(voice_text) + + if result["success"] and result["prompts"]: + # Update prompts + cc.prompts = result["prompts"] + cc.last_parsed_prompts = result["prompts"] + + # Reset detection state for new search + cc.state = None + cc.last_masks = None + cc.last_boxes = None + cc.last_scores = None + cc.last_labels = None + cc.tracked_objects = {} + cc.memory_bank = {} + cc.last_poses = {} + + prompt_str = ", ".join(result["prompts"]) + cc.log(f"Voice search: {prompt_str}", "SUCCESS") + cc.add_voice_feedback(result["feedback"], "success") + + return jsonify({ + "success": True, + "action": "search", + "prompts": result["prompts"], + "prompt_string": prompt_str, + "is_multi": result["is_multi"], + "feedback": result["feedback"], + "tts_message": result["feedback"] + }) + else: + error_msg = result.get("error", "Could not understand the voice command") + cc.add_voice_feedback(f"Error: {error_msg}", "error") + return jsonify({ + "success": False, + "error": error_msg, + "feedback": result.get("feedback", "Failed to process voice command") + }) + + +@app.route('/api/voice_feedback') +def api_voice_feedback(): + """Get recent voice feedback messages.""" + with cc.lock: + messages = list(cc.voice_feedback_messages) + return jsonify({ + "messages": messages, + "last_query": cc.last_voice_query, + "last_prompts": cc.last_parsed_prompts + }) + + +@app.route('/api/toggle_voice', methods=['POST']) +def api_toggle_voice(): + """Toggle voice features.""" + data = request.json + feature = data.get("feature", "voice") + + if feature == "voice": + cc.voice_enabled = not cc.voice_enabled + cc.log(f"Voice input: {'ON' if cc.voice_enabled else 'OFF'}") + return jsonify({"success": True, "enabled": cc.voice_enabled}) + elif feature == "tts": + cc.tts_enabled = not cc.tts_enabled + cc.log(f"TTS output: {'ON' if cc.tts_enabled else 'OFF'}") + return jsonify({"success": True, "enabled": cc.tts_enabled}) + + return jsonify({"success": False, "error": "Unknown feature"}) + + +# ===== CAMERA ROUTES ===== + +@app.route('/api/cameras') +def api_cameras(): + """Get list of available cameras.""" + cameras = detect_available_cameras() + cc.available_cameras = cameras + return jsonify({ + "cameras": cameras, + "current_camera": cc.current_camera_id, + "flip_horizontal": cc.flip_horizontal, + "flip_vertical": cc.flip_vertical + }) + + +@app.route('/api/switch_camera', methods=['POST']) +def api_switch_camera(): + """Switch to a different camera.""" + data = request.json + camera_id = data.get("camera_id") + + if camera_id is None: + return jsonify({"success": False, "error": "No camera_id provided"}) + + camera_id = int(camera_id) + + success = switch_camera(camera_id) + + return jsonify({ + "success": success, + "current_camera": cc.current_camera_id, + "message": f"Switched to camera {camera_id}" if success else f"Failed to switch to camera {camera_id}" + }) + + +@app.route('/api/flip_camera', methods=['POST']) +def api_flip_camera(): + """Toggle camera flip (horizontal/vertical).""" + data = request.json + direction = data.get("direction", "horizontal") + + if direction == "horizontal": + cc.flip_horizontal = not cc.flip_horizontal + cc.log(f"Horizontal flip: {'ON' if cc.flip_horizontal else 'OFF'}") + # Reset detection state when flip changes + reset_detection_state() + return jsonify({ + "success": True, + "flip_horizontal": cc.flip_horizontal, + "flip_vertical": cc.flip_vertical + }) + elif direction == "vertical": + cc.flip_vertical = not cc.flip_vertical + cc.log(f"Vertical flip: {'ON' if cc.flip_vertical else 'OFF'}") + # Reset detection state when flip changes + reset_detection_state() + return jsonify({ + "success": True, + "flip_horizontal": cc.flip_horizontal, + "flip_vertical": cc.flip_vertical + }) + elif direction == "both": + cc.flip_horizontal = not cc.flip_horizontal + cc.flip_vertical = not cc.flip_vertical + cc.log(f"Flip both: H={'ON' if cc.flip_horizontal else 'OFF'}, V={'ON' if cc.flip_vertical else 'OFF'}") + reset_detection_state() + return jsonify({ + "success": True, + "flip_horizontal": cc.flip_horizontal, + "flip_vertical": cc.flip_vertical + }) + + return jsonify({"success": False, "error": "Invalid direction"}) + + +@app.route('/api/set_flip', methods=['POST']) +def api_set_flip(): + """Set flip state explicitly.""" + data = request.json + flip_h = data.get("flip_horizontal") + flip_v = data.get("flip_vertical") + + changed = False + + if flip_h is not None and flip_h != cc.flip_horizontal: + cc.flip_horizontal = bool(flip_h) + changed = True + + if flip_v is not None and flip_v != cc.flip_vertical: + cc.flip_vertical = bool(flip_v) + changed = True + + if changed: + cc.log(f"Flip set: H={'ON' if cc.flip_horizontal else 'OFF'}, V={'ON' if cc.flip_vertical else 'OFF'}") + reset_detection_state() + + return jsonify({ + "success": True, + "flip_horizontal": cc.flip_horizontal, + "flip_vertical": cc.flip_vertical + }) + + +# ===== REFERENCE IMAGE SEARCH API ===== + +@app.route('/api/upload_reference', methods=['POST']) +def api_upload_reference(): + """ + Upload a reference image for search. + Modes: + - 'description': Use Claude to describe, then search by text + - 'visual': Use CLIP for visual similarity matching + """ + global cc + + if 'image' not in request.files: + return jsonify({"success": False, "error": "No image provided"}) + + mode = request.form.get('mode', 'description') # 'description' or 'visual' + + try: + file = request.files['image'] + image_data = file.read() + + # Convert to PIL Image + pil_image = Image.open(io.BytesIO(image_data)).convert('RGB') + cc.reference_image = pil_image + + # Get base64 for Claude + buffered = io.BytesIO() + pil_image.save(buffered, format="JPEG", quality=90) + base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8') + + if mode == 'description': + # Use Claude to describe the image + cc.log("Analyzing reference image with Claude...") + description = describe_image_with_claude(base64_image) + + if description: + cc.reference_description = description + cc.visual_match_enabled = False + + # Set as prompt + cc.prompts = [description] + cc.state = None + cc.last_masks = None + reset_detection_state() + + cc.log(f"Reference search: '{description}'", "SUCCESS") + + return jsonify({ + "success": True, + "mode": "description", + "description": description, + "prompt": description + }) + else: + return jsonify({"success": False, "error": "Failed to describe image"}) + + elif mode == 'visual': + # Use CLIP for visual matching + if not cc.clip_available: + return jsonify({ + "success": False, + "error": "CLIP not available. Install with: pip install transformers" + }) + + cc.log("Computing CLIP embedding for reference image...") + embedding = get_clip_embedding(pil_image) + + if embedding is not None: + cc.reference_embedding = embedding + cc.visual_match_enabled = True + + # Also get a description for display + description = describe_image_with_claude(base64_image) + cc.reference_description = description or "Visual reference" + + # Set a generic prompt to detect objects + cc.prompts = ["object"] + cc.state = None + cc.last_masks = None + reset_detection_state() + + cc.log(f"Visual matching enabled for: {cc.reference_description}", "SUCCESS") + + return jsonify({ + "success": True, + "mode": "visual", + "description": cc.reference_description, + "message": "Visual matching enabled" + }) + else: + return jsonify({"success": False, "error": "Failed to compute CLIP embedding"}) + + else: + return jsonify({"success": False, "error": f"Unknown mode: {mode}"}) + + except Exception as e: + cc.log(f"Reference upload failed: {e}", "ERROR") + return jsonify({"success": False, "error": str(e)}) + + +@app.route('/api/clear_reference', methods=['POST']) +def api_clear_reference(): + """Clear the reference image.""" + global cc + + cc.reference_image = None + cc.reference_embedding = None + cc.reference_description = None + cc.visual_match_enabled = False + + cc.log("Reference image cleared") + + return jsonify({"success": True}) + + +@app.route('/api/reference_status') +def api_reference_status(): + """Get reference image status.""" + return jsonify({ + "has_reference": cc.reference_image is not None, + "description": cc.reference_description, + "visual_match_enabled": cc.visual_match_enabled, + "clip_available": cc.clip_available, + "threshold": cc.visual_match_threshold + }) + + +# ===== GEOMETRIC PROMPTS (DRAW TO SEARCH) API ===== + +@app.route('/api/draw_prompt', methods=['POST']) +def api_draw_prompt(): + """ + Set a geometric prompt (box or point) from user drawing. + This will be processed on the next frame. + """ + global cc + + data = request.json + prompt_type = data.get('type', 'box') # 'box' or 'point' + + if prompt_type == 'box': + x1 = data.get('x1') + y1 = data.get('y1') + x2 = data.get('x2') + y2 = data.get('y2') + + if all(v is not None for v in [x1, y1, x2, y2]): + cc.pending_box_prompt = (float(x1), float(y1), float(x2), float(y2)) + cc.pending_point_prompt = None + cc.draw_mode = 'box' + cc.log(f"Box prompt set: ({x1:.0f}, {y1:.0f}) to ({x2:.0f}, {y2:.0f})") + + return jsonify({ + "success": True, + "type": "box", + "box": [x1, y1, x2, y2] + }) + else: + return jsonify({"success": False, "error": "Invalid box coordinates"}) + + elif prompt_type == 'point': + x = data.get('x') + y = data.get('y') + + if x is not None and y is not None: + cc.pending_point_prompt = (float(x), float(y)) + cc.pending_box_prompt = None + cc.draw_mode = 'point' + cc.log(f"Point prompt set: ({x:.0f}, {y:.0f})") + + return jsonify({ + "success": True, + "type": "point", + "point": [x, y] + }) + else: + return jsonify({"success": False, "error": "Invalid point coordinates"}) + + else: + return jsonify({"success": False, "error": f"Unknown prompt type: {prompt_type}"}) + + +@app.route('/api/clear_draw_prompt', methods=['POST']) +def api_clear_draw_prompt(): + """Clear any pending geometric prompts.""" + global cc + + cc.pending_box_prompt = None + cc.pending_point_prompt = None + cc.draw_mode = None + + cc.log("Draw prompt cleared") + + return jsonify({"success": True}) + + +# ===== NAVIGATION SYSTEM API ===== + +@app.route('/api/navigation/start', methods=['POST']) +def api_navigation_start(): + """Start navigation to a detected object.""" + global cc + + data = request.json + target_label = data.get("target_label") or data.get("label") + target_id = data.get("target_id") or data.get("detection_id") + + if not target_label and target_id is None: + return jsonify({"success": False, "error": "No target specified"}) + + # Check for location memory first (from SQLite) + memory = cc.recall_location(target_label) if target_label else None + memory_hint = None + if memory: + memory_hint = f"I remember finding {target_label} in the {memory.get('context', 'unknown location')} before." + + cc.navigation_active = True + cc.navigation_target = target_label + cc.navigation_target_id = target_id + cc.navigation_start_time = time.time() + cc.navigation_last_seen = None + cc.navigation_reached = False + cc.navigation_target_history = [] + + # Start obstacle detection + cc.obstacle_detection_active = True + cc.current_obstacles = [] + cc.obstacle_masks = None + cc.obstacle_boxes = None + + # Create navigation session in database + if cc.session_id: + cc.navigation_db_id = db.start_navigation_session(cc.session_id, target_label, target_id) + db.log_event(cc.session_id, "navigation_start", f"Started navigation to {target_label}", + data={"target_label": target_label, "target_id": target_id}) + + # Analyze scene context + if cc.current_raw_frame is not None: + try: + _, buffer = cv2.imencode('.jpg', cc.current_raw_frame, [cv2.IMWRITE_JPEG_QUALITY, 70]) + image_data = base64.b64encode(buffer).decode('utf-8') + cc.navigation_context = analyze_scene_context(image_data) + except Exception as e: + cc.log(f"Scene context analysis failed: {e}", "WARN") + cc.navigation_context = None + + cc.log(f"Navigation started: looking for '{target_label}'", "SUCCESS") + + # Initial message + location = cc.navigation_context.get("location", "this area") if cc.navigation_context else "this area" + initial_message = f"Starting navigation to find {target_label}. You appear to be in {location}." + if memory_hint: + initial_message += f" {memory_hint}" + + return jsonify({ + "success": True, + "target": target_label, + "initial_message": initial_message, + "memory_hint": memory_hint, + "context": cc.navigation_context + }) + + +@app.route('/api/navigation/stop', methods=['POST']) +def api_navigation_stop(): + """Stop navigation.""" + global cc + + was_active = cc.navigation_active + target = cc.navigation_target + reached = cc.navigation_reached + + # If we reached the target, remember its location (in SQLite) + if reached and cc.navigation_context and target: + location = cc.navigation_context.get("location", "unknown location") + cc.remember_location(target, location) + + # End navigation session in database + if cc.navigation_db_id: + db.end_navigation_session( + cc.navigation_db_id, + reached=reached, + path_history=cc.navigation_target_history, + scene_context=cc.navigation_context + ) + if cc.session_id: + db.log_event(cc.session_id, "navigation_stop", + f"Navigation to {target} {'reached' if reached else 'cancelled'}", + data={"target": target, "reached": reached}) + + # Stop obstacle detection + cc.obstacle_detection_active = False + cc.current_obstacles = [] + cc.obstacle_masks = None + cc.obstacle_boxes = None + + cc.navigation_active = False + cc.navigation_target = None + cc.navigation_target_id = None + cc.navigation_db_id = None + cc.navigation_start_time = None + cc.navigation_last_seen = None + cc.navigation_reached = False + cc.navigation_context = None + cc.navigation_target_history = [] + + if was_active: + cc.log(f"Navigation ended for '{target}'") + + return jsonify({ + "success": True, + "reached": reached, + "show_post_nav_dialog": was_active # Tell UI to show continue/pause dialog + }) + + +@app.route('/api/navigation/status') +def api_navigation_status(): + """Get current navigation status and guidance.""" + status = get_navigation_status() + + # Add TTS guidance if needed + if status.get("active") and status.get("guidance"): + current_time = time.time() + guidance_text = status["guidance"].get("guidance_text", "") + + # Only speak if enough time has passed and guidance changed + if (current_time - cc.navigation_last_guidance_time > cc.navigation_guidance_interval and + guidance_text != cc.navigation_last_guidance): + status["speak_guidance"] = True + cc.navigation_last_guidance = guidance_text + cc.navigation_last_guidance_time = current_time + else: + status["speak_guidance"] = False + + # Add obstacle alerts with position for AR path routing + if cc.current_obstacles: + obstacles_for_alert = [] + for obs in cc.current_obstacles: + if obs.get("should_alert"): + obstacles_for_alert.append({ + "label": obs["label"], + "type": obs["type"], + "distance": obs["distance"], + "position": obs.get("position", "center"), # For AR path routing + "reason": obs.get("reason", ""), + "alert_text": f"Watch out! {obs['label']} {obs['distance'].replace('_', ' ')}" + }) + status["obstacles"] = obstacles_for_alert + + return jsonify(status) + + +@app.route('/api/navigation/analyze_scene', methods=['POST']) +def api_navigation_analyze_scene(): + """Analyze current scene for navigation context.""" + global cc + + if cc.current_raw_frame is None: + return jsonify({"success": False, "error": "No frame available"}) + + try: + _, buffer = cv2.imencode('.jpg', cc.current_raw_frame, [cv2.IMWRITE_JPEG_QUALITY, 70]) + image_data = base64.b64encode(buffer).decode('utf-8') + context = analyze_scene_context(image_data) + + if context: + cc.navigation_context = context + return jsonify({"success": True, "context": context}) + else: + return jsonify({"success": False, "error": "Analysis failed"}) + + except Exception as e: + return jsonify({"success": False, "error": str(e)}) + + +@app.route('/api/location_memory') +def api_location_memory(): + """Get stored location memory (from SQLite).""" + memories = cc.get_all_location_memories() + return jsonify({ + "success": True, + "memory": memories + }) + + +@app.route('/api/location_memory/recall', methods=['POST']) +def api_recall_location(): + """Recall where an object was last found (from SQLite).""" + data = request.json + label = data.get("label", "") + + memory = cc.recall_location(label) + + if memory: + return jsonify({ + "success": True, + "found": True, + "label": label, + "location": memory.get("context"), + "frequency": memory.get("frequency", 1), + "last_seen": memory.get("last_seen") + }) + else: + return jsonify({ + "success": True, + "found": False, + "label": label, + "message": f"No memory of where {label} was found" + }) + + +@app.route('/api/location_memory/clear', methods=['POST']) +def api_clear_location_memory(): + """Clear location memory.""" + data = request.json or {} + label = data.get("label") + + cc.clear_location_memory(label) + + return jsonify({ + "success": True, + "message": f"Cleared location memory" + (f" for {label}" if label else "") + }) + + +# ===== OBSTACLE DETECTION API ===== + +@app.route('/api/navigation/obstacles') +def api_navigation_obstacles(): + """Get current obstacles detected during navigation.""" + return jsonify({ + "success": True, + "obstacles": cc.current_obstacles, + "active": cc.obstacle_detection_active + }) + + +# ===== DATABASE HISTORY API ===== + +@app.route('/api/history/detections') +def api_history_detections(): + """Get detection history from database.""" + label = request.args.get('label') + limit = int(request.args.get('limit', 100)) + + history = db.get_detection_history(session_id=cc.session_id, label=label, limit=limit) + + return jsonify({ + "success": True, + "detections": history, + "count": len(history) + }) + + +@app.route('/api/history/analysis') +def api_history_analysis(): + """Get analysis history from database.""" + limit = int(request.args.get('limit', 50)) + + history = db.get_analysis_history(session_id=cc.session_id, limit=limit) + + return jsonify({ + "success": True, + "analyses": history, + "count": len(history) + }) + + +@app.route('/api/history/navigation') +def api_history_navigation(): + """Get navigation history from database.""" + limit = int(request.args.get('limit', 20)) + + history = db.get_navigation_history(session_id=cc.session_id, limit=limit) + + return jsonify({ + "success": True, + "navigations": history, + "count": len(history) + }) + + +@app.route('/api/session/stats') +def api_session_stats(): + """Get statistics for the current session.""" + if not cc.session_id: + return jsonify({"success": False, "error": "No active session"}) + + stats = db.get_session_stats(cc.session_id) + + return jsonify({ + "success": True, + "session_id": cc.session_id, + "stats": stats + }) + + +def generate_self_signed_cert(cert_dir: str = None) -> Tuple[str, str]: + """Generate a self-signed SSL certificate for HTTPS.""" + try: + from cryptography import x509 + from cryptography.x509.oid import NameOID + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.hazmat.primitives import serialization + import datetime + + if cert_dir is None: + cert_dir = os.path.join(os.path.dirname(__file__), '.ssl') + + os.makedirs(cert_dir, exist_ok=True) + + key_path = os.path.join(cert_dir, 'key.pem') + cert_path = os.path.join(cert_dir, 'cert.pem') + + # Check if certs already exist + if os.path.exists(key_path) and os.path.exists(cert_path): + print(f"Using existing SSL certificates from {cert_dir}") + return cert_path, key_path + + print("Generating self-signed SSL certificate...") + + # Generate private key + key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + backend=default_backend() + ) + + # Generate certificate + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "California"), + x509.NameAttribute(NameOID.LOCALITY_NAME, "San Francisco"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "SAM3 Command Center"), + x509.NameAttribute(NameOID.COMMON_NAME, "localhost"), + ]) + + cert = x509.CertificateBuilder().subject_name( + subject + ).issuer_name( + issuer + ).public_key( + key.public_key() + ).serial_number( + x509.random_serial_number() + ).not_valid_before( + datetime.datetime.utcnow() + ).not_valid_after( + datetime.datetime.utcnow() + datetime.timedelta(days=365) + ).add_extension( + x509.SubjectAlternativeName([ + x509.DNSName("localhost"), + x509.DNSName("127.0.0.1"), + x509.IPAddress(ipaddress.IPv4Address("127.0.0.1")), + ]), + critical=False, + ).sign(key, hashes.SHA256(), default_backend()) + + # Write key + with open(key_path, "wb") as f: + f.write(key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption() + )) + + # Write certificate + with open(cert_path, "wb") as f: + f.write(cert.public_bytes(serialization.Encoding.PEM)) + + print(f"SSL certificate generated: {cert_path}") + return cert_path, key_path + + except ImportError: + print("WARNING: cryptography package not installed. Cannot generate SSL certificate.") + print(" Install with: pip install cryptography") + print(" Or provide --ssl-cert and --ssl-key arguments") + return None, None + + +def main(): + global cc + + parser = argparse.ArgumentParser(description="SAM3 Web Command Center") + parser.add_argument("--camera", "-c", type=int, default=0, help="Camera device ID") + parser.add_argument("--device", "-d", type=str, default=None, help="Device (cuda, mps, cpu)") + parser.add_argument("--prompt", type=str, default="object", help="Initial prompts (comma-separated)") + parser.add_argument("--threshold", type=float, default=0.3, help="Confidence threshold") + parser.add_argument("--checkpoint", type=str, default=None, help="Model checkpoint path") + parser.add_argument("--port", type=int, default=5000, help="Web server port") + parser.add_argument("--skip-frames", type=int, default=3, help="Process every N frames") + parser.add_argument("--no-tracking", action="store_true", help="Disable optical flow tracking") + parser.add_argument("--no-yolo", action="store_true", help="Disable YOLO models") + parser.add_argument("--api-key", type=str, default=None, help="Anthropic API key (or set ANTHROPIC_API_KEY env var)") + parser.add_argument("--no-https", action="store_true", help="Disable HTTPS (not recommended - microphone won't work)") + parser.add_argument("--ssl-cert", type=str, default=None, help="Path to SSL certificate file") + parser.add_argument("--ssl-key", type=str, default=None, help="Path to SSL private key file") + + args = parser.parse_args() + + # Set API key from argument if provided + global ANTHROPIC_API_KEY + if args.api_key: + ANTHROPIC_API_KEY = args.api_key + print("Using API key from command line argument") + elif ANTHROPIC_API_KEY: + print("Using API key from environment variable") + else: + print("WARNING: No Anthropic API key set. Claude features (analysis, voice search) will not work.") + print(" Set via: --api-key YOUR_KEY or ANTHROPIC_API_KEY=YOUR_KEY") + + # Configure command center + cc.prompts = [p.strip() for p in args.prompt.split(",") if p.strip()] + cc.confidence_threshold = args.threshold + cc.skip_frames = args.skip_frames + cc.enable_tracking = not args.no_tracking + + if args.device: + cc.device_str = args.device + + # Create database session + cc.session_id = db.create_session( + device=args.device or "auto", + prompts=cc.prompts, + settings={ + "threshold": args.threshold, + "skip_frames": args.skip_frames, + "tracking": not args.no_tracking, + "yolo": not args.no_yolo + } + ) + cc.log(f"Database session started: {cc.session_id[:8]}...") + + # Load model + load_model(args.checkpoint) + + # Load depth estimation model for LIDAR-like obstacle detection + cc.log("Loading depth estimation model for advanced obstacle detection...") + load_depth_model() + if depth_model is not None: + cc.log("Depth estimation model loaded successfully", "SUCCESS") + else: + cc.log("Depth estimation unavailable - using other detection layers", "WARNING") + + # Skip YOLO if requested + if args.no_yolo: + cc.yolo_available = False + cc.log("YOLO disabled via command line") + + # Detect available cameras + cc.log("Detecting available cameras...") + cc.available_cameras = detect_available_cameras() + cc.log(f"Found {len(cc.available_cameras)} camera(s)", "SUCCESS") + for cam in cc.available_cameras: + cc.log(f" Camera {cam['id']}: {cam['description']}") + + # Open camera + cc.log(f"Opening camera {args.camera}...") + cc.camera = cv2.VideoCapture(args.camera) + cc.current_camera_id = args.camera + + if not cc.camera.isOpened(): + cc.log(f"Failed to open camera {args.camera}", "ERROR") + return + + width = int(cc.camera.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cc.camera.get(cv2.CAP_PROP_FRAME_HEIGHT)) + cc.log(f"Camera opened: {width}x{height}", "SUCCESS") + + cc.running = True + + # Start analysis worker + analysis_thread = threading.Thread(target=analysis_worker, daemon=True) + analysis_thread.start() + + print(f"\n{'='*50}") + print(f"SAM3 Web Command Center") + print(f"{'='*50}") + + # Setup SSL (HTTPS is default, use --no-https to disable) + ssl_context = None + protocol = "http" + + if not args.no_https: + if args.ssl_cert and args.ssl_key: + # Use provided certificates + if os.path.exists(args.ssl_cert) and os.path.exists(args.ssl_key): + ssl_context = (args.ssl_cert, args.ssl_key) + protocol = "https" + print(f"Using provided SSL certificates") + else: + print(f"ERROR: SSL certificate files not found") + print(f" Cert: {args.ssl_cert}") + print(f" Key: {args.ssl_key}") + return + else: + # Generate self-signed certificate + cert_path, key_path = generate_self_signed_cert() + if cert_path and key_path: + ssl_context = (cert_path, key_path) + protocol = "https" + print(f"Using auto-generated self-signed certificate") + print(f" NOTE: You may need to accept the security warning in your browser") + else: + print("WARNING: Could not setup HTTPS. Falling back to HTTP.") + print(" Microphone and navigation features may not work without HTTPS!") + else: + print("WARNING: HTTPS disabled. Microphone and navigation features may not work!") + + print(f"Open {protocol}://localhost:{args.port} in your browser") + print(f"YOLO: {'Available' if cc.yolo_available else 'Not available'}") + print(f"CLIP: {'Available' if cc.clip_available else 'Not available'}") + if protocol == "https": + print(f"HTTPS: Enabled (microphone and navigation available)") + else: + print(f"HTTPS: Disabled (use default or remove --no-https for full features)") + print(f"{'='*50}\n") + + try: + if ssl_context: + app.run(host='0.0.0.0', port=args.port, threaded=True, debug=False, ssl_context=ssl_context) + else: + app.run(host='0.0.0.0', port=args.port, threaded=True, debug=False) + finally: + cc.running = False + if cc.camera: + cc.camera.release() + + +if __name__ == "__main__": + main() diff --git a/examples/web_command_center/templates/index.html b/examples/web_command_center/templates/index.html new file mode 100644 index 00000000..cf931063 --- /dev/null +++ b/examples/web_command_center/templates/index.html @@ -0,0 +1,4006 @@ + + + + + + SAM3 Command Center + + + + + + +
+

SAM3 Command Center

+
+
+ FPS: + 0 +
+
+ Device: + - +
+
+ Objects: + 0 +
+
+ Tracked: + 0 +
+
+ Status: + Running +
+
+
+ +
+ +
+
+ Live Feed +
+ + +
+
+
+
+ Live camera feed + + +
+ + +
+
Controls
+
Camera
+
Voice Search
+
Reference Search
+
Features
+
+ + +
+
+
+ +
+ + +
+
+ +
+ +
+ + + +
+ +
+ + + +
+
+
+ + +
+
+
Camera Selection
+
+
+ + +
+
Select a camera
+
+
+ +
+
Flip / Mirror
+

+ Flip the camera feed if the image appears reversed +

+
+ + +
+
No flip applied
+
+ +
+
Camera Tips
+
    +
  • Use Flip Horizontal to mirror the image (for front-facing cameras)
  • +
  • Changing cameras resets all tracked objects
  • +
  • If a camera doesn't work, try refreshing the list
  • +
  • External cameras may take a moment to initialize
  • +
+
+
+ + +
+
+
Voice-to-AI Search
+

+ Click the microphone and say something like:
+ "Help me find a red car" or "Search for a person and a dog" +

+ +
+ +
+ Click microphone to start +
+
+ + +
+ +
+
Text-to-Speech (TTS)
+ +
+
+ Enable TTS Feedback + Speak search confirmations aloud +
+
+
+ +
+ + + +
+
+ +
+
Voice History
+
+
No voice searches yet
+
+
+
+ + +
+
+ +
+
Reference Image Search
+

+ Upload an image of an object to find similar objects in the live feed +

+ +
+ + +
+ + +
+ +
+ + +
+ + +
+
+ + +
+
Draw to Search
+

+ Draw a box around an object in the video to select it +

+ +
+ + + +
+ +
+ Click a button above, then draw on the video +
+
+ + +
+
Visual Match Settings
+ +
+ + +
+ Loose (0.5) + 0.75 + Strict (0.95) +
+
+
+
+
+ + +
+
+ +
+
Tracking
+ +
+
+ Optical Flow Tracking + Track masks between keyframes +
+
+
+ +
+
+ Memory Tracking + Store mask history for re-identification +
+
+
+
+ + +
+ +
+
+ Persistent Object IDs + Assign stable IDs to tracked objects +
+
+
+
+ + +
+
+ + +
+
Mask Refinement
+ +
+
+ Fill Holes + Fill small gaps in masks +
+
+
+
+ + +
+ +
+
+ Smooth Edges + Morphological smoothing of mask edges +
+
+
+
+ + +
+ +
+
+ Non-Overlapping Masks + Prevent mask overlaps (higher conf wins) +
+
+
+
+ + +
+
Detection Controls
+ +
+
+ Boundary Suppression + Ignore detections near frame edges +
+
+
+
+ + +
+ +
+
+ Occlusion Suppression + Remove heavily overlapped detections +
+
+
+
+ + +
+ +
+
+ Hotstart Mode + Require N frames before confirming detection +
+
+
+
+ + +
+
+ + +
+
YOLO Integration (v12)
+ +
+
+ YOLO Classification + Run YOLOv12 classification on detected regions +
+
+
+
+ + +
+
+ + +
+ +
+
+ YOLO Pose Estimation + Detect body keypoints for person-like objects +
+
+
+
+ + +
+ +
+
+ Show Skeleton + Draw skeleton lines connecting keypoints +
+
+
+
+ + +
+ +
+
+ Show Keypoint Labels + Display names for each detected keypoint +
+
+
+
+ + +
+ +
+
+ Label Spoofing + Map SAM3 labels to COCO classes for YOLO +
+
+
+
+
+
+
+
+ + + +
+ + + + diff --git a/sam3/model/decoder.py b/sam3/model/decoder.py index c8b1657e..3b0ded8b 100644 --- a/sam3/model/decoder.py +++ b/sam3/model/decoder.py @@ -11,6 +11,7 @@ import torch from sam3.sam.transformer import RoPEAttention +from sam3.utils.device import get_device from torch import nn, Tensor from torchvision.ops.roi_align import RoIAlign @@ -278,7 +279,7 @@ def __init__( if resolution is not None and stride is not None: feat_size = resolution // stride coords_h, coords_w = self._get_coords( - feat_size, feat_size, device="cuda" + feat_size, feat_size, device=get_device() ) self.compilable_cord_cache = (coords_h, coords_w) self.compilable_stored_size = (feat_size, feat_size) diff --git a/sam3/model/edt.py b/sam3/model/edt.py index 9448c1d3..65b0d4cf 100644 --- a/sam3/model/edt.py +++ b/sam3/model/edt.py @@ -1,10 +1,17 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved -"""Triton kernel for euclidean distance transform (EDT)""" +"""Euclidean distance transform (EDT) with optional Triton kernel acceleration for CUDA devices.""" import torch -import triton -import triton.language as tl + +# Try to import Triton (only available on CUDA) +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False """ Disclaimer: This implementation is not meant to be extremely efficient. A CUDA kernel would likely be more efficient. @@ -50,74 +57,193 @@ """ -@triton.jit -def edt_kernel(inputs_ptr, outputs_ptr, v, z, height, width, horizontal: tl.constexpr): - # This is a somewhat verbatim implementation of the efficient 1D EDT algorithm described above - # It can be applied horizontally or vertically depending if we're doing the first or second stage. - # It's parallelized across batch+row (or batch+col if horizontal=False) - # TODO: perhaps the implementation can be revisited if/when local gather/scatter become available in triton - batch_id = tl.program_id(axis=0) - if horizontal: - row_id = tl.program_id(axis=1) - block_start = (batch_id * height * width) + row_id * width - length = width - stride = 1 - else: - col_id = tl.program_id(axis=1) - block_start = (batch_id * height * width) + col_id - length = height - stride = width - - # This will be the index of the right most parabola in the envelope ("the top of the stack") - k = 0 - for q in range(1, length): - # Read the function value at the current location. Note that we're doing a singular read, not very efficient - cur_input = tl.load(inputs_ptr + block_start + (q * stride)) - # location of the parabola on top of the stack - r = tl.load(v + block_start + (k * stride)) - # associated boundary - z_k = tl.load(z + block_start + (k * stride)) - # value of the function at the parabola location - previous_input = tl.load(inputs_ptr + block_start + (r * stride)) - # intersection between the two parabolas - s = (cur_input - previous_input + q * q - r * r) / (q - r) / 2 - - # we'll pop as many parabolas as required - while s <= z_k and k - 1 >= 0: - k = k - 1 +# ============================================================================ +# PyTorch-based implementations (for CPU, MPS, and fallback) +# ============================================================================ + + +def edt_pytorch(data: torch.Tensor) -> torch.Tensor: + """ + Computes the Euclidean Distance Transform (EDT) of a batch of binary images using scipy. + + This is a fallback implementation for non-CUDA devices. It processes each image + in the batch individually using scipy's distance_transform_edt. + + Args: + data: A tensor of shape (B, H, W) representing a batch of binary images. + + Returns: + A tensor of the same shape as data containing the EDT. + It should be equivalent to a batched version of cv2.distanceTransform(input, cv2.DIST_L2, 0) + """ + from scipy.ndimage import distance_transform_edt + + assert data.dim() == 3, "Input tensor must have shape (B, H, W)" + + device = data.device + dtype = data.dtype + B, H, W = data.shape + + # Convert to numpy for scipy processing + data_np = data.cpu().numpy() + + # Allocate output + output_np = data_np.copy().astype("float32") + + # Process each image in the batch + for b in range(B): + # scipy's distance_transform_edt computes distance to nearest zero pixel + # We need to invert the mask because scipy computes distance to zero + # If data[i,j] == 0, EDT should be 0; otherwise distance to nearest 0 + mask = data_np[b] != 0 + output_np[b] = distance_transform_edt(mask) + + # Convert back to tensor and move to original device + output = torch.from_numpy(output_np).to(device=device, dtype=dtype) + return output + + +# ============================================================================ +# Triton-based implementations (CUDA only) +# ============================================================================ + +if HAS_TRITON: + + @triton.jit + def edt_kernel( + inputs_ptr, outputs_ptr, v, z, height, width, horizontal: tl.constexpr + ): + # This is a somewhat verbatim implementation of the efficient 1D EDT algorithm described above + # It can be applied horizontally or vertically depending if we're doing the first or second stage. + # It's parallelized across batch+row (or batch+col if horizontal=False) + # TODO: perhaps the implementation can be revisited if/when local gather/scatter become available in triton + batch_id = tl.program_id(axis=0) + if horizontal: + row_id = tl.program_id(axis=1) + block_start = (batch_id * height * width) + row_id * width + length = width + stride = 1 + else: + col_id = tl.program_id(axis=1) + block_start = (batch_id * height * width) + col_id + length = height + stride = width + + # This will be the index of the right most parabola in the envelope ("the top of the stack") + k = 0 + for q in range(1, length): + # Read the function value at the current location. Note that we're doing a singular read, not very efficient + cur_input = tl.load(inputs_ptr + block_start + (q * stride)) + # location of the parabola on top of the stack r = tl.load(v + block_start + (k * stride)) + # associated boundary z_k = tl.load(z + block_start + (k * stride)) + # value of the function at the parabola location previous_input = tl.load(inputs_ptr + block_start + (r * stride)) + # intersection between the two parabolas s = (cur_input - previous_input + q * q - r * r) / (q - r) / 2 - # Store the new one - k = k + 1 - tl.store(v + block_start + (k * stride), q) - tl.store(z + block_start + (k * stride), s) - if k + 1 < length: - tl.store(z + block_start + ((k + 1) * stride), 1e9) - - # Last step, we read the envelope to find the min in every location - k = 0 - for q in range(length): - while ( - k + 1 < length - and tl.load( - z + block_start + ((k + 1) * stride), mask=(k + 1) < length, other=q - ) - < q - ): - k += 1 - r = tl.load(v + block_start + (k * stride)) - d = q - r - old_value = tl.load(inputs_ptr + block_start + (r * stride)) - tl.store(outputs_ptr + block_start + (q * stride), old_value + d * d) - - -def edt_triton(data: torch.Tensor): + # we'll pop as many parabolas as required + while s <= z_k and k - 1 >= 0: + k = k - 1 + r = tl.load(v + block_start + (k * stride)) + z_k = tl.load(z + block_start + (k * stride)) + previous_input = tl.load(inputs_ptr + block_start + (r * stride)) + s = (cur_input - previous_input + q * q - r * r) / (q - r) / 2 + + # Store the new one + k = k + 1 + tl.store(v + block_start + (k * stride), q) + tl.store(z + block_start + (k * stride), s) + if k + 1 < length: + tl.store(z + block_start + ((k + 1) * stride), 1e9) + + # Last step, we read the envelope to find the min in every location + k = 0 + for q in range(length): + while ( + k + 1 < length + and tl.load( + z + block_start + ((k + 1) * stride), mask=(k + 1) < length, other=q + ) + < q + ): + k += 1 + r = tl.load(v + block_start + (k * stride)) + d = q - r + old_value = tl.load(inputs_ptr + block_start + (r * stride)) + tl.store(outputs_ptr + block_start + (q * stride), old_value + d * d) + + def edt_triton_impl(data: torch.Tensor) -> torch.Tensor: + """ + Computes the Euclidean Distance Transform (EDT) of a batch of binary images using Triton. + + Args: + data: A tensor of shape (B, H, W) representing a batch of binary images. + + Returns: + A tensor of the same shape as data containing the EDT. + It should be equivalent to a batched version of cv2.distanceTransform(input, cv2.DIST_L2, 0) + """ + assert data.dim() == 3 + assert data.is_cuda + B, H, W = data.shape + data = data.contiguous() + + # Allocate the "function" tensor. Implicitly the function is 0 if data[i,j]==0 else +infinity + output = torch.where(data, 1e18, 0.0) + assert output.is_contiguous() + + # Scratch tensors for the parabola stacks + parabola_loc = torch.zeros(B, H, W, dtype=torch.uint32, device=data.device) + parabola_inter = torch.empty(B, H, W, dtype=torch.float, device=data.device) + parabola_inter[:, :, 0] = -1e18 + parabola_inter[:, :, 1] = 1e18 + + # Grid size (number of blocks) + grid = (B, H) + + # Launch initialization kernel + edt_kernel[grid]( + output.clone(), + output, + parabola_loc, + parabola_inter, + H, + W, + horizontal=True, + ) + + # reset the parabola stacks + parabola_loc.zero_() + parabola_inter[:, :, 0] = -1e18 + parabola_inter[:, :, 1] = 1e18 + + grid = (B, W) + edt_kernel[grid]( + output.clone(), + output, + parabola_loc, + parabola_inter, + H, + W, + horizontal=False, + ) + # don't forget to take sqrt at the end + return output.sqrt() + + +# ============================================================================ +# Public API - automatically selects best implementation +# ============================================================================ + + +def edt(data: torch.Tensor) -> torch.Tensor: """ Computes the Euclidean Distance Transform (EDT) of a batch of binary images. + Uses Triton kernel on CUDA when available, falls back to scipy otherwise. + Args: data: A tensor of shape (B, H, W) representing a batch of binary images. @@ -125,49 +251,11 @@ def edt_triton(data: torch.Tensor): A tensor of the same shape as data containing the EDT. It should be equivalent to a batched version of cv2.distanceTransform(input, cv2.DIST_L2, 0) """ - assert data.dim() == 3 - assert data.is_cuda - B, H, W = data.shape - data = data.contiguous() - - # Allocate the "function" tensor. Implicitly the function is 0 if data[i,j]==0 else +infinity - output = torch.where(data, 1e18, 0.0) - assert output.is_contiguous() - - # Scratch tensors for the parabola stacks - parabola_loc = torch.zeros(B, H, W, dtype=torch.uint32, device=data.device) - parabola_inter = torch.empty(B, H, W, dtype=torch.float, device=data.device) - parabola_inter[:, :, 0] = -1e18 - parabola_inter[:, :, 1] = 1e18 - - # Grid size (number of blocks) - grid = (B, H) - - # Launch initialization kernel - edt_kernel[grid]( - output.clone(), - output, - parabola_loc, - parabola_inter, - H, - W, - horizontal=True, - ) - - # reset the parabola stacks - parabola_loc.zero_() - parabola_inter[:, :, 0] = -1e18 - parabola_inter[:, :, 1] = 1e18 - - grid = (B, W) - edt_kernel[grid]( - output.clone(), - output, - parabola_loc, - parabola_inter, - H, - W, - horizontal=False, - ) - # don't forget to take sqrt at the end - return output.sqrt() + if HAS_TRITON and data.is_cuda: + return edt_triton_impl(data) + else: + return edt_pytorch(data) + + +# Legacy alias for backward compatibility +edt_triton = edt diff --git a/sam3/model/geometry_encoders.py b/sam3/model/geometry_encoders.py index bff29172..1a2aa349 100644 --- a/sam3/model/geometry_encoders.py +++ b/sam3/model/geometry_encoders.py @@ -4,10 +4,28 @@ import torch import torch.nn as nn +import torch.nn.functional as F import torchvision from typing_extensions import override from .act_ckpt_utils import activation_ckpt_wrapper + + +def _grid_sample_mps_safe(input, grid, **kwargs): + """ + MPS-safe wrapper for grid_sample. + MPS has bugs with grid_sample on certain tensor configurations, + so we fall back to CPU for MPS devices. + """ + if input.device.type == "mps": + # Move to CPU, perform operation, move back + input_cpu = input.cpu() + grid_cpu = grid.cpu() + result = F.grid_sample(input_cpu, grid_cpu, **kwargs) + return result.to(input.device) + return F.grid_sample(input, grid, **kwargs) + + from .box_ops import box_cxcywh_to_xyxy from .model_misc import get_clones @@ -44,8 +62,13 @@ def concat_padded_sequences(seq1, mask1, seq2, mask2, return_index: bool = False assert seq1_length == mask1.size(1) assert seq2_length == mask2.size(1) - torch._assert_async(is_right_padded(mask1)) - torch._assert_async(is_right_padded(mask2)) + # _assert_async is not supported on MPS, use regular assert + if mask1.device.type == "mps" or mask2.device.type == "mps": + assert is_right_padded(mask1), "mask1 must be right padded" + assert is_right_padded(mask2), "mask2 must be right padded" + else: + torch._assert_async(is_right_padded(mask1)) + torch._assert_async(is_right_padded(mask2)) actual_seq1_lengths = (~mask1).sum(dim=-1) actual_seq2_lengths = (~mask2).sum(dim=-1) @@ -613,7 +636,7 @@ def _encode_points(self, points, points_mask, points_labels, img_feats): grid = points.transpose(0, 1).unsqueeze(2) # re normalize to [-1, 1] grid = (grid * 2) - 1 - sampled = torch.nn.functional.grid_sample( + sampled = _grid_sample_mps_safe( img_feats, grid, align_corners=False ) assert list(sampled.shape) == [bs, self.d_model, n_points, 1] @@ -656,11 +679,16 @@ def _encode_boxes(self, boxes, boxes_mask, boxes_labels, img_feats): # We need to denormalize, and convert to [x, y, x, y] boxes_xyxy = box_cxcywh_to_xyxy(boxes) scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype) - scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True) + # pin_memory() only works with CUDA, not MPS + if boxes_xyxy.device.type == "cuda": + scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True) + else: + scale = scale.to(device=boxes_xyxy.device) scale = scale.view(1, 1, 4) boxes_xyxy = boxes_xyxy * scale + # Match boxes dtype to img_feats dtype for roi_align (needed for half precision) sampled = torchvision.ops.roi_align( - img_feats, boxes_xyxy.float().transpose(0, 1).unbind(0), self.roi_size + img_feats, boxes_xyxy.to(img_feats.dtype).transpose(0, 1).unbind(0), self.roi_size ) assert list(sampled.shape) == [ bs * n_boxes, diff --git a/sam3/model/io_utils.py b/sam3/model/io_utils.py index 0a225842..1691911b 100644 --- a/sam3/model/io_utils.py +++ b/sam3/model/io_utils.py @@ -15,6 +15,7 @@ from PIL import Image from sam3.logger import get_logger +from sam3.utils.device import get_device from tqdm import tqdm logger = get_logger(__name__) @@ -63,7 +64,7 @@ def load_resource_as_video_frames( images.append(img) images = torch.stack(images) if not offload_video_to_cpu: - images = images.cuda() + images = images.to(get_device()) return images, orig_height, orig_width is_image = ( @@ -104,9 +105,10 @@ def load_image_as_single_frame_video( img_mean = torch.tensor(img_mean, dtype=torch.float16)[:, None, None] img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None] if not offload_video_to_cpu: - images = images.cuda() - img_mean = img_mean.cuda() - img_std = img_std.cuda() + device = get_device() + images = images.to(device) + img_mean = img_mean.to(device) + img_std = img_std.to(device) # normalize by mean and std images -= img_mean images /= img_std @@ -201,9 +203,10 @@ def load_video_frames_from_image_folder( ): images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) if not offload_video_to_cpu: - images = images.cuda() - img_mean = img_mean.cuda() - img_std = img_std.cuda() + device = get_device() + images = images.to(device) + img_mean = img_mean.to(device) + img_std = img_std.to(device) # normalize by mean and std images -= img_mean images /= img_std @@ -307,9 +310,10 @@ def load_video_frames_from_video_file_using_cv2( img_mean = torch.tensor(img_mean, dtype=torch.float16).view(1, 3, 1, 1) img_std = torch.tensor(img_std, dtype=torch.float16).view(1, 3, 1, 1) if not offload_video_to_cpu: - video_tensor = video_tensor.cuda() - img_mean = img_mean.cuda() - img_std = img_std.cuda() + device = get_device() + video_tensor = video_tensor.to(device) + img_mean = img_mean.to(device) + img_std = img_std.to(device) # normalize by mean and std video_tensor -= img_mean video_tensor /= img_std @@ -323,7 +327,7 @@ def load_dummy_video(image_size, offload_video_to_cpu, num_frames=60): video_height, video_width = 480, 640 # dummy original video sizes images = torch.randn(num_frames, 3, image_size, image_size, dtype=torch.float16) if not offload_video_to_cpu: - images = images.cuda() + images = images.to(get_device()) return images, video_height, video_width @@ -392,7 +396,7 @@ def __getitem__(self, index): img -= self.img_mean img /= self.img_std if not self.offload_video_to_cpu: - img = img.cuda() + img = img.to(get_device()) self.images[index] = img return img @@ -503,16 +507,33 @@ def __init__( use_rand_seek_in_loading=False, ): # Check and possibly infer the output device (and also get its GPU id when applicable) - assert gpu_device is None or gpu_device.type == "cuda" - gpu_id = ( - gpu_device.index - if gpu_device is not None and gpu_device.index is not None - else torch.cuda.current_device() - ) + # For MPS devices, we disable GPU acceleration since TorchCodec doesn't support it + default_device = get_device() + is_mps = default_device.type == "mps" + + if gpu_device is not None: + assert gpu_device.type in ("cuda", "mps", "cpu"), f"Unsupported device type: {gpu_device.type}" + + # Disable GPU acceleration for non-CUDA devices + if is_mps or (gpu_device is not None and gpu_device.type != "cuda"): + gpu_acceleration = False + + gpu_id = 0 + if torch.cuda.is_available(): + gpu_id = ( + gpu_device.index + if gpu_device is not None and gpu_device.type == "cuda" and gpu_device.index is not None + else torch.cuda.current_device() + ) + if offload_video_to_cpu: out_device = torch.device("cpu") else: - out_device = torch.device("cuda") if gpu_device is None else gpu_device + if gpu_device is not None: + out_device = gpu_device + else: + out_device = default_device + self.out_device = out_device self.gpu_acceleration = gpu_acceleration self.gpu_id = gpu_id @@ -525,7 +546,7 @@ def __init__( img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None] self.img_std = img_std - if gpu_acceleration: + if gpu_acceleration and torch.cuda.is_available(): self.img_mean = self.img_mean.to(f"cuda:{self.gpu_id}") self.img_std = self.img_std.to(f"cuda:{self.gpu_id}") decoder_option = {"device": f"cuda:{self.gpu_id}"} diff --git a/sam3/model/position_encoding.py b/sam3/model/position_encoding.py index eb3f4055..2efbb5d1 100644 --- a/sam3/model/position_encoding.py +++ b/sam3/model/position_encoding.py @@ -6,6 +6,8 @@ import torch from torch import nn +from sam3.utils.device import get_device + class PositionEmbeddingSine(nn.Module): """ @@ -44,7 +46,7 @@ def __init__( (precompute_resolution // 32, precompute_resolution // 32), ] for size in precompute_sizes: - tensors = torch.zeros((1, 1) + size, device="cuda") + tensors = torch.zeros((1, 1) + size, device=get_device()) self.forward(tensors) # further clone and detach it in the cache (just to be safe) self.cache[size] = self.cache[size].clone().detach() diff --git a/sam3/model/sam3_image.py b/sam3/model/sam3_image.py index aafe520b..db961e2a 100644 --- a/sam3/model/sam3_image.py +++ b/sam3/model/sam3_image.py @@ -122,7 +122,11 @@ def _get_img_feats(self, backbone_out, img_ids): # If this assert fails, it likely means we're requesting different img_ids (perhaps a different frame?) # We currently don't expect this to happen. We could technically trigger a recompute here, # but likely at the cost of a cpu<->gpu sync point, which would deteriorate perf - torch._assert_async((img_ids >= 0).all()) + # _assert_async is not supported on MPS + if img_ids.device.type == "mps": + assert (img_ids >= 0).all(), "img_ids must be non-negative" + else: + torch._assert_async((img_ids >= 0).all()) vis_feats = backbone_out["backbone_fpn"][-self.num_feature_levels :] vis_pos_enc = backbone_out["vision_pos_enc"][-self.num_feature_levels :] diff --git a/sam3/model/sam3_image_processor.py b/sam3/model/sam3_image_processor.py index 4d98fbfb..82f410c0 100644 --- a/sam3/model/sam3_image_processor.py +++ b/sam3/model/sam3_image_processor.py @@ -8,13 +8,16 @@ from sam3.model import box_ops from sam3.model.data_misc import FindStage, interpolate +from sam3.utils.device import get_device_str from torchvision.transforms import v2 class Sam3Processor: """ """ - def __init__(self, model, resolution=1008, device="cuda", confidence_threshold=0.5): + def __init__(self, model, resolution=1008, device=None, confidence_threshold=0.5): + if device is None: + device = get_device_str() self.model = model self.resolution = resolution self.device = device @@ -54,6 +57,11 @@ def set_image(self, image, state=None): image = v2.functional.to_image(image).to(self.device) image = self.transform(image).unsqueeze(0) + # Match model dtype (for half precision support) + model_dtype = next(self.model.parameters()).dtype + if image.dtype != model_dtype: + image = image.to(model_dtype) + state["original_height"] = height state["original_width"] = width state["backbone_out"] = self.model.backbone.forward_image(image) @@ -93,6 +101,12 @@ def set_image_batch(self, images: List[np.ndarray], state=None): for image in images ] images = torch.stack(images, dim=0) + + # Match model dtype (for half precision support) + model_dtype = next(self.model.parameters()).dtype + if images.dtype != model_dtype: + images = images.to(model_dtype) + state["backbone_out"] = self.model.backbone.forward_image(images) inst_interactivity_en = self.model.inst_interactive_predictor is not None if inst_interactivity_en and "sam2_backbone_out" in state["backbone_out"]: diff --git a/sam3/model/sam3_tracker_base.py b/sam3/model/sam3_tracker_base.py index 90fbd696..94952834 100644 --- a/sam3/model/sam3_tracker_base.py +++ b/sam3/model/sam3_tracker_base.py @@ -164,10 +164,12 @@ def _get_tpos_enc(self, rel_pos_list, device, max_abs_pos=None, dummy=False): return torch.zeros(len(rel_pos_list), self.mem_dim, device=device) t_diff_max = max_abs_pos - 1 if max_abs_pos is not None else 1 - pos_enc = ( - torch.tensor(rel_pos_list).pin_memory().to(device=device, non_blocking=True) - / t_diff_max - ) + # pin_memory() only works with CUDA, not MPS + rel_pos_tensor = torch.tensor(rel_pos_list) + if device.type == "cuda" if isinstance(device, torch.device) else device == "cuda": + pos_enc = rel_pos_tensor.pin_memory().to(device=device, non_blocking=True) / t_diff_max + else: + pos_enc = rel_pos_tensor.to(device=device) / t_diff_max tpos_dim = self.hidden_dim pos_enc = get_1d_sine_pe(pos_enc, dim=tpos_dim) pos_enc = self.obj_ptr_tpos_proj(pos_enc) @@ -653,15 +655,15 @@ def _prepare_memory_conditioned_features( if prev is None: continue # skip padding frames # "maskmem_features" might have been offloaded to CPU in demo use cases, - # so we load it back to GPU (it's a no-op if it's already on GPU). - feats = prev["maskmem_features"].cuda(non_blocking=True) + # so we load it back to the model's device (it's a no-op if it's already there). + feats = prev["maskmem_features"].to(device, non_blocking=True) seq_len = feats.shape[-2] * feats.shape[-1] to_cat_prompt.append(feats.flatten(2).permute(2, 0, 1)) to_cat_prompt_mask.append( torch.zeros(B, seq_len, device=device, dtype=bool) ) # Spatial positional encoding (it might have been offloaded to CPU in eval) - maskmem_enc = prev["maskmem_pos_enc"][-1].cuda() + maskmem_enc = prev["maskmem_pos_enc"][-1].to(device) maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) if ( diff --git a/sam3/model/sam3_tracking_predictor.py b/sam3/model/sam3_tracking_predictor.py index b2440ef6..a5a27bb5 100644 --- a/sam3/model/sam3_tracking_predictor.py +++ b/sam3/model/sam3_tracking_predictor.py @@ -46,8 +46,16 @@ def __init__( self.max_point_num_in_prompt_enc = max_point_num_in_prompt_enc self.non_overlap_masks_for_output = non_overlap_masks_for_output - self.bf16_context = torch.autocast(device_type="cuda", dtype=torch.bfloat16) - self.bf16_context.__enter__() # keep using for the entire model process + # Set up autocast context based on device type + # MPS doesn't support bfloat16, so we skip autocast on non-CUDA devices + device_type = getattr(self, 'device', torch.device('cpu')) + if hasattr(device_type, 'type'): + device_type = device_type.type + if device_type == "cuda": + self.bf16_context = torch.autocast(device_type="cuda", dtype=torch.bfloat16) + self.bf16_context.__enter__() # keep using for the entire model process + else: + self.bf16_context = None # No autocast for MPS/CPU self.iter_use_prev_mask_pred = True self.add_all_frames_to_correct_as_cond = True @@ -78,7 +86,8 @@ def init_state( if offload_state_to_cpu: inference_state["storage_device"] = torch.device("cpu") else: - inference_state["storage_device"] = torch.device("cuda") + # Use the actual device (cuda, mps, or cpu) instead of hardcoded cuda + inference_state["storage_device"] = self.device if video_path is not None: images, video_height, video_width = load_video_frames( @@ -300,7 +309,12 @@ def add_new_points_or_box( prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) if prev_out is not None and prev_out["pred_masks"] is not None: - prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True) + device = inference_state["device"] + # Use device-agnostic transfer (cuda, mps, or cpu) + if device.type == "cuda": + prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True) + else: + prev_sam_mask_logits = prev_out["pred_masks"].to(device) # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) current_out, _ = self._run_single_frame_inference( @@ -1021,7 +1035,8 @@ def _get_image_feature(self, inference_state, frame_idx, batch_size): ) else: # Cache miss -- we will run inference on a single image - image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0) + device = inference_state["device"] + image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0) backbone_out = self.forward_image(image) # Cache the most recent frame's feature (for repeated interactions with # a frame; we can use an LRU cache for more frames in the future). diff --git a/sam3/model/sam3_video_inference.py b/sam3/model/sam3_video_inference.py index 7fb87d01..8e1b71e1 100644 --- a/sam3/model/sam3_video_inference.py +++ b/sam3/model/sam3_video_inference.py @@ -477,9 +477,13 @@ def _postprocess_output( # slice those valid entries from the original outputs keep_idx = torch.nonzero(keep, as_tuple=True)[0] - keep_idx_gpu = keep_idx.pin_memory().to( - device=out_binary_masks.device, non_blocking=True - ) + # pin_memory() only works with CUDA, not MPS + if out_binary_masks.device.type == "cuda": + keep_idx_gpu = keep_idx.pin_memory().to( + device=out_binary_masks.device, non_blocking=True + ) + else: + keep_idx_gpu = keep_idx.to(device=out_binary_masks.device) out_obj_ids = torch.index_select(out_obj_ids, 0, keep_idx) out_probs = torch.index_select(out_probs, 0, keep_idx) diff --git a/sam3/model/sam3_video_predictor.py b/sam3/model/sam3_video_predictor.py index c639e1d0..ccd7a009 100644 --- a/sam3/model/sam3_video_predictor.py +++ b/sam3/model/sam3_video_predictor.py @@ -16,6 +16,7 @@ import torch from sam3.logger import get_logger +from sam3.utils.device import get_device logger = get_logger(__name__) @@ -48,7 +49,7 @@ def __init__( strict_state_dict_loading=strict_state_dict_loading, apply_temporal_disambiguation=apply_temporal_disambiguation, ) - .cuda() + .to(get_device()) .eval() ) @@ -275,11 +276,17 @@ def _get_session_stats(self): return session_stats_str def _get_torch_and_gpu_properties(self): - """Get a string for PyTorch and GPU properties (for logging and debugging).""" - torch_and_gpu_str = ( - f"torch: {torch.__version__} with CUDA arch {torch.cuda.get_arch_list()}, " - f"GPU device: {torch.cuda.get_device_properties(torch.cuda.current_device())}" - ) + """Get a string for PyTorch and device properties (for logging and debugging).""" + device = get_device() + if device.type == "cuda": + torch_and_gpu_str = ( + f"torch: {torch.__version__} with CUDA arch {torch.cuda.get_arch_list()}, " + f"GPU device: {torch.cuda.get_device_properties(torch.cuda.current_device())}" + ) + elif device.type == "mps": + torch_and_gpu_str = f"torch: {torch.__version__} with MPS (Apple Silicon)" + else: + torch_and_gpu_str = f"torch: {torch.__version__} on CPU" return torch_and_gpu_str def shutdown(self): @@ -428,7 +435,8 @@ def _start_nccl_process_group(self): device_id=self.device, ) # warm-up the NCCL process group by running a dummy all-reduce - tensor = torch.ones(1024, 1024).cuda() + # Note: NCCL backend requires CUDA tensors + tensor = torch.ones(1024, 1024, device=self.device) torch.distributed.all_reduce(tensor) logger.debug(f"started NCCL process group on {rank=} with {world_size=}") diff --git a/sam3/model/vl_combiner.py b/sam3/model/vl_combiner.py index 43bc7bd5..ae8bc405 100644 --- a/sam3/model/vl_combiner.py +++ b/sam3/model/vl_combiner.py @@ -10,6 +10,7 @@ from torch.nn.attention import sdpa_kernel, SDPBackend +from sam3.utils.device import get_device_str from .act_ckpt_utils import activation_ckpt_wrapper from .necks import Sam3DualViTDetNeck @@ -119,8 +120,10 @@ def _forward_image_no_act_ckpt(self, samples): return output def forward_text( - self, captions, input_boxes=None, additional_text=None, device="cuda" + self, captions, input_boxes=None, additional_text=None, device=None ): + if device is None: + device = get_device_str() return activation_ckpt_wrapper(self._forward_text_no_ack_ckpt)( captions=captions, input_boxes=input_boxes, @@ -134,8 +137,10 @@ def _forward_text_no_ack_ckpt( captions, input_boxes=None, additional_text=None, - device="cuda", + device=None, ): + if device is None: + device = get_device_str() output = {} # Forward through text_encoder diff --git a/sam3/model_builder.py b/sam3/model_builder.py index 1a3bdecf..3d588ffb 100644 --- a/sam3/model_builder.py +++ b/sam3/model_builder.py @@ -44,17 +44,11 @@ from sam3.sam.transformer import RoPEAttention -# Setup TensorFloat-32 for Ampere GPUs if available -def _setup_tf32() -> None: - """Enable TensorFloat-32 for Ampere GPUs if available.""" - if torch.cuda.is_available(): - device_props = torch.cuda.get_device_properties(0) - if device_props.major >= 8: - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True +# Import device utilities +from sam3.utils.device import get_device_str, setup_device_optimizations - -_setup_tf32() +# Setup device-specific optimizations (TF32 for Ampere GPUs, etc.) +setup_device_optimizations() def _create_position_encoding(precompute_resolution=None): @@ -549,8 +543,7 @@ def _load_checkpoint(model, checkpoint_path): def _setup_device_and_mode(model, device, eval_mode): """Setup model device and evaluation mode.""" - if device == "cuda": - model = model.cuda() + model = model.to(device=device) if eval_mode: model.eval() return model @@ -558,7 +551,7 @@ def _setup_device_and_mode(model, device, eval_mode): def build_sam3_image_model( bpe_path=None, - device="cuda" if torch.cuda.is_available() else "cpu", + device=None, # Will use get_device_str() if None eval_mode=True, checkpoint_path=None, load_from_HF=True, @@ -571,7 +564,7 @@ def build_sam3_image_model( Args: bpe_path: Path to the BPE tokenizer vocabulary - device: Device to load the model on ('cuda' or 'cpu') + device: Device to load the model on ('cuda', 'mps', or 'cpu'). If None, auto-detects best available device. eval_mode: Whether to set the model to evaluation mode checkpoint_path: Optional path to model checkpoint enable_segmentation: Whether to enable segmentation head @@ -586,6 +579,10 @@ def build_sam3_image_model( "sam3", "assets/bpe_simple_vocab_16e6.txt.gz" ) + # Set default device if not specified + if device is None: + device = get_device_str() + # Create visual components compile_mode = "default" if compile else None vision_encoder = _create_vision_backbone( @@ -657,7 +654,7 @@ def build_sam3_video_model( geo_encoder_use_img_cross_attn: bool = True, strict_state_dict_loading: bool = True, apply_temporal_disambiguation: bool = True, - device="cuda" if torch.cuda.is_available() else "cpu", + device=None, # Will use get_device_str() if None compile=False, ) -> Sam3VideoInferenceWithInstanceInteractivity: """ @@ -675,6 +672,10 @@ def build_sam3_video_model( "sam3", "assets/bpe_simple_vocab_16e6.txt.gz" ) + # Set default device if not specified + if device is None: + device = get_device_str() + # Build Tracker module tracker = build_tracker(apply_temporal_disambiguation=apply_temporal_disambiguation) diff --git a/sam3/perflib/connected_components.py b/sam3/perflib/connected_components.py index c96932a4..f212263b 100644 --- a/sam3/perflib/connected_components.py +++ b/sam3/perflib/connected_components.py @@ -54,6 +54,8 @@ def connected_components(input_tensor: torch.Tensor): """ Computes connected components labeling on a batch of 2D tensors, using the best available backend. + Supports CUDA (with optional Triton acceleration), MPS (Apple Silicon), and CPU. + Args: input_tensor (torch.Tensor): A BxHxW integer tensor or Bx1xHxW. Non-zero values are considered foreground. Bool tensor also accepted @@ -69,7 +71,10 @@ def connected_components(input_tensor: torch.Tensor): input_tensor.dim() == 4 and input_tensor.shape[1] == 1 ), "Input tensor must be (B, H, W) or (B, 1, H, W)." - if input_tensor.is_cuda: + # Check device type + device_type = input_tensor.device.type + + if device_type == "cuda": if HAS_CC_TORCH: return get_connected_components(input_tensor.to(torch.uint8)) else: @@ -80,5 +85,6 @@ def connected_components(input_tensor: torch.Tensor): return connected_components_triton(input_tensor) - # CPU fallback + # For MPS (Apple Silicon) and CPU, use the CPU implementation + # MPS tensors are handled in connected_components_cpu via .cpu() conversion return connected_components_cpu(input_tensor) diff --git a/sam3/perflib/nms.py b/sam3/perflib/nms.py index b3efc599..f50cb800 100644 --- a/sam3/perflib/nms.py +++ b/sam3/perflib/nms.py @@ -55,12 +55,18 @@ def nms_masks( def generic_nms( ious: torch.Tensor, scores: torch.Tensor, iou_threshold=0.5 ) -> torch.Tensor: - """A generic version of `torchvision.ops.nms` that takes a pairwise IoU matrix.""" + """A generic version of `torchvision.ops.nms` that takes a pairwise IoU matrix. + + Supports CUDA (with optional Triton acceleration), MPS (Apple Silicon), and CPU. + """ assert ious.dim() == 2 and ious.size(0) == ious.size(1) assert scores.dim() == 1 and scores.size(0) == ious.size(0) - if ious.is_cuda: + # Check device type + device_type = ious.device.type + + if device_type == "cuda": if GENERIC_NMS_AVAILABLE: return generic_nms_cuda(ious, scores, iou_threshold, use_iou_matrix=True) else: @@ -68,6 +74,8 @@ def generic_nms( return nms_triton(ious, scores, iou_threshold) + # For MPS (Apple Silicon) and CPU, use the CPU implementation + # MPS tensors need to be moved to CPU for numpy operations return generic_nms_cpu(ious, scores, iou_threshold) diff --git a/sam3/sam/transformer.py b/sam3/sam/transformer.py index 3e96c283..5d4a4ce9 100644 --- a/sam3/sam/transformer.py +++ b/sam3/sam/transformer.py @@ -252,9 +252,11 @@ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) ).transpose(1, 2) else: - torch.backends.cuda.enable_flash_sdp(True) - torch.backends.cuda.enable_math_sdp(True) - torch.backends.cuda.enable_mem_efficient_sdp(True) + # Only configure CUDA backends when on CUDA device + if q.is_cuda: + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_math_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(True) out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) out = self._recombine_heads(out) @@ -282,9 +284,9 @@ def __init__( self.compute_cis = partial( compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta ) - device = torch.device("cuda") if torch.cuda.is_available() else None + # Use None for device - will be set on first forward pass based on input tensor self.freqs_cis = self.compute_cis( - end_x=feat_sizes[0], end_y=feat_sizes[1], device=device + end_x=feat_sizes[0], end_y=feat_sizes[1], device=None ) if self.use_rope_real: self.freqs_cis_real = self.freqs_cis.real @@ -347,9 +349,11 @@ def forward( q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) ).transpose(1, 2) else: - torch.backends.cuda.enable_flash_sdp(True) - torch.backends.cuda.enable_math_sdp(True) - torch.backends.cuda.enable_mem_efficient_sdp(True) + # Only configure CUDA backends when on CUDA device + if q.is_cuda: + torch.backends.cuda.enable_flash_sdp(True) + torch.backends.cuda.enable_math_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(True) out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) out = self._recombine_heads(out) diff --git a/sam3/train/data/sam3_image_dataset.py b/sam3/train/data/sam3_image_dataset.py index 97efb1d1..f8b0d634 100644 --- a/sam3/train/data/sam3_image_dataset.py +++ b/sam3/train/data/sam3_image_dataset.py @@ -15,7 +15,7 @@ import torch import torch.utils.data import torchvision -from decord import cpu, VideoReader +# decord is imported lazily when needed for video loading from iopath.common.file_io import g_pathmgr from PIL import Image as PILImage @@ -202,6 +202,7 @@ def _load_images( try: if ".mp4" in path and path[-4:] == ".mp4": # Going to load a video frame + from decord import cpu, VideoReader video_path, frame = path.split("@") video = VideoReader(video_path, ctx=cpu(0)) # Convert to PIL image diff --git a/sam3/train/loss/sigmoid_focal_loss.py b/sam3/train/loss/sigmoid_focal_loss.py index 15e6db43..48f3b811 100644 --- a/sam3/train/loss/sigmoid_focal_loss.py +++ b/sam3/train/loss/sigmoid_focal_loss.py @@ -1,11 +1,19 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved -"""Triton kernel for faster and memory efficient sigmoid focal loss""" +"""Sigmoid focal loss with optional Triton kernel acceleration for CUDA devices.""" import torch -import triton -import triton.language as tl -from torch._inductor.runtime.triton_helpers import libdevice +import torch.nn.functional as F + +# Try to import Triton (only available on CUDA) +try: + import triton + import triton.language as tl + from torch._inductor.runtime.triton_helpers import libdevice + + HAS_TRITON = True +except ImportError: + HAS_TRITON = False """ @@ -32,290 +40,410 @@ """ -@triton.jit -def _inner_focal_loss_fwd(inputs, targets, alpha, gamma): - inv_targets = 1 - targets - # Sigmoid - sig = tl.sigmoid(inputs) - - # Binary cross entropy with logits - # In practice, we want the following: - # bce_loss = -targets * tl.log(sig) - (1 - targets) * tl.log(1 - sig) - # However, the above is not numerically stable. - # We're also not directly taking the sum here, so the usual log-sum-exp trick doesn't apply - # The bce can be reformulated, after algebraic manipulation, to - # bce_loss = log(1 + exp(-x)) + x * (1-y) - # This is still not stable, because for large (-x) the exponential will blow up. - # We'll use the following alternate formulation: - # bce_loss = max(x, 0) - x * y + log(1 + exp(-abs(x))) - # Let's show that it's equivalent: - # Case x>=0: abs(x) = x , max(x, 0) = x - # so we get x - x * y + log(1 + exp(-x)) which is equivalent - # Case x<0: abs(x) = -x, max(x, 0) = 0 - # we have log(1 + exp(-abs(x))) = log(1 + exp(x)) = log(exp(x)(1 + exp(-x))) = x+log(1 + exp(-x)) - # plugging it in, we get - # 0 - x * y + x + log(1 + exp(-x)), which is also equivalent - # Note that this is stable because now the exponent are guaranteed to be below 0. - max_val = tl.clamp(inputs, min=0, max=1e9) - bce_loss = max_val - inputs * targets + tl.log(1 + tl.exp(-tl.abs(inputs))) - - # Modulating factor - p_t = sig * targets + (1 - sig) * inv_targets - mod_factor = libdevice.pow(1 - p_t, gamma) - - # Alpha factor - alpha_t = alpha * targets + (1 - alpha) * inv_targets - - # Final loss calculation - return alpha_t * mod_factor * bce_loss - - -# Non-reduced version -@triton.jit -def sigmoid_focal_loss_fwd_kernel( - inputs_ptr, - targets_ptr, - loss_ptr, - alpha: float, - gamma: float, - n_elements: int, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offset = block_start + tl.arange(0, BLOCK_SIZE) - mask = offset < n_elements - - # Load data - inputs = tl.load(inputs_ptr + offset, mask=mask).to(tl.float32) - targets = tl.load(targets_ptr + offset, mask=mask) - - final_loss = _inner_focal_loss_fwd(inputs, targets, alpha, gamma) - - # Store result - tl.store(loss_ptr + offset, final_loss, mask=mask) - - -# version with reduction -@triton.jit -def sigmoid_focal_loss_fwd_kernel_reduce( - inputs_ptr, - targets_ptr, - loss_ptr, - alpha: float, - gamma: float, - n_elements: int, - BLOCK_SIZE: tl.constexpr, - REDUCE_SIZE: tl.constexpr, -): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - reduce_loc = pid % REDUCE_SIZE - offset = block_start + tl.arange(0, BLOCK_SIZE) - mask = offset < n_elements - # Load data - inputs = tl.load(inputs_ptr + offset, mask=mask).to(tl.float32) - targets = tl.load(targets_ptr + offset, mask=mask) - - final_loss = _inner_focal_loss_fwd(inputs, targets, alpha, gamma) * mask - - fl = tl.sum(final_loss) - - # Store result - tl.atomic_add(loss_ptr + reduce_loc, fl) - - -@triton.jit -def _inner_focal_loss_bwd(inputs, targets, alpha, gamma): - inv_targets = 1 - targets - - # Recompute forward - max_val = tl.clamp(inputs, min=0, max=1e9) - bce_loss = max_val - inputs * targets + tl.log(1 + tl.exp(-tl.abs(inputs))) - - # Sigmoid - sig = tl.sigmoid(inputs) - inv_sig = 1 - sig - - # Modulating factor - p_t = sig * targets + inv_sig * inv_targets - tmp = libdevice.pow(1 - p_t, gamma - 1) - mod_factor = tmp * (1 - p_t) - - # Alpha factor - alpha_t = alpha * targets + (1 - alpha) * inv_targets - - # Now computing the derivatives - d_pt = (2 * targets - 1) * sig * inv_sig - d_mod_factor = -gamma * d_pt * tmp - - d_bce_loss = sig - targets - - return alpha_t * (d_bce_loss * mod_factor + d_mod_factor * bce_loss) - - -@triton.jit -def sigmoid_focal_loss_bwd_kernel( - inputs_ptr, - targets_ptr, - grad_inputs_ptr, - grad_out_ptr, - alpha: float, - gamma: float, - n_elements: int, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offset = block_start + tl.arange(0, BLOCK_SIZE) - mask = offset < n_elements - input_ptrs = inputs_ptr + offset - target_ptrs = targets_ptr + offset - grad_input_ptrs = grad_inputs_ptr + offset - grad_out_ptrs = grad_out_ptr + offset - # Load data - inputs = tl.load(input_ptrs, mask=mask).to(tl.float32) - targets = tl.load(target_ptrs, mask=mask) - grad_out = tl.load(grad_out_ptrs, mask=mask) - d_loss = grad_out * _inner_focal_loss_bwd(inputs, targets, alpha, gamma) - tl.store(grad_input_ptrs, d_loss, mask=mask) - - -@triton.jit -def sigmoid_focal_loss_bwd_kernel_reduce( - inputs_ptr, - targets_ptr, - grad_inputs_ptr, - grad_out_ptr, - alpha: float, - gamma: float, - n_elements: int, - BLOCK_SIZE: tl.constexpr, -): - # The only difference is that the gradient is now a single scalar - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offset = block_start + tl.arange(0, BLOCK_SIZE) - mask = offset < n_elements - input_ptrs = inputs_ptr + offset - target_ptrs = targets_ptr + offset - grad_input_ptrs = grad_inputs_ptr + offset - # Load data - inputs = tl.load(input_ptrs, mask=mask).to(tl.float32) - targets = tl.load(target_ptrs, mask=mask) - grad_out = tl.load(grad_out_ptr) - d_loss = grad_out * _inner_focal_loss_bwd(inputs, targets, alpha, gamma) - tl.store(grad_input_ptrs, d_loss, mask=mask) - - -class SigmoidFocalLoss(torch.autograd.Function): - BLOCK_SIZE = 256 - - @staticmethod - def forward(ctx, inputs, targets, alpha=0.25, gamma=2): - n_elements = inputs.numel() - assert targets.numel() == n_elements - input_shape = inputs.shape - inputs = inputs.view(-1).contiguous() - targets = targets.view(-1).contiguous() - loss = torch.empty(inputs.shape, dtype=torch.float32, device=inputs.device) - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - sigmoid_focal_loss_fwd_kernel[grid]( - inputs, targets, loss, alpha, gamma, n_elements, SigmoidFocalLoss.BLOCK_SIZE - ) - ctx.save_for_backward(inputs.view(input_shape), targets.view(input_shape)) - ctx.alpha = alpha - ctx.gamma = gamma - return loss.view(input_shape) - - @staticmethod - def backward(ctx, grad_output): - inputs, targets = ctx.saved_tensors - alpha = ctx.alpha - gamma = ctx.gamma - n_elements = inputs.numel() - input_shape = inputs.shape - grad_inputs = torch.empty( - inputs.shape, dtype=grad_output.dtype, device=grad_output.device - ) - inputs_ptr = inputs.view(-1).contiguous() - targets_ptr = targets.view(-1).contiguous() - grad_output_ptr = grad_output.view(-1).contiguous() - grad_inputs_ptr = grad_inputs - assert grad_output.numel() == n_elements - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - sigmoid_focal_loss_bwd_kernel[grid]( - inputs_ptr, - targets_ptr, - grad_inputs_ptr, - grad_output_ptr, - alpha, - gamma, - n_elements, - SigmoidFocalLoss.BLOCK_SIZE, - ) - return grad_inputs.view(input_shape), None, None, None - - -triton_sigmoid_focal_loss = SigmoidFocalLoss.apply - - -class SigmoidFocalLossReduced(torch.autograd.Function): - BLOCK_SIZE = 256 - REDUCE_SIZE = 32 - - @staticmethod - def forward(ctx, inputs, targets, alpha=0.25, gamma=2): - n_elements = inputs.numel() - input_shape = inputs.shape - inputs = inputs.view(-1).contiguous() - targets = targets.view(-1).contiguous() - loss = torch.zeros( - SigmoidFocalLossReduced.REDUCE_SIZE, - device=inputs.device, - dtype=torch.float32, - ) - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - sigmoid_focal_loss_fwd_kernel_reduce[grid]( - inputs, - targets, - loss, - alpha, - gamma, - n_elements, - SigmoidFocalLossReduced.BLOCK_SIZE, - SigmoidFocalLossReduced.REDUCE_SIZE, - ) - ctx.save_for_backward(inputs.view(input_shape), targets.view(input_shape)) - ctx.alpha = alpha - ctx.gamma = gamma - return loss.sum() - - @staticmethod - def backward(ctx, grad_output): - inputs, targets = ctx.saved_tensors - alpha = ctx.alpha - gamma = ctx.gamma - n_elements = inputs.numel() - input_shape = inputs.shape - grad_inputs = torch.empty( - inputs.shape, dtype=grad_output.dtype, device=grad_output.device - ) - inputs_ptr = inputs.view(-1).contiguous() - targets_ptr = targets.reshape(-1).contiguous() - assert grad_output.numel() == 1 - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - sigmoid_focal_loss_bwd_kernel_reduce[grid]( - inputs_ptr, - targets_ptr, - grad_inputs, - grad_output, - alpha, - gamma, - n_elements, - SigmoidFocalLossReduced.BLOCK_SIZE, - ) - return grad_inputs.view(input_shape), None, None, None - - -triton_sigmoid_focal_loss_reduce = SigmoidFocalLossReduced.apply +# ============================================================================ +# PyTorch-based implementations (for CPU, MPS, and fallback) +# ============================================================================ + + +def sigmoid_focal_loss_pytorch( + inputs: torch.Tensor, + targets: torch.Tensor, + alpha: float = 0.25, + gamma: float = 2.0, +) -> torch.Tensor: + """ + Pure PyTorch implementation of sigmoid focal loss (no reduction). + + Args: + inputs: Tensor of any shape, containing logits + targets: Tensor of the same shape as inputs, containing float targets + alpha: Weighting factor in range (0,1) to balance positive vs negative examples + gamma: Exponent of the modulating factor (1 - p_t) ** gamma + + Returns: + Tensor of the same shape as inputs, containing the focal loss for each element + """ + # Compute sigmoid and BCE loss + prob = torch.sigmoid(inputs) + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") + + # Compute p_t and alpha_t + p_t = prob * targets + (1 - prob) * (1 - targets) + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + + # Compute focal loss + focal_weight = (1 - p_t) ** gamma + loss = alpha_t * focal_weight * ce_loss + + return loss + + +def sigmoid_focal_loss_reduced_pytorch( + inputs: torch.Tensor, + targets: torch.Tensor, + alpha: float = 0.25, + gamma: float = 2.0, +) -> torch.Tensor: + """ + Pure PyTorch implementation of sigmoid focal loss with sum reduction. + + Args: + inputs: Tensor of any shape, containing logits + targets: Tensor of the same shape as inputs, containing float targets + alpha: Weighting factor in range (0,1) to balance positive vs negative examples + gamma: Exponent of the modulating factor (1 - p_t) ** gamma + + Returns: + Scalar tensor containing the sum of focal losses + """ + return sigmoid_focal_loss_pytorch(inputs, targets, alpha, gamma).sum() + + +# ============================================================================ +# Triton-based implementations (CUDA only) +# ============================================================================ + +if HAS_TRITON: + + @triton.jit + def _inner_focal_loss_fwd(inputs, targets, alpha, gamma): + inv_targets = 1 - targets + # Sigmoid + sig = tl.sigmoid(inputs) + + # Binary cross entropy with logits + # In practice, we want the following: + # bce_loss = -targets * tl.log(sig) - (1 - targets) * tl.log(1 - sig) + # However, the above is not numerically stable. + # We're also not directly taking the sum here, so the usual log-sum-exp trick doesn't apply + # The bce can be reformulated, after algebraic manipulation, to + # bce_loss = log(1 + exp(-x)) + x * (1-y) + # This is still not stable, because for large (-x) the exponential will blow up. + # We'll use the following alternate formulation: + # bce_loss = max(x, 0) - x * y + log(1 + exp(-abs(x))) + # Let's show that it's equivalent: + # Case x>=0: abs(x) = x , max(x, 0) = x + # so we get x - x * y + log(1 + exp(-x)) which is equivalent + # Case x<0: abs(x) = -x, max(x, 0) = 0 + # we have log(1 + exp(-abs(x))) = log(1 + exp(x)) = log(exp(x)(1 + exp(-x))) = x+log(1 + exp(-x)) + # plugging it in, we get + # 0 - x * y + x + log(1 + exp(-x)), which is also equivalent + # Note that this is stable because now the exponent are guaranteed to be below 0. + max_val = tl.clamp(inputs, min=0, max=1e9) + bce_loss = max_val - inputs * targets + tl.log(1 + tl.exp(-tl.abs(inputs))) + + # Modulating factor + p_t = sig * targets + (1 - sig) * inv_targets + mod_factor = libdevice.pow(1 - p_t, gamma) + + # Alpha factor + alpha_t = alpha * targets + (1 - alpha) * inv_targets + + # Final loss calculation + return alpha_t * mod_factor * bce_loss + + # Non-reduced version + @triton.jit + def sigmoid_focal_loss_fwd_kernel( + inputs_ptr, + targets_ptr, + loss_ptr, + alpha: float, + gamma: float, + n_elements: int, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offset = block_start + tl.arange(0, BLOCK_SIZE) + mask = offset < n_elements + + # Load data + inputs = tl.load(inputs_ptr + offset, mask=mask).to(tl.float32) + targets = tl.load(targets_ptr + offset, mask=mask) + + final_loss = _inner_focal_loss_fwd(inputs, targets, alpha, gamma) + + # Store result + tl.store(loss_ptr + offset, final_loss, mask=mask) + + # version with reduction + @triton.jit + def sigmoid_focal_loss_fwd_kernel_reduce( + inputs_ptr, + targets_ptr, + loss_ptr, + alpha: float, + gamma: float, + n_elements: int, + BLOCK_SIZE: tl.constexpr, + REDUCE_SIZE: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + reduce_loc = pid % REDUCE_SIZE + offset = block_start + tl.arange(0, BLOCK_SIZE) + mask = offset < n_elements + # Load data + inputs = tl.load(inputs_ptr + offset, mask=mask).to(tl.float32) + targets = tl.load(targets_ptr + offset, mask=mask) + + final_loss = _inner_focal_loss_fwd(inputs, targets, alpha, gamma) * mask + + fl = tl.sum(final_loss) + + # Store result + tl.atomic_add(loss_ptr + reduce_loc, fl) + + @triton.jit + def _inner_focal_loss_bwd(inputs, targets, alpha, gamma): + inv_targets = 1 - targets + + # Recompute forward + max_val = tl.clamp(inputs, min=0, max=1e9) + bce_loss = max_val - inputs * targets + tl.log(1 + tl.exp(-tl.abs(inputs))) + + # Sigmoid + sig = tl.sigmoid(inputs) + inv_sig = 1 - sig + + # Modulating factor + p_t = sig * targets + inv_sig * inv_targets + tmp = libdevice.pow(1 - p_t, gamma - 1) + mod_factor = tmp * (1 - p_t) + + # Alpha factor + alpha_t = alpha * targets + (1 - alpha) * inv_targets + + # Now computing the derivatives + d_pt = (2 * targets - 1) * sig * inv_sig + d_mod_factor = -gamma * d_pt * tmp + + d_bce_loss = sig - targets + + return alpha_t * (d_bce_loss * mod_factor + d_mod_factor * bce_loss) + + @triton.jit + def sigmoid_focal_loss_bwd_kernel( + inputs_ptr, + targets_ptr, + grad_inputs_ptr, + grad_out_ptr, + alpha: float, + gamma: float, + n_elements: int, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offset = block_start + tl.arange(0, BLOCK_SIZE) + mask = offset < n_elements + input_ptrs = inputs_ptr + offset + target_ptrs = targets_ptr + offset + grad_input_ptrs = grad_inputs_ptr + offset + grad_out_ptrs = grad_out_ptr + offset + # Load data + inputs = tl.load(input_ptrs, mask=mask).to(tl.float32) + targets = tl.load(target_ptrs, mask=mask) + grad_out = tl.load(grad_out_ptrs, mask=mask) + d_loss = grad_out * _inner_focal_loss_bwd(inputs, targets, alpha, gamma) + tl.store(grad_input_ptrs, d_loss, mask=mask) + + @triton.jit + def sigmoid_focal_loss_bwd_kernel_reduce( + inputs_ptr, + targets_ptr, + grad_inputs_ptr, + grad_out_ptr, + alpha: float, + gamma: float, + n_elements: int, + BLOCK_SIZE: tl.constexpr, + ): + # The only difference is that the gradient is now a single scalar + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offset = block_start + tl.arange(0, BLOCK_SIZE) + mask = offset < n_elements + input_ptrs = inputs_ptr + offset + target_ptrs = targets_ptr + offset + grad_input_ptrs = grad_inputs_ptr + offset + # Load data + inputs = tl.load(input_ptrs, mask=mask).to(tl.float32) + targets = tl.load(target_ptrs, mask=mask) + grad_out = tl.load(grad_out_ptr) + d_loss = grad_out * _inner_focal_loss_bwd(inputs, targets, alpha, gamma) + tl.store(grad_input_ptrs, d_loss, mask=mask) + + class SigmoidFocalLossTriton(torch.autograd.Function): + BLOCK_SIZE = 256 + + @staticmethod + def forward(ctx, inputs, targets, alpha=0.25, gamma=2): + n_elements = inputs.numel() + assert targets.numel() == n_elements + input_shape = inputs.shape + inputs = inputs.view(-1).contiguous() + targets = targets.view(-1).contiguous() + loss = torch.empty(inputs.shape, dtype=torch.float32, device=inputs.device) + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + sigmoid_focal_loss_fwd_kernel[grid]( + inputs, + targets, + loss, + alpha, + gamma, + n_elements, + SigmoidFocalLossTriton.BLOCK_SIZE, + ) + ctx.save_for_backward(inputs.view(input_shape), targets.view(input_shape)) + ctx.alpha = alpha + ctx.gamma = gamma + return loss.view(input_shape) + + @staticmethod + def backward(ctx, grad_output): + inputs, targets = ctx.saved_tensors + alpha = ctx.alpha + gamma = ctx.gamma + n_elements = inputs.numel() + input_shape = inputs.shape + grad_inputs = torch.empty( + inputs.shape, dtype=grad_output.dtype, device=grad_output.device + ) + inputs_ptr = inputs.view(-1).contiguous() + targets_ptr = targets.view(-1).contiguous() + grad_output_ptr = grad_output.view(-1).contiguous() + grad_inputs_ptr = grad_inputs + assert grad_output.numel() == n_elements + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + sigmoid_focal_loss_bwd_kernel[grid]( + inputs_ptr, + targets_ptr, + grad_inputs_ptr, + grad_output_ptr, + alpha, + gamma, + n_elements, + SigmoidFocalLossTriton.BLOCK_SIZE, + ) + return grad_inputs.view(input_shape), None, None, None + + class SigmoidFocalLossReducedTriton(torch.autograd.Function): + BLOCK_SIZE = 256 + REDUCE_SIZE = 32 + + @staticmethod + def forward(ctx, inputs, targets, alpha=0.25, gamma=2): + n_elements = inputs.numel() + input_shape = inputs.shape + inputs = inputs.view(-1).contiguous() + targets = targets.view(-1).contiguous() + loss = torch.zeros( + SigmoidFocalLossReducedTriton.REDUCE_SIZE, + device=inputs.device, + dtype=torch.float32, + ) + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + sigmoid_focal_loss_fwd_kernel_reduce[grid]( + inputs, + targets, + loss, + alpha, + gamma, + n_elements, + SigmoidFocalLossReducedTriton.BLOCK_SIZE, + SigmoidFocalLossReducedTriton.REDUCE_SIZE, + ) + ctx.save_for_backward(inputs.view(input_shape), targets.view(input_shape)) + ctx.alpha = alpha + ctx.gamma = gamma + return loss.sum() + + @staticmethod + def backward(ctx, grad_output): + inputs, targets = ctx.saved_tensors + alpha = ctx.alpha + gamma = ctx.gamma + n_elements = inputs.numel() + input_shape = inputs.shape + grad_inputs = torch.empty( + inputs.shape, dtype=grad_output.dtype, device=grad_output.device + ) + inputs_ptr = inputs.view(-1).contiguous() + targets_ptr = targets.reshape(-1).contiguous() + assert grad_output.numel() == 1 + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + sigmoid_focal_loss_bwd_kernel_reduce[grid]( + inputs_ptr, + targets_ptr, + grad_inputs, + grad_output, + alpha, + gamma, + n_elements, + SigmoidFocalLossReducedTriton.BLOCK_SIZE, + ) + return grad_inputs.view(input_shape), None, None, None + + +# ============================================================================ +# Public API - automatically selects best implementation +# ============================================================================ + + +def sigmoid_focal_loss( + inputs: torch.Tensor, + targets: torch.Tensor, + alpha: float = 0.25, + gamma: float = 2.0, +) -> torch.Tensor: + """ + Sigmoid focal loss without reduction. + + Uses Triton kernel on CUDA when available, falls back to PyTorch otherwise. + + Args: + inputs: Tensor of any shape, containing logits + targets: Tensor of the same shape as inputs, containing float targets + alpha: Weighting factor in range (0,1) to balance positive vs negative examples + gamma: Exponent of the modulating factor (1 - p_t) ** gamma + + Returns: + Tensor of the same shape as inputs, containing the focal loss for each element + """ + if HAS_TRITON and inputs.is_cuda: + return SigmoidFocalLossTriton.apply(inputs, targets, alpha, gamma) + else: + return sigmoid_focal_loss_pytorch(inputs, targets, alpha, gamma) + + +def sigmoid_focal_loss_reduce( + inputs: torch.Tensor, + targets: torch.Tensor, + alpha: float = 0.25, + gamma: float = 2.0, +) -> torch.Tensor: + """ + Sigmoid focal loss with sum reduction. + + Uses Triton kernel on CUDA when available, falls back to PyTorch otherwise. + + Args: + inputs: Tensor of any shape, containing logits + targets: Tensor of the same shape as inputs, containing float targets + alpha: Weighting factor in range (0,1) to balance positive vs negative examples + gamma: Exponent of the modulating factor (1 - p_t) ** gamma + + Returns: + Scalar tensor containing the sum of focal losses + """ + if HAS_TRITON and inputs.is_cuda: + return SigmoidFocalLossReducedTriton.apply(inputs, targets, alpha, gamma) + else: + return sigmoid_focal_loss_reduced_pytorch(inputs, targets, alpha, gamma) + + +# Legacy aliases for backward compatibility +triton_sigmoid_focal_loss = sigmoid_focal_loss +triton_sigmoid_focal_loss_reduce = sigmoid_focal_loss_reduce diff --git a/sam3/train/utils/distributed.py b/sam3/train/utils/distributed.py index 3c87a911..de41d724 100644 --- a/sam3/train/utils/distributed.py +++ b/sam3/train/utils/distributed.py @@ -190,6 +190,9 @@ def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, s For some backends, such as NCCL, communication only works if the tensor is on the GPU. This helper function converts to the correct device and returns the tensor + original device. + + Note: NCCL backend only works with CUDA. For MPS or CPU distributed training, + use a different backend like 'gloo'. """ orig_device = "cpu" if not tensor.is_cuda else "gpu" if ( @@ -197,6 +200,7 @@ def convert_to_distributed_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, s and torch.distributed.get_backend() == torch.distributed.Backend.NCCL and not tensor.is_cuda ): + # NCCL requires CUDA tensors tensor = tensor.cuda() return (tensor, orig_device) diff --git a/sam3/utils/__init__.py b/sam3/utils/__init__.py new file mode 100644 index 00000000..0136676b --- /dev/null +++ b/sam3/utils/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +from sam3.utils.device import ( + get_device, + get_device_str, + is_cuda_available, + is_gpu_available, + is_mps_available, + move_model_to_device, + setup_device_optimizations, + tensor_is_on_cuda, + tensor_is_on_gpu, + tensor_is_on_mps, + to_device, +) + +__all__ = [ + "get_device", + "get_device_str", + "is_cuda_available", + "is_mps_available", + "is_gpu_available", + "to_device", + "setup_device_optimizations", + "tensor_is_on_gpu", + "tensor_is_on_cuda", + "tensor_is_on_mps", + "move_model_to_device", +] diff --git a/sam3/utils/device.py b/sam3/utils/device.py new file mode 100644 index 00000000..fb413394 --- /dev/null +++ b/sam3/utils/device.py @@ -0,0 +1,171 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +""" +Device utilities for supporting CUDA, MPS (Apple Silicon), and CPU backends. +""" + +import logging +from functools import lru_cache +from typing import Optional, Union + +import torch + +logger = logging.getLogger(__name__) + + +@lru_cache(maxsize=1) +def get_device() -> torch.device: + """ + Get the best available device for computation. + + Priority: CUDA > MPS > CPU + + Returns: + torch.device: The best available device + """ + if torch.cuda.is_available(): + return torch.device("cuda") + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return torch.device("mps") + else: + return torch.device("cpu") + + +def get_device_str() -> str: + """ + Get the best available device as a string. + + Returns: + str: Device string ("cuda", "mps", or "cpu") + """ + return str(get_device()) + + +def is_cuda_available() -> bool: + """Check if CUDA is available.""" + return torch.cuda.is_available() + + +def is_mps_available() -> bool: + """Check if MPS (Apple Silicon GPU) is available.""" + return hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + + +def is_gpu_available() -> bool: + """Check if any GPU (CUDA or MPS) is available.""" + return is_cuda_available() or is_mps_available() + + +def to_device( + tensor: torch.Tensor, + device: Optional[Union[str, torch.device]] = None, + non_blocking: bool = False, +) -> torch.Tensor: + """ + Move tensor to the specified device, or to the best available device if not specified. + + Args: + tensor: The tensor to move + device: Target device. If None, uses get_device() + non_blocking: Whether to perform the transfer asynchronously + + Returns: + torch.Tensor: Tensor on the target device + """ + if device is None: + device = get_device() + return tensor.to(device=device, non_blocking=non_blocking) + + +def setup_device_optimizations() -> None: + """ + Setup device-specific optimizations. + + - For CUDA Ampere+ GPUs: Enable TensorFloat-32 + - For MPS: Enable high water mark ratio for memory management + - For CPU: Currently no special optimizations + """ + if torch.cuda.is_available(): + try: + device_props = torch.cuda.get_device_properties(0) + if device_props.major >= 8: + # Enable TF32 for Ampere GPUs (compute capability >= 8.0) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + logger.debug("Enabled TensorFloat-32 for Ampere GPU") + except Exception as e: + logger.debug(f"Could not set up CUDA optimizations: {e}") + elif is_mps_available(): + # MPS optimizations for Apple Silicon + try: + # Set high water mark ratio to allow more GPU memory usage + # This can improve performance by reducing memory pressure + torch.mps.set_per_process_memory_fraction(0.0) # No limit + logger.debug("Using MPS (Apple Silicon GPU) with optimizations") + except Exception as e: + logger.debug(f"MPS optimization setup: {e}") + else: + logger.debug("Using CPU") + + +def mps_synchronize() -> None: + """ + Synchronize MPS operations. + + Call this when you need to ensure all MPS operations are complete, + such as before timing or when switching between GPU and CPU operations. + """ + if is_mps_available(): + torch.mps.synchronize() + + +def empty_cache() -> None: + """ + Empty the GPU cache to free memory. + + Works for both CUDA and MPS backends. + """ + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif is_mps_available(): + torch.mps.empty_cache() + + +def get_device_for_tensor(tensor: torch.Tensor) -> torch.device: + """Get the device of a tensor.""" + return tensor.device + + +def tensor_is_on_gpu(tensor: torch.Tensor) -> bool: + """Check if tensor is on a GPU (CUDA or MPS).""" + device_type = tensor.device.type + return device_type in ("cuda", "mps") + + +def tensor_is_on_cuda(tensor: torch.Tensor) -> bool: + """Check if tensor is specifically on CUDA.""" + return tensor.device.type == "cuda" + + +def tensor_is_on_mps(tensor: torch.Tensor) -> bool: + """Check if tensor is specifically on MPS.""" + return tensor.device.type == "mps" + + +def move_model_to_device( + model: torch.nn.Module, + device: Optional[Union[str, torch.device]] = None, +) -> torch.nn.Module: + """ + Move a model to the specified device. + + Args: + model: The model to move + device: Target device. If None, uses get_device() + + Returns: + The model on the target device + """ + if device is None: + device = get_device() + return model.to(device) diff --git a/tests/test_device_support.py b/tests/test_device_support.py new file mode 100644 index 00000000..f0246de5 --- /dev/null +++ b/tests/test_device_support.py @@ -0,0 +1,324 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved + +""" +Tests for CPU and MPS (Apple Silicon) device support. + +Run with: pytest tests/test_device_support.py -v +""" + +import pytest +import torch + + +class TestDeviceUtilities: + """Test the device utility module.""" + + def test_device_module_imports(self): + """Test that device utilities can be imported.""" + from sam3.utils.device import ( + get_device, + get_device_str, + is_cuda_available, + is_gpu_available, + is_mps_available, + setup_device_optimizations, + tensor_is_on_cuda, + tensor_is_on_gpu, + tensor_is_on_mps, + to_device, + ) + + # All functions should be callable + assert callable(get_device) + assert callable(get_device_str) + assert callable(is_cuda_available) + assert callable(is_mps_available) + assert callable(is_gpu_available) + assert callable(to_device) + assert callable(setup_device_optimizations) + + def test_get_device_returns_valid_device(self): + """Test that get_device returns a valid torch.device.""" + from sam3.utils.device import get_device + + device = get_device() + assert isinstance(device, torch.device) + assert device.type in ("cuda", "mps", "cpu") + + def test_get_device_str_returns_string(self): + """Test that get_device_str returns a string.""" + from sam3.utils.device import get_device_str + + device_str = get_device_str() + assert isinstance(device_str, str) + assert device_str in ("cuda", "mps", "cpu") + + def test_device_detection_consistency(self): + """Test that device detection functions are consistent.""" + from sam3.utils.device import ( + get_device, + is_cuda_available, + is_gpu_available, + is_mps_available, + ) + + device = get_device() + + # If CUDA is available, device should be CUDA + if is_cuda_available(): + assert device.type == "cuda" + assert is_gpu_available() + # If MPS is available and CUDA is not, device should be MPS + elif is_mps_available(): + assert device.type == "mps" + assert is_gpu_available() + # Otherwise, device should be CPU + else: + assert device.type == "cpu" + + def test_to_device_moves_tensor(self): + """Test that to_device correctly moves tensors.""" + from sam3.utils.device import get_device, to_device + + tensor = torch.randn(3, 3) + moved_tensor = to_device(tensor) + + expected_device = get_device() + assert moved_tensor.device.type == expected_device.type + + def test_tensor_device_checks(self): + """Test tensor device check functions.""" + from sam3.utils.device import ( + tensor_is_on_cuda, + tensor_is_on_gpu, + tensor_is_on_mps, + ) + + cpu_tensor = torch.randn(3, 3, device="cpu") + assert not tensor_is_on_cuda(cpu_tensor) + assert not tensor_is_on_mps(cpu_tensor) + assert not tensor_is_on_gpu(cpu_tensor) + + +class TestCPUSupport: + """Test that operations work on CPU.""" + + def test_sigmoid_focal_loss_cpu(self): + """Test sigmoid focal loss works on CPU.""" + from sam3.train.loss.sigmoid_focal_loss import ( + sigmoid_focal_loss, + sigmoid_focal_loss_reduce, + ) + + inputs = torch.randn(10, 5, device="cpu", requires_grad=True) + targets = torch.rand(10, 5, device="cpu") + + # Test unreduced version + loss = sigmoid_focal_loss(inputs, targets) + assert loss.device.type == "cpu" + assert loss.shape == inputs.shape + + # Test reduced version + loss_reduced = sigmoid_focal_loss_reduce(inputs, targets) + assert loss_reduced.device.type == "cpu" + assert loss_reduced.dim() == 0 # scalar + + # Test backward pass + loss_reduced.backward() + assert inputs.grad is not None + assert inputs.grad.shape == inputs.shape + + def test_edt_cpu(self): + """Test EDT (Euclidean Distance Transform) works on CPU.""" + from sam3.model.edt import edt + + # Create a batch of binary masks + data = torch.zeros(2, 64, 64, device="cpu") + data[:, 20:40, 20:40] = 1 # Square in the middle + + result = edt(data) + assert result.device.type == "cpu" + assert result.shape == data.shape + # EDT of zeros should be zero + assert (result[data == 0] == 0).all() + + def test_nms_cpu(self): + """Test NMS works on CPU.""" + from sam3.perflib.nms import generic_nms + + n = 10 + # Create a symmetric IoU matrix + ious = torch.rand(n, n, device="cpu") + ious = (ious + ious.T) / 2 # Make symmetric + ious.fill_diagonal_(1.0) # Diagonal should be 1 + + scores = torch.rand(n, device="cpu") + + kept = generic_nms(ious, scores, iou_threshold=0.5) + assert kept.device.type == "cpu" + assert kept.dim() == 1 + assert len(kept) <= n + + def test_connected_components_cpu(self): + """Test connected components works on CPU.""" + from sam3.perflib.connected_components import connected_components + + # Create a batch of binary masks with distinct components + data = torch.zeros(2, 1, 64, 64, device="cpu", dtype=torch.uint8) + data[0, 0, 10:20, 10:20] = 1 # Component 1 + data[0, 0, 40:50, 40:50] = 1 # Component 2 + data[1, 0, 5:15, 5:15] = 1 # Component in second batch + + labels, counts = connected_components(data) + assert labels.device.type == "cpu" + assert counts.device.type == "cpu" + assert labels.shape == data.shape + assert counts.shape == data.shape + + +@pytest.mark.skipif( + not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()), + reason="MPS not available", +) +class TestMPSSupport: + """Test that operations work on MPS (Apple Silicon).""" + + def test_sigmoid_focal_loss_mps(self): + """Test sigmoid focal loss works on MPS.""" + from sam3.train.loss.sigmoid_focal_loss import ( + sigmoid_focal_loss, + sigmoid_focal_loss_reduce, + ) + + inputs = torch.randn(10, 5, device="mps", requires_grad=True) + targets = torch.rand(10, 5, device="mps") + + # Test unreduced version + loss = sigmoid_focal_loss(inputs, targets) + assert loss.device.type == "mps" + assert loss.shape == inputs.shape + + # Test reduced version + loss_reduced = sigmoid_focal_loss_reduce(inputs, targets) + assert loss_reduced.device.type == "mps" + + def test_edt_mps(self): + """Test EDT works on MPS (falls back to CPU internally).""" + from sam3.model.edt import edt + + # Create a batch of binary masks on MPS + data = torch.zeros(2, 64, 64, device="mps") + data[:, 20:40, 20:40] = 1 + + result = edt(data) + # Result should be on MPS (moved back after CPU computation) + assert result.device.type == "mps" + assert result.shape == data.shape + + def test_nms_mps(self): + """Test NMS works on MPS (falls back to CPU internally).""" + from sam3.perflib.nms import generic_nms + + n = 10 + ious = torch.rand(n, n, device="mps") + ious = (ious + ious.T) / 2 + ious.fill_diagonal_(1.0) + scores = torch.rand(n, device="mps") + + kept = generic_nms(ious, scores, iou_threshold=0.5) + # Result should be on MPS + assert kept.device.type == "mps" + + def test_connected_components_mps(self): + """Test connected components works on MPS.""" + from sam3.perflib.connected_components import connected_components + + data = torch.zeros(2, 1, 64, 64, device="mps", dtype=torch.uint8) + data[0, 0, 10:20, 10:20] = 1 + data[0, 0, 40:50, 40:50] = 1 + + labels, counts = connected_components(data) + # Results should be on MPS + assert labels.device.type == "mps" + assert counts.device.type == "mps" + + def test_device_detection_mps(self): + """Test that MPS is detected when available.""" + from sam3.utils.device import get_device, is_gpu_available, is_mps_available + + assert is_mps_available() + assert is_gpu_available() + # If CUDA is not available, MPS should be the default + if not torch.cuda.is_available(): + assert get_device().type == "mps" + + +class TestModelBuilderDeviceSupport: + """Test model builder device handling.""" + + def test_device_parameter_accepted(self): + """Test that build functions accept device parameter.""" + from sam3.model_builder import build_sam3_image_model, build_sam3_video_model + import inspect + + # Check that device parameter exists + image_sig = inspect.signature(build_sam3_image_model) + video_sig = inspect.signature(build_sam3_video_model) + + assert "device" in image_sig.parameters + assert "device" in video_sig.parameters + + # Check defaults are None (auto-detect) + assert image_sig.parameters["device"].default is None + assert video_sig.parameters["device"].default is None + + +class TestTransformerDeviceSupport: + """Test transformer module device handling.""" + + def test_rope_attention_cpu(self): + """Test RoPEAttention works on CPU.""" + from sam3.sam.transformer import RoPEAttention + + attention = RoPEAttention( + embedding_dim=256, + num_heads=8, + downsample_rate=1, + feat_sizes=(8, 8), + ) + attention = attention.to("cpu") + + # Create dummy inputs + batch_size = 2 + seq_len = 64 + q = torch.randn(batch_size, seq_len, 256, device="cpu") + k = torch.randn(batch_size, seq_len, 256, device="cpu") + v = torch.randn(batch_size, seq_len, 256, device="cpu") + + output = attention(q, k, v) + assert output.device.type == "cpu" + assert output.shape == (batch_size, seq_len, 256) + + def test_attention_cpu(self): + """Test base Attention works on CPU.""" + from sam3.sam.transformer import Attention + + attention = Attention( + embedding_dim=256, + num_heads=8, + ) + attention = attention.to("cpu") + + batch_size = 2 + seq_len = 64 + q = torch.randn(batch_size, seq_len, 256, device="cpu") + k = torch.randn(batch_size, seq_len, 256, device="cpu") + v = torch.randn(batch_size, seq_len, 256, device="cpu") + + output = attention(q, k, v) + assert output.device.type == "cpu" + assert output.shape == (batch_size, seq_len, 256) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])