diff --git a/videox_fun/data/dataset_image_video.py b/videox_fun/data/dataset_image_video.py index ca3f8cc..e355f66 100644 --- a/videox_fun/data/dataset_image_video.py +++ b/videox_fun/data/dataset_image_video.py @@ -23,6 +23,7 @@ from torch.utils.data import BatchSampler, Sampler from torch.utils.data.dataset import Dataset from contextlib import contextmanager +from videox_fun.utils.video_tuple_loader import load_video_mask_tuple_clip VIDEO_READER_TIMEOUT = 20 @@ -595,32 +596,14 @@ def get_batch(self, idx): elif data_info.get('type', 'image') == 'video_mask_tuple': # object effect removal sample_dir = data_info['file_path'] try: - if os.path.exists(os.path.join(sample_dir, 'rgb_full.mp4')): - input_video_path = os.path.join(sample_dir, 'rgb_full.mp4') - target_video_path = os.path.join(sample_dir, 'rgb_removed.mp4') - mask_video_path = os.path.join(sample_dir, 'mask.mp4') - depth_video_path = os.path.join(sample_dir, 'depth_removed.mp4') - - input_video = media.read_video(input_video_path) - target_video = media.read_video(target_video_path) - mask_video = media.read_video(mask_video_path) - - # Load depth map if it exists - depth_video = None - if os.path.exists(depth_video_path): - depth_video = media.read_video(depth_video_path) - - else: - input_video_path = os.path.join(sample_dir, 'input') - target_video_path = os.path.join(sample_dir, 'bg') - mask_video_path = os.path.join(sample_dir, 'trimask') - - input_video = _read_video_from_dir(input_video_path) - target_video = _read_video_from_dir(target_video_path) - mask_video = _read_video_from_dir(mask_video_path) - - # Initialize depth_video as None for this path - depth_video = None + input_video, target_video, mask_video, depth_video = load_video_mask_tuple_clip( + sample_dir=sample_dir, + video_sample_n_frames=self.video_sample_n_frames, + video_length_drop_start=self.video_length_drop_start, + video_length_drop_end=self.video_length_drop_end, + video_sample_stride=self.video_sample_stride, + target_short_side=self.larger_side_of_image_and_video, + ) except Exception as e: print(f"Error loading video_mask_tuple from {sample_dir}: {e}") import traceback @@ -633,62 +616,26 @@ def get_batch(self, idx): mask_video = mask_video[..., None] if mask_video.shape[-1] == 3: mask_video = mask_video[..., :1] - min_sample_n_frames = min( - self.video_sample_n_frames, - int(len(input_video) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) - ) - video_length = int(self.video_length_drop_end * len(input_video)) - clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) - start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 - batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) - input_video = input_video[batch_index] - target_video = target_video[batch_index] - mask_video = mask_video[batch_index] - if depth_video is not None: - depth_video = depth_video[batch_index] - - resized_inputs = [] - resized_targets = [] - resized_masks = [] - resized_depths = [] - for i in range(len(input_video)): - resized_input = resize_frame(input_video[i], self.larger_side_of_image_and_video) - resized_target = resize_frame(target_video[i], self.larger_side_of_image_and_video) - resized_mask = resize_frame(mask_video[i], self.larger_side_of_image_and_video) - - # Apply mask quantization based on mode - if self.ablation_binary_mask: - # Ablation binary mask mode: remap [0, 63, 127, 255] to [0, 127] - # Map 0 and 63 → 0 - # Map 127 and 255 → 127 - resized_mask = np.where(resized_mask <= 95, 0, resized_mask) - resized_mask = np.where(resized_mask > 95, 127, resized_mask) - elif self.use_quadmask: - # Quadmask mode: preserve 4 values [0, 63, 127, 255] - # Quantize to nearest quadmask value for robustness - resized_mask = np.where(resized_mask <= 31, 0, resized_mask) - resized_mask = np.where(np.logical_and(resized_mask > 31, resized_mask <= 95), 63, resized_mask) - resized_mask = np.where(np.logical_and(resized_mask > 95, resized_mask <= 191), 127, resized_mask) - resized_mask = np.where(resized_mask > 191, 255, resized_mask) - else: - # Trimask mode: 3 values [0, 127, 255] - resized_mask = np.where(np.logical_and(resized_mask > 63, resized_mask < 192), 127, resized_mask) - resized_mask = np.where(resized_mask >= 192, 255, resized_mask) - resized_mask = np.where(resized_mask <= 63, 0, resized_mask) - resized_inputs.append(resized_input) - resized_targets.append(resized_target) - resized_masks.append(resized_mask) - - if depth_video is not None: - resized_depth = resize_frame(depth_video[i], self.larger_side_of_image_and_video) - resized_depths.append(resized_depth) - - input_video = np.array(resized_inputs) - target_video = np.array(resized_targets) - mask_video = np.array(resized_masks) - if depth_video is not None: - depth_video = np.array(resized_depths) + # Apply mask quantization after resize to preserve existing trimask / quadmask behavior. + if self.ablation_binary_mask: + # Ablation binary mask mode: remap [0, 63, 127, 255] to [0, 127] + # Map 0 and 63 → 0 + # Map 127 and 255 → 127 + mask_video = np.where(mask_video <= 95, 0, mask_video) + mask_video = np.where(mask_video > 95, 127, mask_video) + elif self.use_quadmask: + # Quadmask mode: preserve 4 values [0, 63, 127, 255] + # Quantize to nearest quadmask value for robustness + mask_video = np.where(mask_video <= 31, 0, mask_video) + mask_video = np.where(np.logical_and(mask_video > 31, mask_video <= 95), 63, mask_video) + mask_video = np.where(np.logical_and(mask_video > 95, mask_video <= 191), 127, mask_video) + mask_video = np.where(mask_video > 191, 255, mask_video) + else: + # Trimask mode: 3 values [0, 127, 255] + mask_video = np.where(np.logical_and(mask_video > 63, mask_video < 192), 127, mask_video) + mask_video = np.where(mask_video >= 192, 255, mask_video) + mask_video = np.where(mask_video <= 63, 0, mask_video) if len(mask_video.shape) == 3: mask_video = mask_video[..., None] diff --git a/videox_fun/data/dataset_image_video_warped.py b/videox_fun/data/dataset_image_video_warped.py index 0928dae..e93b8be 100644 --- a/videox_fun/data/dataset_image_video_warped.py +++ b/videox_fun/data/dataset_image_video_warped.py @@ -23,6 +23,7 @@ from torch.utils.data import BatchSampler, Sampler from torch.utils.data.dataset import Dataset from contextlib import contextmanager +from videox_fun.utils.video_tuple_loader import load_video_mask_tuple_clip VIDEO_READER_TIMEOUT = 20 @@ -595,32 +596,14 @@ def get_batch(self, idx): elif data_info.get('type', 'image') == 'video_mask_tuple': # object effect removal sample_dir = data_info['file_path'] if self.data_root is None else os.path.join(self.data_root, data_info['file_path']) try: - if os.path.exists(os.path.join(sample_dir, 'rgb_full.mp4')): - input_video_path = os.path.join(sample_dir, 'rgb_full.mp4') - target_video_path = os.path.join(sample_dir, 'rgb_removed.mp4') - mask_video_path = os.path.join(sample_dir, 'mask.mp4') - depth_video_path = os.path.join(sample_dir, 'depth_removed.mp4') - - input_video = media.read_video(input_video_path) - target_video = media.read_video(target_video_path) - mask_video = media.read_video(mask_video_path) - - # Load depth map if it exists - depth_video = None - if os.path.exists(depth_video_path): - depth_video = media.read_video(depth_video_path) - - else: - input_video_path = os.path.join(sample_dir, 'input') - target_video_path = os.path.join(sample_dir, 'bg') - mask_video_path = os.path.join(sample_dir, 'trimask') - - input_video = _read_video_from_dir(input_video_path) - target_video = _read_video_from_dir(target_video_path) - mask_video = _read_video_from_dir(mask_video_path) - - # Initialize depth_video as None for this path - depth_video = None + input_video, target_video, mask_video, depth_video = load_video_mask_tuple_clip( + sample_dir=sample_dir, + video_sample_n_frames=self.video_sample_n_frames, + video_length_drop_start=self.video_length_drop_start, + video_length_drop_end=self.video_length_drop_end, + video_sample_stride=self.video_sample_stride, + target_short_side=self.larger_side_of_image_and_video, + ) except Exception as e: print(f"Error loading video_mask_tuple from {sample_dir}: {e}") import traceback @@ -633,62 +616,26 @@ def get_batch(self, idx): mask_video = mask_video[..., None] if mask_video.shape[-1] == 3: mask_video = mask_video[..., :1] - min_sample_n_frames = min( - self.video_sample_n_frames, - int(len(input_video) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) - ) - video_length = int(self.video_length_drop_end * len(input_video)) - clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) - start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 - batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) - input_video = input_video[batch_index] - target_video = target_video[batch_index] - mask_video = mask_video[batch_index] - if depth_video is not None: - depth_video = depth_video[batch_index] - - resized_inputs = [] - resized_targets = [] - resized_masks = [] - resized_depths = [] - for i in range(len(input_video)): - resized_input = resize_frame(input_video[i], self.larger_side_of_image_and_video) - resized_target = resize_frame(target_video[i], self.larger_side_of_image_and_video) - resized_mask = resize_frame(mask_video[i], self.larger_side_of_image_and_video) - - # Apply mask quantization based on mode - if self.ablation_binary_mask: - # Ablation binary mask mode: remap [0, 63, 127, 255] to [0, 127] - # Map 0 and 63 → 0 - # Map 127 and 255 → 127 - resized_mask = np.where(resized_mask <= 95, 0, resized_mask) - resized_mask = np.where(resized_mask > 95, 127, resized_mask) - elif self.use_quadmask: - # Quadmask mode: preserve 4 values [0, 63, 127, 255] - # Quantize to nearest quadmask value for robustness - resized_mask = np.where(resized_mask <= 31, 0, resized_mask) - resized_mask = np.where(np.logical_and(resized_mask > 31, resized_mask <= 95), 63, resized_mask) - resized_mask = np.where(np.logical_and(resized_mask > 95, resized_mask <= 191), 127, resized_mask) - resized_mask = np.where(resized_mask > 191, 255, resized_mask) - else: - # Trimask mode: 3 values [0, 127, 255] - resized_mask = np.where(np.logical_and(resized_mask > 63, resized_mask < 192), 127, resized_mask) - resized_mask = np.where(resized_mask >= 192, 255, resized_mask) - resized_mask = np.where(resized_mask <= 63, 0, resized_mask) - resized_inputs.append(resized_input) - resized_targets.append(resized_target) - resized_masks.append(resized_mask) - - if depth_video is not None: - resized_depth = resize_frame(depth_video[i], self.larger_side_of_image_and_video) - resized_depths.append(resized_depth) - - input_video = np.array(resized_inputs) - target_video = np.array(resized_targets) - mask_video = np.array(resized_masks) - if depth_video is not None: - depth_video = np.array(resized_depths) + # Apply mask quantization after resize to preserve existing trimask / quadmask behavior. + if self.ablation_binary_mask: + # Ablation binary mask mode: remap [0, 63, 127, 255] to [0, 127] + # Map 0 and 63 → 0 + # Map 127 and 255 → 127 + mask_video = np.where(mask_video <= 95, 0, mask_video) + mask_video = np.where(mask_video > 95, 127, mask_video) + elif self.use_quadmask: + # Quadmask mode: preserve 4 values [0, 63, 127, 255] + # Quantize to nearest quadmask value for robustness + mask_video = np.where(mask_video <= 31, 0, mask_video) + mask_video = np.where(np.logical_and(mask_video > 31, mask_video <= 95), 63, mask_video) + mask_video = np.where(np.logical_and(mask_video > 95, mask_video <= 191), 127, mask_video) + mask_video = np.where(mask_video > 191, 255, mask_video) + else: + # Trimask mode: 3 values [0, 127, 255] + mask_video = np.where(np.logical_and(mask_video > 63, mask_video < 192), 127, mask_video) + mask_video = np.where(mask_video >= 192, 255, mask_video) + mask_video = np.where(mask_video <= 63, 0, mask_video) if len(mask_video.shape) == 3: mask_video = mask_video[..., None] diff --git a/videox_fun/utils/video_tuple_loader.py b/videox_fun/utils/video_tuple_loader.py new file mode 100644 index 0000000..c975047 --- /dev/null +++ b/videox_fun/utils/video_tuple_loader.py @@ -0,0 +1,181 @@ +import gc +import glob +import os +import random +from contextlib import contextmanager + +import cv2 +import mediapy as media +import numpy as np +from decord import VideoReader +from func_timeout import FunctionTimedOut, func_timeout + +VIDEO_READER_TIMEOUT = 20 +VIDEO_READER_NUM_THREADS = 2 + + +@contextmanager +def _video_reader_contextmanager(*args, **kwargs): + video_reader = VideoReader(*args, **kwargs) + try: + yield video_reader + finally: + del video_reader + gc.collect() + + +def _get_video_reader_batch(video_reader, batch_index): + return video_reader.get_batch(batch_index).asnumpy() + + +def _resize_frame(frame, target_short_side): + h, w, _ = frame.shape + if h < w: + if target_short_side > h: + return frame + new_h = target_short_side + new_w = int(target_short_side * w / h) + else: + if target_short_side > w: + return frame + new_w = target_short_side + new_h = int(target_short_side * h / w) + + return cv2.resize(frame, (new_w, new_h)) + + +def _resize_clip(frames, target_short_side): + return np.array([_resize_frame(frame, target_short_side) for frame in frames]) + + +def _sample_batch_index( + total_frames, + video_sample_n_frames, + video_length_drop_start, + video_length_drop_end, + video_sample_stride, +): + min_sample_n_frames = min( + video_sample_n_frames, + int(total_frames * (video_length_drop_end - video_length_drop_start) // video_sample_stride), + ) + if min_sample_n_frames == 0: + raise ValueError("No frames available after sampling constraints.") + + video_length = int(video_length_drop_end * total_frames) + clip_length = min(video_length, (min_sample_n_frames - 1) * video_sample_stride + 1) + start_idx = ( + random.randint(int(video_length_drop_start * video_length), video_length - clip_length) + if video_length != clip_length + else 0 + ) + return np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) + + +def _read_resized_video_clip(video_path, batch_index, target_short_side): + with _video_reader_contextmanager(video_path, num_threads=VIDEO_READER_NUM_THREADS) as video_reader: + return _read_resized_video_reader_clip(video_reader, batch_index, target_short_side, video_path) + + +def _read_resized_video_reader_clip(video_reader, batch_index, target_short_side, source_name): + try: + frames = func_timeout( + VIDEO_READER_TIMEOUT, + _get_video_reader_batch, + args=(video_reader, batch_index), + ) + return _resize_clip(frames, target_short_side) + except FunctionTimedOut as exc: + raise ValueError(f"Read timeout while sampling frames from {source_name}.") from exc + except Exception as exc: + raise ValueError(f"Failed to extract frames from {source_name}. Error is {exc}.") from exc + + +def _get_frame_paths(frame_dir): + frame_paths = sorted(glob.glob(os.path.join(frame_dir, "*.png"))) + if not frame_paths: + raise ValueError(f"No PNG files found in directory: {frame_dir}") + return frame_paths + + +def _read_resized_frame_dir_clip(frame_dir, batch_index, target_short_side): + frame_paths = _get_frame_paths(frame_dir) + try: + selected_frames = [media.read_image(frame_paths[idx]) for idx in batch_index] + except IndexError as exc: + raise ValueError( + f"Frame selection for {frame_dir} is out of range. Requested up to index " + f"{int(np.max(batch_index))} but only found {len(frame_paths)} frames." + ) from exc + + return _resize_clip(selected_frames, target_short_side) + + +def load_video_mask_tuple_clip( + sample_dir, + video_sample_n_frames, + video_length_drop_start, + video_length_drop_end, + video_sample_stride, + target_short_side, +): + mp4_input_path = os.path.join(sample_dir, "rgb_full.mp4") + + if os.path.exists(mp4_input_path): + with _video_reader_contextmanager(mp4_input_path, num_threads=VIDEO_READER_NUM_THREADS) as video_reader: + batch_index = _sample_batch_index( + total_frames=len(video_reader), + video_sample_n_frames=video_sample_n_frames, + video_length_drop_start=video_length_drop_start, + video_length_drop_end=video_length_drop_end, + video_sample_stride=video_sample_stride, + ) + input_video = _read_resized_video_reader_clip( + video_reader, + batch_index, + target_short_side, + mp4_input_path, + ) + + target_video = _read_resized_video_clip( + os.path.join(sample_dir, "rgb_removed.mp4"), + batch_index, + target_short_side, + ) + mask_video = _read_resized_video_clip( + os.path.join(sample_dir, "mask.mp4"), + batch_index, + target_short_side, + ) + + depth_video_path = os.path.join(sample_dir, "depth_removed.mp4") + depth_video = None + if os.path.exists(depth_video_path): + depth_video = _read_resized_video_clip(depth_video_path, batch_index, target_short_side) + else: + input_dir = os.path.join(sample_dir, "input") + input_frame_paths = _get_frame_paths(input_dir) + batch_index = _sample_batch_index( + total_frames=len(input_frame_paths), + video_sample_n_frames=video_sample_n_frames, + video_length_drop_start=video_length_drop_start, + video_length_drop_end=video_length_drop_end, + video_sample_stride=video_sample_stride, + ) + input_video = _resize_clip( + [media.read_image(input_frame_paths[idx]) for idx in batch_index], + target_short_side, + ) + target_video = _read_resized_frame_dir_clip( + os.path.join(sample_dir, "bg"), + batch_index, + target_short_side, + ) + mask_video = _read_resized_frame_dir_clip( + os.path.join(sample_dir, "trimask"), + batch_index, + target_short_side, + ) + depth_video = None + + return input_video, target_video, mask_video, depth_video