diff --git a/src/flat_mae/masking.py b/src/flat_mae/masking.py index 400baf6..e010c21 100644 --- a/src/flat_mae/masking.py +++ b/src/flat_mae/masking.py @@ -3,9 +3,13 @@ # References: # capi: https://github.com/facebookresearch/capi/blob/main/data.py +import math import torch +import random +import numpy as np import torch.nn as nn import torch.nn.functional as F +from einops import rearrange from torch import Tensor from torch.utils.data import default_collate from jaxtyping import Float, Int @@ -78,13 +82,162 @@ def __init__( super().__init__(mask_ratio=mask_ratio, img_size=img_size, patch_size=patch_size) -# TODO: -# - inverse block masking +def _make_block_mask(x: int, y: int, h: int, w: int, shape: tuple[int, int], roll: bool = True) -> torch.Tensor: + """ + Make a rectangular block mask where (x, y) is the block center, (h, w) is the block + shape, and shape is the grid shape. + If roll is True, the mask is wrapped around the edges, and otherwise cropped. + """ + H, W = shape + + # (x, y) is box center + top = y - h // 2 + left = x - w // 2 + + # row and column indices of the block (possibly over the edges) + y_ids = torch.arange(top, top + h) + x_ids = torch.arange(left, left + w) + + # handle edges. wrap around if roll is true, otherwise clip. + # note that if roll is not enabled, the box can be smaller than (h, w). + if roll: + y_ids = y_ids % H + x_ids = x_ids % W + else: + y_ids = y_ids[(y_ids >= 0) & (y_ids < H)] + x_ids = x_ids[(x_ids >= 0) & (x_ids < W)] + + # create block mask as intersection of row and column mask. + y_mask = torch.zeros(shape) + y_mask[y_ids, :] = 1 + x_mask = torch.zeros(shape) + x_mask[:, x_ids] = 1 + mask = y_mask * x_mask + return mask + + +def _inverse_block_masking( + mask: torch.Tensor, + *, + mask_ratio: float, + patch_size: int | tuple[int, int], + roll: bool = True, + min_aspect: float = 1.0, + max_aspect: float | None = None, +) -> torch.Tensor: + """Sample an inverse block mask on a 2D tensor shape (H, W).""" + + H, W = mask.shape # shape of the mask + p, q = to_2tuple(patch_size) + + grid_h = H // p + grid_w = W // q + + # patchify mask to [grid_h, grid_w] + mask_patches = rearrange( + mask, + "(h p) (w q) -> (h w) (p q)", + h=grid_h, + w=grid_w, + p=p, + q=q, + ) + L, D = mask_patches.shape + + patch_mask = mask_patches.sum(dim=-1).clip(max=1) + patch_mask = patch_mask.reshape(grid_h, grid_w) + + len_keep = int((1 - mask_ratio) * L) + total_patches = int(patch_mask.sum().item()) + len_keep = min(len_keep, total_patches) + + max_aspect = max_aspect or 1 / min_aspect + min_lar, max_lar = ( + np.log(min_aspect), + np.log(max_aspect), + ) # get the aspect ratio in log space to treat the ratios symmetrically + # sample an aspect ratio + # note that we don't need to worry whether the aspect is too big/small to fit in the + # image, since we scale the box below anyway. + aspect_ratio = math.exp(np.random.uniform(min_lar, max_lar)) + + # height and width of the block given the aspect ratio + # len_keep: h * w + # aspect: h / w + h = math.ceil(math.sqrt(len_keep * aspect_ratio)) + w = math.ceil(math.sqrt(len_keep / aspect_ratio)) + + # sample a random position for the box center + y = random.randint(0, grid_h - 1) + x = random.randint(0, grid_w - 1) + + # increase the block size until it covers enough valid patches + scale = 1.2 + h_, w_ = h, w + while True: + block_mask = _make_block_mask(x, y, h_, w_, (grid_h, grid_w), roll=roll) + block_mask = patch_mask * block_mask + if block_mask.sum() >= len_keep: + break + h_ = math.ceil(scale * h_) + w_ = math.ceil(scale * w_) + + # truncate ids to exactly len_keep + # flip a coin to remove from top or bottom + # note this is similar to capi, but they only remove from bottom + ids_keep = block_mask.flatten().nonzero(as_tuple=False).squeeze() + if random.randint(0, 1): + ids_keep = ids_keep[:len_keep] + else: + ids_keep = ids_keep[len(ids_keep) - len_keep :] + + visible_mask_patches = torch.zeros_like(mask_patches) + visible_mask_patches[ids_keep] = 1 + visible_mask = rearrange( + visible_mask_patches, + "(h w) (p q) -> (h p) (w q)", + h=grid_h, + w=grid_w, + p=p, + q=q, + ) + return visible_mask + +class InverseBlockMasking(nn.Module): + def __init__( + self, + mask_ratio: float, + img_size: int | tuple[int, int], + patch_size: int | tuple[int, int], + *, + roll: bool = True, + min_aspect: float = 1.0, + max_aspect: float | None = None, + ): + super().__init__() + self.mask_ratio = mask_ratio + self.img_size = to_2tuple(img_size) + self.patch_size = to_2tuple(patch_size) + self.roll = roll + self.min_aspect = min_aspect + self.max_aspect = max_aspect + self.inverse_block_masking = _inverse_block_masking + + def forward(self, img_mask: torch.Tensor) -> torch.Tensor: + return self.inverse_block_masking( + img_mask, + mask_ratio=self.mask_ratio, + patch_size=self.patch_size, + roll=self.roll, + min_aspect=self.min_aspect, + max_aspect=self.max_aspect + ) MASKING_DICT = { "random": RandomMasking, "tube": TubeMasking, + "inverse": InverseBlockMasking, }