diff --git a/pyproject.toml b/pyproject.toml index e4998de0..62567c20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ ] dependencies = [ "timm>=1.0.17", - "numpy==1.26", + "numpy>=1.26", "tqdm", "ftfy==6.1.1", "regex", @@ -59,7 +59,7 @@ notebooks = [ "ipycanvas", "ipympl", "pycocotools", - "decord", + "decord2", "opencv-python", "einops", "scikit-image", diff --git a/sam3/eval/postprocessors.py b/sam3/eval/postprocessors.py index 973da118..0eaa7dde 100644 --- a/sam3/eval/postprocessors.py +++ b/sam3/eval/postprocessors.py @@ -150,9 +150,13 @@ def _process_masks(self, target_sizes, pred_masks, consistent=True, keep=None): if pred_masks is None: return None if self.always_interpolate_masks_on_gpu: - gpu_device = target_sizes.device - assert gpu_device.type == "cuda" - pred_masks = pred_masks.to(device=gpu_device) + device = target_sizes.device + if device.type == "cpu": + logging.warning( + "always_interpolate_masks_on_gpu=True but data is on CPU; " + "falling back to CPU interpolation" + ) + pred_masks = pred_masks.to(device=device) if consistent: assert keep is None, "TODO: implement?" # All masks should have the same shape, expected when processing a batch of size 1 @@ -454,9 +458,13 @@ def process_results( ] # [P,Q,...] --> [K,...] meta_td = meta_td[tracked_obj_ids_idx[PROMPT_AXIS].cpu()] if self.always_interpolate_masks_on_gpu: - gpu_device = meta_td["original_size"].device - assert gpu_device.type == "cuda" - tracked_objs_outs_td = tracked_objs_outs_td.to(device=gpu_device) + device = meta_td["original_size"].device + if device.type == "cpu": + logging.warning( + "always_interpolate_masks_on_gpu=True but data is on CPU; " + "falling back to CPU interpolation" + ) + tracked_objs_outs_td = tracked_objs_outs_td.to(device=device) frame_results_td = self( tracked_objs_outs_td.unsqueeze(1), ( diff --git a/sam3/model/decoder.py b/sam3/model/decoder.py index c8b1657e..6e85cb9e 100644 --- a/sam3/model/decoder.py +++ b/sam3/model/decoder.py @@ -277,8 +277,9 @@ def __init__( if resolution is not None and stride is not None: feat_size = resolution // stride + device = "cuda" if torch.cuda.is_available() else "cpu" coords_h, coords_w = self._get_coords( - feat_size, feat_size, device="cuda" + feat_size, feat_size, device=device ) self.compilable_cord_cache = (coords_h, coords_w) self.compilable_stored_size = (feat_size, feat_size) diff --git a/sam3/model/geometry_encoders.py b/sam3/model/geometry_encoders.py index bff29172..e5242487 100644 --- a/sam3/model/geometry_encoders.py +++ b/sam3/model/geometry_encoders.py @@ -10,7 +10,7 @@ from .act_ckpt_utils import activation_ckpt_wrapper from .box_ops import box_cxcywh_to_xyxy -from .model_misc import get_clones +from .model_misc import get_clones, tensor_to_device def is_right_padded(mask): @@ -656,7 +656,7 @@ def _encode_boxes(self, boxes, boxes_mask, boxes_labels, img_feats): # We need to denormalize, and convert to [x, y, x, y] boxes_xyxy = box_cxcywh_to_xyxy(boxes) scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype) - scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True) + scale = tensor_to_device(scale, boxes_xyxy.device) scale = scale.view(1, 1, 4) boxes_xyxy = boxes_xyxy * scale sampled = torchvision.ops.roi_align( diff --git a/sam3/model/io_utils.py b/sam3/model/io_utils.py index 0a225842..b6df71e2 100644 --- a/sam3/model/io_utils.py +++ b/sam3/model/io_utils.py @@ -26,10 +26,18 @@ VIDEO_EXTS = [".mp4", ".mov", ".avi", ".mkv", ".webm"] +def _get_float_dtype(device): + """Return appropriate float dtype for the device (float32 for CPU, float16 for GPU).""" + if device.type == 'cpu': + return torch.float32 + return torch.float16 + + def load_resource_as_video_frames( resource_path, image_size, offload_video_to_cpu, + device=None, img_mean=(0.5, 0.5, 0.5), img_std=(0.5, 0.5, 0.5), async_loading_frames=False, @@ -39,9 +47,12 @@ def load_resource_as_video_frames( Load video frames from either a video or an image (as a single-frame video). Alternatively, if input is a list of PIL images, convert its format """ + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + float_dtype = _get_float_dtype(device) if isinstance(resource_path, list): - img_mean = torch.tensor(img_mean, dtype=torch.float16)[:, None, None] - img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None] + img_mean = torch.tensor(img_mean, dtype=float_dtype)[:, None, None] + img_std = torch.tensor(img_std, dtype=float_dtype)[:, None, None] assert all(isinstance(img_pil, Image.Image) for img_pil in resource_path) assert len(resource_path) is not None orig_height, orig_width = resource_path[0].size @@ -55,15 +66,14 @@ def load_resource_as_video_frames( assert img_np.dtype == np.uint8, "np.uint8 is expected for JPEG images" img_np = img_np / 255.0 img = torch.from_numpy(img_np).permute(2, 0, 1) - # float16 precision should be sufficient for image tensor storage - img = img.to(dtype=torch.float16) + img = img.to(dtype=float_dtype) # normalize by mean and std img -= img_mean img /= img_std images.append(img) images = torch.stack(images) if not offload_video_to_cpu: - images = images.cuda() + images = images.to(device) return images, orig_height, orig_width is_image = ( @@ -75,6 +85,7 @@ def load_resource_as_video_frames( image_path=resource_path, image_size=image_size, offload_video_to_cpu=offload_video_to_cpu, + device=device, img_mean=img_mean, img_std=img_std, ) @@ -83,6 +94,7 @@ def load_resource_as_video_frames( video_path=resource_path, image_size=image_size, offload_video_to_cpu=offload_video_to_cpu, + device=device, img_mean=img_mean, img_std=img_std, async_loading_frames=async_loading_frames, @@ -94,19 +106,23 @@ def load_image_as_single_frame_video( image_path, image_size, offload_video_to_cpu, + device=None, img_mean=(0.5, 0.5, 0.5), img_std=(0.5, 0.5, 0.5), ): """Load an image as a single-frame video.""" + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + float_dtype = _get_float_dtype(device) images, image_height, image_width = _load_img_as_tensor(image_path, image_size) - images = images.unsqueeze(0).half() + images = images.unsqueeze(0).to(float_dtype) - img_mean = torch.tensor(img_mean, dtype=torch.float16)[:, None, None] - img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None] + img_mean = torch.tensor(img_mean, dtype=float_dtype)[:, None, None] + img_std = torch.tensor(img_std, dtype=float_dtype)[:, None, None] if not offload_video_to_cpu: - images = images.cuda() - img_mean = img_mean.cuda() - img_std = img_std.cuda() + images = images.to(device) + img_mean = img_mean.to(device) + img_std = img_std.to(device) # normalize by mean and std images -= img_mean images /= img_std @@ -117,6 +133,7 @@ def load_video_frames( video_path, image_size, offload_video_to_cpu, + device=None, img_mean=(0.5, 0.5, 0.5), img_std=(0.5, 0.5, 0.5), async_loading_frames=False, @@ -126,17 +143,20 @@ def load_video_frames( Load the video frames from video_path. The frames are resized to image_size as in the model and are loaded to GPU if offload_video_to_cpu=False. This is used by the demo. """ + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") assert isinstance(video_path, str) if video_path.startswith(" where N is an integer match = re.match(r"", video_path) num_frames = int(match.group(1)) if match else 60 - return load_dummy_video(image_size, offload_video_to_cpu, num_frames=num_frames) + return load_dummy_video(image_size, offload_video_to_cpu, device, num_frames=num_frames) elif os.path.isdir(video_path): return load_video_frames_from_image_folder( image_folder=video_path, image_size=image_size, offload_video_to_cpu=offload_video_to_cpu, + device=device, img_mean=img_mean, img_std=img_std, async_loading_frames=async_loading_frames, @@ -146,6 +166,7 @@ def load_video_frames( video_path=video_path, image_size=image_size, offload_video_to_cpu=offload_video_to_cpu, + device=device, img_mean=img_mean, img_std=img_std, async_loading_frames=async_loading_frames, @@ -159,6 +180,7 @@ def load_video_frames_from_image_folder( image_folder, image_size, offload_video_to_cpu, + device, img_mean, img_std, async_loading_frames, @@ -166,6 +188,7 @@ def load_video_frames_from_image_folder( """ Load the video frames from a directory of image files ("." format) """ + float_dtype = _get_float_dtype(device) frame_names = [ p for p in os.listdir(image_folder) @@ -184,26 +207,25 @@ def load_video_frames_from_image_folder( if num_frames == 0: raise RuntimeError(f"no images found in {image_folder}") img_paths = [os.path.join(image_folder, frame_name) for frame_name in frame_names] - img_mean = torch.tensor(img_mean, dtype=torch.float16)[:, None, None] - img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None] + img_mean = torch.tensor(img_mean, dtype=float_dtype)[:, None, None] + img_std = torch.tensor(img_std, dtype=float_dtype)[:, None, None] if async_loading_frames: lazy_images = AsyncImageFrameLoader( - img_paths, image_size, offload_video_to_cpu, img_mean, img_std + img_paths, image_size, offload_video_to_cpu, device, img_mean, img_std ) return lazy_images, lazy_images.video_height, lazy_images.video_width - # float16 precision should be sufficient for image tensor storage - images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float16) + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=float_dtype) video_height, video_width = None, None for n, img_path in enumerate( tqdm(img_paths, desc=f"frame loading (image folder) [rank={RANK}]") ): images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) if not offload_video_to_cpu: - images = images.cuda() - img_mean = img_mean.cuda() - img_std = img_std.cuda() + images = images.to(device) + img_mean = img_mean.to(device) + img_std = img_std.to(device) # normalize by mean and std images -= img_mean images /= img_std @@ -214,6 +236,7 @@ def load_video_frames_from_video_file( video_path, image_size, offload_video_to_cpu, + device, img_mean, img_std, async_loading_frames, @@ -229,6 +252,7 @@ def load_video_frames_from_video_file( img_mean=img_mean, img_std=img_std, offload_video_to_cpu=offload_video_to_cpu, + device=device, ) elif video_loader_type == "torchcodec": logger.info("Using torchcodec to load video file") @@ -240,6 +264,7 @@ def load_video_frames_from_video_file( img_std=img_std, gpu_acceleration=gpu_acceleration, gpu_device=gpu_device, + device=device, ) # The `AsyncVideoFileLoaderWithTorchCodec` class always loads the videos asynchronously, # so we just wait for its loading thread to finish if async_loading_frames=False. @@ -258,6 +283,7 @@ def load_video_frames_from_video_file_using_cv2( img_mean: tuple = (0.5, 0.5, 0.5), img_std: tuple = (0.5, 0.5, 0.5), offload_video_to_cpu: bool = False, + device: torch.device = None, ) -> torch.Tensor: """ Load video from path, convert to normalized tensor with specified preprocessing @@ -269,10 +295,14 @@ def load_video_frames_from_video_file_using_cv2( img_std: Normalization standard deviation (RGB) Returns: - torch.Tensor: Preprocessed video tensor in shape (T, C, H, W) with float16 dtype + torch.Tensor: Preprocessed video tensor in shape (T, C, H, W) """ import cv2 # delay OpenCV import to avoid unnecessary dependency + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + float_dtype = _get_float_dtype(device) + # Initialize video capture cap = cv2.VideoCapture(video_path) if not cap.isOpened(): @@ -304,26 +334,27 @@ def load_video_frames_from_video_file_using_cv2( frames_np = np.stack(frames, axis=0).astype(np.float32) # (T, H, W, C) video_tensor = torch.from_numpy(frames_np).permute(0, 3, 1, 2) # (T, C, H, W) - img_mean = torch.tensor(img_mean, dtype=torch.float16).view(1, 3, 1, 1) - img_std = torch.tensor(img_std, dtype=torch.float16).view(1, 3, 1, 1) + img_mean = torch.tensor(img_mean, dtype=float_dtype).view(1, 3, 1, 1) + img_std = torch.tensor(img_std, dtype=float_dtype).view(1, 3, 1, 1) if not offload_video_to_cpu: - video_tensor = video_tensor.cuda() - img_mean = img_mean.cuda() - img_std = img_std.cuda() + video_tensor = video_tensor.to(device) + img_mean = img_mean.to(device) + img_std = img_std.to(device) # normalize by mean and std video_tensor -= img_mean video_tensor /= img_std return video_tensor, original_height, original_width -def load_dummy_video(image_size, offload_video_to_cpu, num_frames=60): +def load_dummy_video(image_size, offload_video_to_cpu, device, num_frames=60): """ Load a dummy video with random frames for testing and compilation warmup purposes. """ + float_dtype = _get_float_dtype(device) video_height, video_width = 480, 640 # dummy original video sizes - images = torch.randn(num_frames, 3, image_size, image_size, dtype=torch.float16) + images = torch.randn(num_frames, 3, image_size, image_size, dtype=float_dtype) if not offload_video_to_cpu: - images = images.cuda() + images = images.to(device) return images, video_height, video_width @@ -341,10 +372,11 @@ class AsyncImageFrameLoader: A list of video frames to be load asynchronously without blocking session start. """ - def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std): + def __init__(self, img_paths, image_size, offload_video_to_cpu, device, img_mean, img_std): self.img_paths = img_paths self.image_size = image_size self.offload_video_to_cpu = offload_video_to_cpu + self.device = device self.img_mean = img_mean self.img_std = img_std # items in `self._images` will be loaded asynchronously @@ -386,13 +418,12 @@ def __getitem__(self, index): ) self.video_height = video_height self.video_width = video_width - # float16 precision should be sufficient for image tensor storage - img = img.to(dtype=torch.float16) + img = img.to(dtype=_get_float_dtype(self.device)) # normalize by mean and std img -= self.img_mean img /= self.img_std if not self.offload_video_to_cpu: - img = img.cuda() + img = img.to(self.device) self.images[index] = img return img @@ -500,6 +531,7 @@ def __init__( img_std, gpu_acceleration=True, gpu_device=None, + device=None, use_rand_seek_in_loading=False, ): # Check and possibly infer the output device (and also get its GPU id when applicable) @@ -507,31 +539,34 @@ def __init__( gpu_id = ( gpu_device.index if gpu_device is not None and gpu_device.index is not None - else torch.cuda.current_device() + else (torch.cuda.current_device() if torch.cuda.is_available() else None) ) - if offload_video_to_cpu: + if device is not None: + out_device = device + elif offload_video_to_cpu: out_device = torch.device("cpu") else: out_device = torch.device("cuda") if gpu_device is None else gpu_device self.out_device = out_device - self.gpu_acceleration = gpu_acceleration + float_dtype = _get_float_dtype(out_device) + self.gpu_acceleration = gpu_acceleration and out_device.type == "cuda" self.gpu_id = gpu_id self.image_size = image_size self.offload_video_to_cpu = offload_video_to_cpu if not isinstance(img_mean, torch.Tensor): - img_mean = torch.tensor(img_mean, dtype=torch.float16)[:, None, None] + img_mean = torch.tensor(img_mean, dtype=float_dtype)[:, None, None] self.img_mean = img_mean if not isinstance(img_std, torch.Tensor): - img_std = torch.tensor(img_std, dtype=torch.float16)[:, None, None] + img_std = torch.tensor(img_std, dtype=float_dtype)[:, None, None] self.img_std = img_std - if gpu_acceleration: + if self.gpu_acceleration: self.img_mean = self.img_mean.to(f"cuda:{self.gpu_id}") self.img_std = self.img_std.to(f"cuda:{self.gpu_id}") decoder_option = {"device": f"cuda:{self.gpu_id}"} else: - self.img_mean = self.img_mean.cpu() - self.img_std = self.img_std.cpu() + self.img_mean = self.img_mean.to(out_device) + self.img_std = self.img_std.to(out_device) decoder_option = {"num_threads": 1} # use a single thread to save memory self.rank = int(os.environ.get("RANK", "0")) @@ -551,7 +586,7 @@ def __init__( 3, self.image_size, self.image_size, - dtype=torch.float16, + dtype=float_dtype, device=self.out_device, ) # catch and raise any exceptions in the async loading thread @@ -652,15 +687,14 @@ def _transform_frame(self, frame): mode="bicubic", align_corners=False, )[0] - # float16 precision should be sufficient for image tensor storage - frame_resized = frame_resized.half() # uint8 -> float16 + frame_resized = frame_resized.to(dtype=_get_float_dtype(self.out_device)) frame_resized /= 255 frame_resized -= self.img_mean frame_resized /= self.img_std if self.offload_video_to_cpu: frame_resized = frame_resized.cpu() elif frame_resized.device != self.out_device: - frame_resized = frame_resized.to(device=self.out_device, non_blocking=True) + frame_resized = frame_resized.to(device=self.out_device, non_blocking=torch.cuda.is_available()) return frame_resized def __getitem__(self, index): diff --git a/sam3/model/model_misc.py b/sam3/model/model_misc.py index 2cb44b3f..879e21d7 100644 --- a/sam3/model/model_misc.py +++ b/sam3/model/model_misc.py @@ -28,6 +28,26 @@ def inverse_sigmoid(x, eps=1e-3): return torch.log(x1 / x2) +def tensor_to_device(tensor: Tensor, device: torch.device) -> Tensor: + """ + Transfer a tensor to the target device efficiently. + + For CUDA devices, uses pin_memory() for faster CPU→GPU transfers. + For other devices (MPS, CPU), uses direct .to() transfer. + + Args: + tensor: The tensor to transfer (should be on CPU) + device: The target device + + Returns: + The tensor on the target device + """ + if device.type == "cuda": + return tensor.pin_memory().to(device=device, non_blocking=True) + else: + return tensor.to(device=device) + + class MultiheadAttentionWrapper(nn.MultiheadAttention): def forward(self, *args, **kwargs): kwargs["need_weights"] = False diff --git a/sam3/model/position_encoding.py b/sam3/model/position_encoding.py index eb3f4055..270327ea 100644 --- a/sam3/model/position_encoding.py +++ b/sam3/model/position_encoding.py @@ -43,8 +43,9 @@ def __init__( (precompute_resolution // 16, precompute_resolution // 16), (precompute_resolution // 32, precompute_resolution // 32), ] + device = "cuda" if torch.cuda.is_available() else "cpu" for size in precompute_sizes: - tensors = torch.zeros((1, 1) + size, device="cuda") + tensors = torch.zeros((1, 1) + size, device=device) self.forward(tensors) # further clone and detach it in the cache (just to be safe) self.cache[size] = self.cache[size].clone().detach() diff --git a/sam3/model/sam3_image.py b/sam3/model/sam3_image.py index aafe520b..8e1a8c6b 100644 --- a/sam3/model/sam3_image.py +++ b/sam3/model/sam3_image.py @@ -836,7 +836,10 @@ def _build_multigpu_buffer_next_chunk( assert len(feats["backbone_fpn"]) == 3 # SAM2 backbone always have 3 levels # cast the SAM2 backbone features to bfloat16 for all-gather (this is usually # a no-op, SAM2 backbone features are likely already in bfloat16 due to AMP) - backbone_fpn_bf16 = [x.to(torch.bfloat16) for x in feats["backbone_fpn"]] + if torch.cuda.is_available(): + backbone_fpn_bf16 = [x.to(torch.bfloat16) for x in feats["backbone_fpn"]] + else: + backbone_fpn_bf16 = list(feats["backbone_fpn"]) fpn0, fpn_handle0 = self._gather_tensor(backbone_fpn_bf16[0]) fpn1, fpn_handle1 = self._gather_tensor(backbone_fpn_bf16[1]) fpn2, fpn_handle2 = self._gather_tensor(backbone_fpn_bf16[2]) diff --git a/sam3/model/sam3_image_processor.py b/sam3/model/sam3_image_processor.py index 4d98fbfb..da42569b 100644 --- a/sam3/model/sam3_image_processor.py +++ b/sam3/model/sam3_image_processor.py @@ -14,9 +14,11 @@ class Sam3Processor: """ """ - def __init__(self, model, resolution=1008, device="cuda", confidence_threshold=0.5): + def __init__(self, model, resolution=1008, device=None, confidence_threshold=0.5): self.model = model self.resolution = resolution + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" self.device = device self.transform = v2.Compose( [ diff --git a/sam3/model/sam3_tracker_base.py b/sam3/model/sam3_tracker_base.py index 90fbd696..98a75794 100644 --- a/sam3/model/sam3_tracker_base.py +++ b/sam3/model/sam3_tracker_base.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from sam3.model.memory import SimpleMaskEncoder - +from sam3.model.model_misc import tensor_to_device from sam3.model.sam3_tracker_utils import get_1d_sine_pe, select_closest_cond_frames from sam3.sam.mask_decoder import MaskDecoder, MLP @@ -165,7 +165,7 @@ def _get_tpos_enc(self, rel_pos_list, device, max_abs_pos=None, dummy=False): t_diff_max = max_abs_pos - 1 if max_abs_pos is not None else 1 pos_enc = ( - torch.tensor(rel_pos_list).pin_memory().to(device=device, non_blocking=True) + tensor_to_device(torch.tensor(rel_pos_list), device) / t_diff_max ) tpos_dim = self.hidden_dim @@ -653,15 +653,15 @@ def _prepare_memory_conditioned_features( if prev is None: continue # skip padding frames # "maskmem_features" might have been offloaded to CPU in demo use cases, - # so we load it back to GPU (it's a no-op if it's already on GPU). - feats = prev["maskmem_features"].cuda(non_blocking=True) + # so we load it back to device (it's a no-op if it's already on device). + feats = prev["maskmem_features"].to(device, non_blocking=torch.cuda.is_available()) seq_len = feats.shape[-2] * feats.shape[-1] to_cat_prompt.append(feats.flatten(2).permute(2, 0, 1)) to_cat_prompt_mask.append( torch.zeros(B, seq_len, device=device, dtype=bool) ) # Spatial positional encoding (it might have been offloaded to CPU in eval) - maskmem_enc = prev["maskmem_pos_enc"][-1].cuda() + maskmem_enc = prev["maskmem_pos_enc"][-1].to(device) maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) if ( diff --git a/sam3/model/sam3_tracker_utils.py b/sam3/model/sam3_tracker_utils.py index 7afc70aa..e2173408 100644 --- a/sam3/model/sam3_tracker_utils.py +++ b/sam3/model/sam3_tracker_utils.py @@ -5,7 +5,14 @@ import torch.nn.functional as F from numpy.typing import NDArray -from sam3.model.edt import edt_triton +# Triton is only available on CUDA (not Apple Silicon/MPS) +try: + from sam3.model.edt import edt_triton + + _HAS_TRITON = True +except ImportError: + _HAS_TRITON = False + edt_triton = None def sample_box_points( @@ -148,6 +155,10 @@ def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True): - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks """ + # Fall back to slow (OpenCV-based) implementation if Triton is not available + if not _HAS_TRITON or not gt_masks.is_cuda: + return sample_one_point_from_error_center_slow(gt_masks, pred_masks, padding) + if pred_masks is None: pred_masks = torch.zeros_like(gt_masks) assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1 diff --git a/sam3/model/sam3_tracking_predictor.py b/sam3/model/sam3_tracking_predictor.py index b2440ef6..ae4584f3 100644 --- a/sam3/model/sam3_tracking_predictor.py +++ b/sam3/model/sam3_tracking_predictor.py @@ -46,8 +46,11 @@ def __init__( self.max_point_num_in_prompt_enc = max_point_num_in_prompt_enc self.non_overlap_masks_for_output = non_overlap_masks_for_output - self.bf16_context = torch.autocast(device_type="cuda", dtype=torch.bfloat16) - self.bf16_context.__enter__() # keep using for the entire model process + if torch.cuda.is_available(): + self.bf16_context = torch.autocast(device_type="cuda", dtype=torch.bfloat16) + self.bf16_context.__enter__() # keep using for the entire model process + else: + self.bf16_context = None self.iter_use_prev_mask_pred = True self.add_all_frames_to_correct_as_cond = True @@ -75,7 +78,7 @@ def init_state( # and from 24 to 21 when tracking two objects) inference_state["offload_state_to_cpu"] = offload_state_to_cpu inference_state["device"] = self.device - if offload_state_to_cpu: + if offload_state_to_cpu or not torch.cuda.is_available(): inference_state["storage_device"] = torch.device("cpu") else: inference_state["storage_device"] = torch.device("cuda") @@ -300,7 +303,7 @@ def add_new_points_or_box( prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) if prev_out is not None and prev_out["pred_masks"] is not None: - prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True) + prev_sam_mask_logits = prev_out["pred_masks"].to(inference_state["device"], non_blocking=torch.cuda.is_available()) # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) current_out, _ = self._run_single_frame_inference( @@ -469,7 +472,7 @@ def _get_orig_video_res_output(self, inference_state, any_res_masks): device = inference_state["device"] video_H = inference_state["video_height"] video_W = inference_state["video_width"] - any_res_masks = any_res_masks.to(device, non_blocking=True) + any_res_masks = any_res_masks.to(device, non_blocking=torch.cuda.is_available()) if any_res_masks.shape[-2:] == (video_H, video_W): video_res_masks = any_res_masks else: @@ -609,7 +612,7 @@ def _consolidate_temp_output_across_obj( if run_mem_encoder: device = inference_state["device"] high_res_masks = torch.nn.functional.interpolate( - consolidated_out["pred_masks"].to(device, non_blocking=True), + consolidated_out["pred_masks"].to(device, non_blocking=torch.cuda.is_available()), size=(self.image_size, self.image_size), mode="bilinear", align_corners=False, @@ -1021,7 +1024,7 @@ def _get_image_feature(self, inference_state, frame_idx, batch_size): ) else: # Cache miss -- we will run inference on a single image - image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0) + image = inference_state["images"][frame_idx].to(inference_state["device"]).float().unsqueeze(0) backbone_out = self.forward_image(image) # Cache the most recent frame's feature (for repeated interactions with # a frame; we can use an LRU cache for more frames in the future). @@ -1093,10 +1096,11 @@ def _run_single_frame_inference( storage_device = inference_state["storage_device"] maskmem_features = current_out["maskmem_features"] if maskmem_features is not None: - maskmem_features = maskmem_features.to(torch.bfloat16) - maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + if torch.cuda.is_available(): + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=torch.cuda.is_available()) pred_masks_gpu = current_out["pred_masks"] - pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) + pred_masks = pred_masks_gpu.to(storage_device, non_blocking=torch.cuda.is_available()) # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) # object pointer is a small tensor, so we always keep it on GPU memory for fast access @@ -1144,8 +1148,9 @@ def _run_memory_encoder( # optionally offload the output to CPU memory to save GPU space storage_device = inference_state["storage_device"] - maskmem_features = maskmem_features.to(torch.bfloat16) - maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + if torch.cuda.is_available(): + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=torch.cuda.is_available()) # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it maskmem_pos_enc = self._get_maskmem_pos_enc( inference_state, {"maskmem_pos_enc": maskmem_pos_enc} diff --git a/sam3/model/sam3_video_inference.py b/sam3/model/sam3_video_inference.py index 7fb87d01..3911a510 100644 --- a/sam3/model/sam3_video_inference.py +++ b/sam3/model/sam3_video_inference.py @@ -15,6 +15,7 @@ from sam3.model.data_misc import BatchedDatapoint, convert_my_tensors, FindStage from sam3.model.geometry_encoders import Prompt from sam3.model.io_utils import IMAGE_EXTS, load_resource_as_video_frames +from sam3.model.model_misc import tensor_to_device from sam3.model.sam3_tracker_utils import fill_holes_in_mask_scores from sam3.model.sam3_video_base import MaskletConfirmationStatus, Sam3VideoBase from sam3.model.utils.misc import copy_data_to_device @@ -59,10 +60,14 @@ def init_state( video_loader_type="cv2", ): """Initialize an inference state from `resource_path` (an image or a video).""" + # Get actual current device from model parameters + device = next(self.parameters()).device + images, orig_height, orig_width = load_resource_as_video_frames( resource_path=resource_path, image_size=self.image_size, offload_video_to_cpu=offload_video_to_cpu, + device=device, img_mean=self.image_mean, img_std=self.image_std, async_loading_frames=async_loading_frames, @@ -147,7 +152,7 @@ def _construct_initial_input_batch(self, inference_state, images): find_targets=[None] * num_frames, find_metadatas=[None] * num_frames, ) - input_batch = copy_data_to_device(input_batch, device, non_blocking=True) + input_batch = copy_data_to_device(input_batch, device, non_blocking=torch.cuda.is_available()) inference_state["input_batch"] = input_batch # construct the placeholder interactive prompts and tracking queries @@ -477,9 +482,7 @@ def _postprocess_output( # slice those valid entries from the original outputs keep_idx = torch.nonzero(keep, as_tuple=True)[0] - keep_idx_gpu = keep_idx.pin_memory().to( - device=out_binary_masks.device, non_blocking=True - ) + keep_idx_gpu = tensor_to_device(keep_idx, out_binary_masks.device) out_obj_ids = torch.index_select(out_obj_ids, 0, keep_idx) out_probs = torch.index_select(out_probs, 0, keep_idx) diff --git a/sam3/model/sam3_video_predictor.py b/sam3/model/sam3_video_predictor.py index c639e1d0..3cdf14e5 100644 --- a/sam3/model/sam3_video_predictor.py +++ b/sam3/model/sam3_video_predictor.py @@ -39,6 +39,14 @@ def __init__( self.video_loader_type = video_loader_type from sam3.model_builder import build_sam3_video_model + # Determine device + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + + logger.info(f"Sam3VideoPredictor using device: {self.device}") + self.model = ( build_sam3_video_model( checkpoint_path=checkpoint_path, @@ -48,7 +56,7 @@ def __init__( strict_state_dict_loading=strict_state_dict_loading, apply_temporal_disambiguation=apply_temporal_disambiguation, ) - .cuda() + .to(self.device) .eval() ) @@ -265,21 +273,27 @@ def _get_session_stats(self): f"'{session_id}' ({session['state']['num_frames']} frames)" for session_id, session in self._ALL_INFERENCE_STATES.items() ] - session_stats_str = ( - f"live sessions: [{', '.join(live_session_strs)}], GPU memory: " - f"{torch.cuda.memory_allocated() // 1024**2} MiB used and " - f"{torch.cuda.memory_reserved() // 1024**2} MiB reserved" - f" (max over time: {torch.cuda.max_memory_allocated() // 1024**2} MiB used " - f"and {torch.cuda.max_memory_reserved() // 1024**2} MiB reserved)" - ) + if torch.cuda.is_available(): + mem_stats = ( + f"GPU memory: {torch.cuda.memory_allocated() // 1024**2} MiB used and " + f"{torch.cuda.memory_reserved() // 1024**2} MiB reserved" + f" (max over time: {torch.cuda.max_memory_allocated() // 1024**2} MiB used " + f"and {torch.cuda.max_memory_reserved() // 1024**2} MiB reserved)" + ) + else: + mem_stats = "Running on CPU" + session_stats_str = f"live sessions: [{', '.join(live_session_strs)}], {mem_stats}" return session_stats_str def _get_torch_and_gpu_properties(self): """Get a string for PyTorch and GPU properties (for logging and debugging).""" - torch_and_gpu_str = ( - f"torch: {torch.__version__} with CUDA arch {torch.cuda.get_arch_list()}, " - f"GPU device: {torch.cuda.get_device_properties(torch.cuda.current_device())}" - ) + if torch.cuda.is_available(): + torch_and_gpu_str = ( + f"torch: {torch.__version__} with CUDA arch {torch.cuda.get_arch_list()}, " + f"GPU device: {torch.cuda.get_device_properties(torch.cuda.current_device())}" + ) + else: + torch_and_gpu_str = f"torch: {torch.__version__} (CPU mode)" return torch_and_gpu_str def shutdown(self): diff --git a/sam3/model/utils/sam2_utils.py b/sam3/model/utils/sam2_utils.py index d91ba0f1..4e4b39e1 100644 --- a/sam3/model/utils/sam2_utils.py +++ b/sam3/model/utils/sam2_utils.py @@ -85,7 +85,7 @@ def __getitem__(self, index): img -= self.img_mean img /= self.img_std if not self.offload_video_to_cpu: - img = img.to(self.compute_device, non_blocking=True) + img = img.to(self.compute_device, non_blocking=torch.cuda.is_available()) self.images[index] = img return img diff --git a/sam3/model/vl_combiner.py b/sam3/model/vl_combiner.py index 43bc7bd5..43a5be42 100644 --- a/sam3/model/vl_combiner.py +++ b/sam3/model/vl_combiner.py @@ -119,8 +119,10 @@ def _forward_image_no_act_ckpt(self, samples): return output def forward_text( - self, captions, input_boxes=None, additional_text=None, device="cuda" + self, captions, input_boxes=None, additional_text=None, device=None ): + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" return activation_ckpt_wrapper(self._forward_text_no_ack_ckpt)( captions=captions, input_boxes=input_boxes, @@ -134,8 +136,10 @@ def _forward_text_no_ack_ckpt( captions, input_boxes=None, additional_text=None, - device="cuda", + device=None, ): + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" output = {} # Forward through text_encoder diff --git a/sam3/model_builder.py b/sam3/model_builder.py index 058bbec3..8e4a4f46 100644 --- a/sam3/model_builder.py +++ b/sam3/model_builder.py @@ -34,7 +34,7 @@ from sam3.model.sam3_image import Sam3Image, Sam3ImageOnVideoMultiGPU from sam3.model.sam3_tracking_predictor import Sam3TrackerPredictor from sam3.model.sam3_video_inference import Sam3VideoInferenceWithInstanceInteractivity -from sam3.model.sam3_video_predictor import Sam3VideoPredictorMultiGPU +from sam3.model.sam3_video_predictor import Sam3VideoPredictor, Sam3VideoPredictorMultiGPU from sam3.model.text_encoder_ve import VETextEncoder from sam3.model.tokenizer_ve import SimpleTokenizer from sam3.model.vitdet import ViT @@ -547,8 +547,10 @@ def _load_checkpoint(model, checkpoint_path): def _setup_device_and_mode(model, device, eval_mode): """Setup model device and evaluation mode.""" - if device == "cuda": + if device == "cuda" and torch.cuda.is_available(): model = model.cuda() + elif device != "cpu": + model = model.to(device) if eval_mode: model.eval() return model @@ -788,6 +790,9 @@ def build_sam3_video_model( def build_sam3_video_predictor(*model_args, gpus_to_use=None, **model_kwargs): + # Use single-device predictor on CPU, multi-GPU predictor only when CUDA is available + if not torch.cuda.is_available(): + return Sam3VideoPredictor(*model_args, **model_kwargs) return Sam3VideoPredictorMultiGPU( *model_args, gpus_to_use=gpus_to_use, **model_kwargs ) diff --git a/sam3/perflib/connected_components.py b/sam3/perflib/connected_components.py index c96932a4..8a786062 100644 --- a/sam3/perflib/connected_components.py +++ b/sam3/perflib/connected_components.py @@ -39,6 +39,13 @@ def connected_components_cpu(input_tensor: torch.Tensor): ), "Input tensor must be (B, H, W) or (B, 1, H, W)." batch_size = input_tensor.shape[0] + # Handle empty batch case + if batch_size == 0: + return ( + torch.zeros(out_shape, dtype=torch.int64, device=input_tensor.device), + torch.zeros(out_shape, dtype=torch.int64, device=input_tensor.device), + ) + labels_list = [] counts_list = [] for b in range(batch_size):