Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 155 additions & 2 deletions src/flat_mae/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
# References:
# capi: https://github.com/facebookresearch/capi/blob/main/data.py

import math
import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor
from torch.utils.data import default_collate
from jaxtyping import Float, Int
Expand Down Expand Up @@ -78,13 +82,162 @@ def __init__(
super().__init__(mask_ratio=mask_ratio, img_size=img_size, patch_size=patch_size)


# TODO:
# - inverse block masking
def _make_block_mask(x: int, y: int, h: int, w: int, shape: tuple[int, int], roll: bool = True) -> torch.Tensor:
"""
Make a rectangular block mask where (x, y) is the block center, (h, w) is the block
shape, and shape is the grid shape.

If roll is True, the mask is wrapped around the edges, and otherwise cropped.
"""
H, W = shape

# (x, y) is box center
top = y - h // 2
left = x - w // 2

# row and column indices of the block (possibly over the edges)
y_ids = torch.arange(top, top + h)
x_ids = torch.arange(left, left + w)

# handle edges. wrap around if roll is true, otherwise clip.
# note that if roll is not enabled, the box can be smaller than (h, w).
if roll:
y_ids = y_ids % H
x_ids = x_ids % W
else:
y_ids = y_ids[(y_ids >= 0) & (y_ids < H)]
x_ids = x_ids[(x_ids >= 0) & (x_ids < W)]

# create block mask as intersection of row and column mask.
y_mask = torch.zeros(shape)
y_mask[y_ids, :] = 1
x_mask = torch.zeros(shape)
x_mask[:, x_ids] = 1
mask = y_mask * x_mask
return mask


def _inverse_block_masking(
mask: torch.Tensor,
*,
mask_ratio: float,
patch_size: int | tuple[int, int],
roll: bool = True,
min_aspect: float = 1.0,
max_aspect: float | None = None,
) -> torch.Tensor:
"""Sample an inverse block mask on a 2D tensor shape (H, W)."""

H, W = mask.shape # shape of the mask
p, q = to_2tuple(patch_size)

grid_h = H // p
grid_w = W // q

# patchify mask to [grid_h, grid_w]
mask_patches = rearrange(
mask,
"(h p) (w q) -> (h w) (p q)",
h=grid_h,
w=grid_w,
p=p,
q=q,
)
L, D = mask_patches.shape

patch_mask = mask_patches.sum(dim=-1).clip(max=1)
patch_mask = patch_mask.reshape(grid_h, grid_w)

len_keep = int((1 - mask_ratio) * L)
total_patches = int(patch_mask.sum().item())
len_keep = min(len_keep, total_patches)

max_aspect = max_aspect or 1 / min_aspect
min_lar, max_lar = (
np.log(min_aspect),
np.log(max_aspect),
) # get the aspect ratio in log space to treat the ratios symmetrically
# sample an aspect ratio
# note that we don't need to worry whether the aspect is too big/small to fit in the
# image, since we scale the box below anyway.
aspect_ratio = math.exp(np.random.uniform(min_lar, max_lar))

# height and width of the block given the aspect ratio
# len_keep: h * w
# aspect: h / w
h = math.ceil(math.sqrt(len_keep * aspect_ratio))
w = math.ceil(math.sqrt(len_keep / aspect_ratio))

# sample a random position for the box center
y = random.randint(0, grid_h - 1)
x = random.randint(0, grid_w - 1)

# increase the block size until it covers enough valid patches
scale = 1.2
h_, w_ = h, w
while True:
block_mask = _make_block_mask(x, y, h_, w_, (grid_h, grid_w), roll=roll)
block_mask = patch_mask * block_mask
if block_mask.sum() >= len_keep:
break
h_ = math.ceil(scale * h_)
w_ = math.ceil(scale * w_)

# truncate ids to exactly len_keep
# flip a coin to remove from top or bottom
# note this is similar to capi, but they only remove from bottom
ids_keep = block_mask.flatten().nonzero(as_tuple=False).squeeze()
if random.randint(0, 1):
ids_keep = ids_keep[:len_keep]
else:
ids_keep = ids_keep[len(ids_keep) - len_keep :]

visible_mask_patches = torch.zeros_like(mask_patches)
visible_mask_patches[ids_keep] = 1
visible_mask = rearrange(
visible_mask_patches,
"(h w) (p q) -> (h p) (w q)",
h=grid_h,
w=grid_w,
p=p,
q=q,
)
return visible_mask

class InverseBlockMasking(nn.Module):
def __init__(
self,
mask_ratio: float,
img_size: int | tuple[int, int],
patch_size: int | tuple[int, int],
*,
roll: bool = True,
min_aspect: float = 1.0,
max_aspect: float | None = None,
):
super().__init__()
self.mask_ratio = mask_ratio
self.img_size = to_2tuple(img_size)
self.patch_size = to_2tuple(patch_size)
self.roll = roll
self.min_aspect = min_aspect
self.max_aspect = max_aspect
self.inverse_block_masking = _inverse_block_masking

def forward(self, img_mask: torch.Tensor) -> torch.Tensor:
return self.inverse_block_masking(
img_mask,
mask_ratio=self.mask_ratio,
patch_size=self.patch_size,
roll=self.roll,
min_aspect=self.min_aspect,
max_aspect=self.max_aspect
)

MASKING_DICT = {
"random": RandomMasking,
"tube": TubeMasking,
"inverse": InverseBlockMasking,
}


Expand Down