diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py index c94161f..adeec8f 100644 --- a/sam2/utils/misc.py +++ b/sam2/utils/misc.py @@ -346,6 +346,31 @@ def load_video_frames_from_memory( return images, video_height, video_width +def load_video_frames_from_memory( + imgs, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), +): + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + images = torch.zeros(len(imgs), 3, image_size, image_size, dtype=torch.float32) + + for n, img in enumerate(imgs): + images[n], video_height, video_width = _load_img_np_as_tensor(img, image_size) + + if not offload_video_to_cpu: + images = images.cuda() + img_mean = img_mean.cuda() + img_std = img_std.cuda() + + images -= img_mean + images /= img_std + return images, video_height, video_width + + def fill_holes_in_mask_scores(mask, max_area): """ A post processor to fill small holes in mask scores with area under `max_area`.