Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 28 additions & 81 deletions videox_fun/data/dataset_image_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch.utils.data import BatchSampler, Sampler
from torch.utils.data.dataset import Dataset
from contextlib import contextmanager
from videox_fun.utils.video_tuple_loader import load_video_mask_tuple_clip

VIDEO_READER_TIMEOUT = 20

Expand Down Expand Up @@ -595,32 +596,14 @@ def get_batch(self, idx):
elif data_info.get('type', 'image') == 'video_mask_tuple': # object effect removal
sample_dir = data_info['file_path']
try:
if os.path.exists(os.path.join(sample_dir, 'rgb_full.mp4')):
input_video_path = os.path.join(sample_dir, 'rgb_full.mp4')
target_video_path = os.path.join(sample_dir, 'rgb_removed.mp4')
mask_video_path = os.path.join(sample_dir, 'mask.mp4')
depth_video_path = os.path.join(sample_dir, 'depth_removed.mp4')

input_video = media.read_video(input_video_path)
target_video = media.read_video(target_video_path)
mask_video = media.read_video(mask_video_path)

# Load depth map if it exists
depth_video = None
if os.path.exists(depth_video_path):
depth_video = media.read_video(depth_video_path)

else:
input_video_path = os.path.join(sample_dir, 'input')
target_video_path = os.path.join(sample_dir, 'bg')
mask_video_path = os.path.join(sample_dir, 'trimask')

input_video = _read_video_from_dir(input_video_path)
target_video = _read_video_from_dir(target_video_path)
mask_video = _read_video_from_dir(mask_video_path)

# Initialize depth_video as None for this path
depth_video = None
input_video, target_video, mask_video, depth_video = load_video_mask_tuple_clip(
sample_dir=sample_dir,
video_sample_n_frames=self.video_sample_n_frames,
video_length_drop_start=self.video_length_drop_start,
video_length_drop_end=self.video_length_drop_end,
video_sample_stride=self.video_sample_stride,
target_short_side=self.larger_side_of_image_and_video,
)
except Exception as e:
print(f"Error loading video_mask_tuple from {sample_dir}: {e}")
import traceback
Expand All @@ -633,62 +616,26 @@ def get_batch(self, idx):
mask_video = mask_video[..., None]
if mask_video.shape[-1] == 3:
mask_video = mask_video[..., :1]
min_sample_n_frames = min(
self.video_sample_n_frames,
int(len(input_video) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
)
video_length = int(self.video_length_drop_end * len(input_video))
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
input_video = input_video[batch_index]
target_video = target_video[batch_index]
mask_video = mask_video[batch_index]
if depth_video is not None:
depth_video = depth_video[batch_index]

resized_inputs = []
resized_targets = []
resized_masks = []
resized_depths = []
for i in range(len(input_video)):
resized_input = resize_frame(input_video[i], self.larger_side_of_image_and_video)
resized_target = resize_frame(target_video[i], self.larger_side_of_image_and_video)
resized_mask = resize_frame(mask_video[i], self.larger_side_of_image_and_video)

# Apply mask quantization based on mode
if self.ablation_binary_mask:
# Ablation binary mask mode: remap [0, 63, 127, 255] to [0, 127]
# Map 0 and 63 → 0
# Map 127 and 255 → 127
resized_mask = np.where(resized_mask <= 95, 0, resized_mask)
resized_mask = np.where(resized_mask > 95, 127, resized_mask)
elif self.use_quadmask:
# Quadmask mode: preserve 4 values [0, 63, 127, 255]
# Quantize to nearest quadmask value for robustness
resized_mask = np.where(resized_mask <= 31, 0, resized_mask)
resized_mask = np.where(np.logical_and(resized_mask > 31, resized_mask <= 95), 63, resized_mask)
resized_mask = np.where(np.logical_and(resized_mask > 95, resized_mask <= 191), 127, resized_mask)
resized_mask = np.where(resized_mask > 191, 255, resized_mask)
else:
# Trimask mode: 3 values [0, 127, 255]
resized_mask = np.where(np.logical_and(resized_mask > 63, resized_mask < 192), 127, resized_mask)
resized_mask = np.where(resized_mask >= 192, 255, resized_mask)
resized_mask = np.where(resized_mask <= 63, 0, resized_mask)

resized_inputs.append(resized_input)
resized_targets.append(resized_target)
resized_masks.append(resized_mask)

if depth_video is not None:
resized_depth = resize_frame(depth_video[i], self.larger_side_of_image_and_video)
resized_depths.append(resized_depth)

input_video = np.array(resized_inputs)
target_video = np.array(resized_targets)
mask_video = np.array(resized_masks)
if depth_video is not None:
depth_video = np.array(resized_depths)
# Apply mask quantization after resize to preserve existing trimask / quadmask behavior.
if self.ablation_binary_mask:
# Ablation binary mask mode: remap [0, 63, 127, 255] to [0, 127]
# Map 0 and 63 → 0
# Map 127 and 255 → 127
mask_video = np.where(mask_video <= 95, 0, mask_video)
mask_video = np.where(mask_video > 95, 127, mask_video)
elif self.use_quadmask:
# Quadmask mode: preserve 4 values [0, 63, 127, 255]
# Quantize to nearest quadmask value for robustness
mask_video = np.where(mask_video <= 31, 0, mask_video)
mask_video = np.where(np.logical_and(mask_video > 31, mask_video <= 95), 63, mask_video)
mask_video = np.where(np.logical_and(mask_video > 95, mask_video <= 191), 127, mask_video)
mask_video = np.where(mask_video > 191, 255, mask_video)
else:
# Trimask mode: 3 values [0, 127, 255]
mask_video = np.where(np.logical_and(mask_video > 63, mask_video < 192), 127, mask_video)
mask_video = np.where(mask_video >= 192, 255, mask_video)
mask_video = np.where(mask_video <= 63, 0, mask_video)

if len(mask_video.shape) == 3:
mask_video = mask_video[..., None]
Expand Down
109 changes: 28 additions & 81 deletions videox_fun/data/dataset_image_video_warped.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch.utils.data import BatchSampler, Sampler
from torch.utils.data.dataset import Dataset
from contextlib import contextmanager
from videox_fun.utils.video_tuple_loader import load_video_mask_tuple_clip

VIDEO_READER_TIMEOUT = 20

Expand Down Expand Up @@ -595,32 +596,14 @@ def get_batch(self, idx):
elif data_info.get('type', 'image') == 'video_mask_tuple': # object effect removal
sample_dir = data_info['file_path'] if self.data_root is None else os.path.join(self.data_root, data_info['file_path'])
try:
if os.path.exists(os.path.join(sample_dir, 'rgb_full.mp4')):
input_video_path = os.path.join(sample_dir, 'rgb_full.mp4')
target_video_path = os.path.join(sample_dir, 'rgb_removed.mp4')
mask_video_path = os.path.join(sample_dir, 'mask.mp4')
depth_video_path = os.path.join(sample_dir, 'depth_removed.mp4')

input_video = media.read_video(input_video_path)
target_video = media.read_video(target_video_path)
mask_video = media.read_video(mask_video_path)

# Load depth map if it exists
depth_video = None
if os.path.exists(depth_video_path):
depth_video = media.read_video(depth_video_path)

else:
input_video_path = os.path.join(sample_dir, 'input')
target_video_path = os.path.join(sample_dir, 'bg')
mask_video_path = os.path.join(sample_dir, 'trimask')

input_video = _read_video_from_dir(input_video_path)
target_video = _read_video_from_dir(target_video_path)
mask_video = _read_video_from_dir(mask_video_path)

# Initialize depth_video as None for this path
depth_video = None
input_video, target_video, mask_video, depth_video = load_video_mask_tuple_clip(
sample_dir=sample_dir,
video_sample_n_frames=self.video_sample_n_frames,
video_length_drop_start=self.video_length_drop_start,
video_length_drop_end=self.video_length_drop_end,
video_sample_stride=self.video_sample_stride,
target_short_side=self.larger_side_of_image_and_video,
)
except Exception as e:
print(f"Error loading video_mask_tuple from {sample_dir}: {e}")
import traceback
Expand All @@ -633,62 +616,26 @@ def get_batch(self, idx):
mask_video = mask_video[..., None]
if mask_video.shape[-1] == 3:
mask_video = mask_video[..., :1]
min_sample_n_frames = min(
self.video_sample_n_frames,
int(len(input_video) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
)
video_length = int(self.video_length_drop_end * len(input_video))
clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
input_video = input_video[batch_index]
target_video = target_video[batch_index]
mask_video = mask_video[batch_index]
if depth_video is not None:
depth_video = depth_video[batch_index]

resized_inputs = []
resized_targets = []
resized_masks = []
resized_depths = []
for i in range(len(input_video)):
resized_input = resize_frame(input_video[i], self.larger_side_of_image_and_video)
resized_target = resize_frame(target_video[i], self.larger_side_of_image_and_video)
resized_mask = resize_frame(mask_video[i], self.larger_side_of_image_and_video)

# Apply mask quantization based on mode
if self.ablation_binary_mask:
# Ablation binary mask mode: remap [0, 63, 127, 255] to [0, 127]
# Map 0 and 63 → 0
# Map 127 and 255 → 127
resized_mask = np.where(resized_mask <= 95, 0, resized_mask)
resized_mask = np.where(resized_mask > 95, 127, resized_mask)
elif self.use_quadmask:
# Quadmask mode: preserve 4 values [0, 63, 127, 255]
# Quantize to nearest quadmask value for robustness
resized_mask = np.where(resized_mask <= 31, 0, resized_mask)
resized_mask = np.where(np.logical_and(resized_mask > 31, resized_mask <= 95), 63, resized_mask)
resized_mask = np.where(np.logical_and(resized_mask > 95, resized_mask <= 191), 127, resized_mask)
resized_mask = np.where(resized_mask > 191, 255, resized_mask)
else:
# Trimask mode: 3 values [0, 127, 255]
resized_mask = np.where(np.logical_and(resized_mask > 63, resized_mask < 192), 127, resized_mask)
resized_mask = np.where(resized_mask >= 192, 255, resized_mask)
resized_mask = np.where(resized_mask <= 63, 0, resized_mask)

resized_inputs.append(resized_input)
resized_targets.append(resized_target)
resized_masks.append(resized_mask)

if depth_video is not None:
resized_depth = resize_frame(depth_video[i], self.larger_side_of_image_and_video)
resized_depths.append(resized_depth)

input_video = np.array(resized_inputs)
target_video = np.array(resized_targets)
mask_video = np.array(resized_masks)
if depth_video is not None:
depth_video = np.array(resized_depths)
# Apply mask quantization after resize to preserve existing trimask / quadmask behavior.
if self.ablation_binary_mask:
# Ablation binary mask mode: remap [0, 63, 127, 255] to [0, 127]
# Map 0 and 63 → 0
# Map 127 and 255 → 127
mask_video = np.where(mask_video <= 95, 0, mask_video)
mask_video = np.where(mask_video > 95, 127, mask_video)
elif self.use_quadmask:
# Quadmask mode: preserve 4 values [0, 63, 127, 255]
# Quantize to nearest quadmask value for robustness
mask_video = np.where(mask_video <= 31, 0, mask_video)
mask_video = np.where(np.logical_and(mask_video > 31, mask_video <= 95), 63, mask_video)
mask_video = np.where(np.logical_and(mask_video > 95, mask_video <= 191), 127, mask_video)
mask_video = np.where(mask_video > 191, 255, mask_video)
else:
# Trimask mode: 3 values [0, 127, 255]
mask_video = np.where(np.logical_and(mask_video > 63, mask_video < 192), 127, mask_video)
mask_video = np.where(mask_video >= 192, 255, mask_video)
mask_video = np.where(mask_video <= 63, 0, mask_video)

if len(mask_video.shape) == 3:
mask_video = mask_video[..., None]
Expand Down
Loading