Skip to content
5 changes: 5 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,11 @@ Nets
.. autoclass:: ViTAutoEnc
:members:

`MaskedAutoEncoderViT`
~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: MaskedAutoEncoderViT
:members:

`FullyConnectedNet`
~~~~~~~~~~~~~~~~~~~
.. autoclass:: FullyConnectedNet
Expand Down
1 change: 1 addition & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from .generator import Generator
from .highresnet import HighResBlock, HighResNet
from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet
from .masked_autoencoder_vit import MaskedAutoEncoderViT
from .mednext import (
MedNeXt,
MedNext,
Expand Down
211 changes: 211 additions & 0 deletions monai/networks/nets/masked_autoencoder_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Sequence

import numpy as np
import torch
import torch.nn as nn

from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding
from monai.networks.blocks.transformerblock import TransformerBlock
from monai.networks.layers import trunc_normal_
from monai.utils import ensure_tuple_rep
from monai.utils.module import look_up_option

SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"}

__all__ = ["MaskedAutoEncoderViT"]


class MaskedAutoEncoderViT(nn.Module):
"""
Masked Autoencoder (ViT), based on: "Kaiming et al.,
Masked Autoencoders Are Scalable Vision Learners <https://arxiv.org/abs/2111.06377>"
Only a subset of the patches passes through the encoder. The decoder tries to reconstruct
the masked patches, resulting in improved training speed.
"""

def __init__(
self,
in_channels: int,
img_size: Sequence[int] | int,
patch_size: Sequence[int] | int,
hidden_size: int = 768,
mlp_dim: int = 512,
num_layers: int = 12,
num_heads: int = 12,
masking_ratio: float = 0.75,
decoder_hidden_size: int = 384,
decoder_mlp_dim: int = 512,
decoder_num_layers: int = 4,
decoder_num_heads: int = 12,
proj_type: str = "conv",
pos_embed_type: str = "sincos",
decoder_pos_embed_type: str = "sincos",
dropout_rate: float = 0.0,
spatial_dims: int = 3,
qkv_bias: bool = False,
save_attn: bool = False,
) -> None:
"""
Args:
in_channels: dimension of input channels or the number of channels for input.
img_size: dimension of input image.
patch_size: dimension of patch size
hidden_size: dimension of hidden layer. Defaults to 768.
mlp_dim: dimension of feedforward layer. Defaults to 512.
num_layers: number of transformer blocks. Defaults to 12.
num_heads: number of attention heads. Defaults to 12.
masking_ratio: ratio of patches to be masked. Defaults to 0.75.
decoder_hidden_size: dimension of hidden layer for decoder. Defaults to 384.
decoder_mlp_dim: dimension of feedforward layer for decoder. Defaults to 512.
decoder_num_layers: number of transformer blocks for decoder. Defaults to 4.
decoder_num_heads: number of attention heads for decoder. Defaults to 12.
proj_type: position embedding layer type. Defaults to "conv".
pos_embed_type: position embedding layer type. Defaults to "sincos".
decoder_pos_embed_type: position embedding layer type for decoder. Defaults to "sincos".
dropout_rate: fraction of the input units to drop. Defaults to 0.0.
spatial_dims: number of spatial dimensions. Defaults to 3.
qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False.
save_attn: to make accessible the attention in self attention block. Defaults to False.
Examples::
# for single channel input with image size of (96,96,96), and sin-cos positional encoding
>>> net = MaskedAutoEncoderViT(in_channels=1, img_size=(96,96,96), patch_size=(16,16,16),
pos_embed_type='sincos')
# for 3-channel with image size of (128,128,128) and a learnable positional encoding
>>> net = MaskedAutoEncoderViT(in_channels=3, img_size=128, patch_size=16, pos_embed_type='learnable')
# for 3-channel with image size of (224,224) and a masking ratio of 0.25
>>> net = MaskedAutoEncoderViT(in_channels=3, img_size=(224,224), patch_size=(16,16), masking_ratio=0.25,
spatial_dims=2)
"""

super().__init__()

if not (0 <= dropout_rate <= 1):
raise ValueError(f"dropout_rate should be between 0 and 1, got {dropout_rate}.")

if hidden_size % num_heads != 0:
raise ValueError("hidden_size should be divisible by num_heads.")

if decoder_hidden_size % decoder_num_heads != 0:
raise ValueError("decoder_hidden_size should be divisible by decoder_num_heads.")

self.patch_size = ensure_tuple_rep(patch_size, spatial_dims)
self.img_size = ensure_tuple_rep(img_size, spatial_dims)
self.spatial_dims = spatial_dims
for m, p in zip(self.img_size, self.patch_size):
if m % p != 0:
raise ValueError(f"patch_size={patch_size} should be divisible by img_size={img_size}.")

self.decoder_hidden_size = decoder_hidden_size

if masking_ratio <= 0 or masking_ratio >= 1:
raise ValueError(f"masking_ratio should be in the range (0, 1), got {masking_ratio}.")

self.masking_ratio = masking_ratio
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))

self.patch_embedding = PatchEmbeddingBlock(
in_channels=in_channels,
img_size=img_size,
patch_size=patch_size,
hidden_size=hidden_size,
num_heads=num_heads,
proj_type=proj_type,
pos_embed_type=pos_embed_type,
dropout_rate=dropout_rate,
spatial_dims=self.spatial_dims,
)
blocks = [
TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn)
for _ in range(num_layers)
]
self.blocks = nn.Sequential(*blocks, nn.LayerNorm(hidden_size))

# decoder
self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size)

self.mask_tokens = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size))

self.decoder_pos_embed_type = look_up_option(decoder_pos_embed_type, SUPPORTED_POS_EMBEDDING_TYPES)
self.decoder_pos_embedding = nn.Parameter(torch.zeros(1, self.patch_embedding.n_patches, decoder_hidden_size))

decoder_blocks = [
TransformerBlock(decoder_hidden_size, decoder_mlp_dim, decoder_num_heads, dropout_rate, qkv_bias, save_attn)
for _ in range(decoder_num_layers)
]
self.decoder_blocks = nn.Sequential(*decoder_blocks, nn.LayerNorm(decoder_hidden_size))
self.decoder_pred = nn.Linear(decoder_hidden_size, int(np.prod(self.patch_size)) * in_channels)

self._init_weights()

def _init_weights(self):
"""
similar to monai/networks/blocks/patchembedding.py for the decoder positional encoding and for mask and
classification tokens
"""
if self.decoder_pos_embed_type == "none":
pass
elif self.decoder_pos_embed_type == "learnable":
trunc_normal_(self.decoder_pos_embedding, mean=0.0, std=0.02, a=-2.0, b=2.0)
elif self.decoder_pos_embed_type == "sincos":
grid_size = []
for in_size, pa_size in zip(self.img_size, self.patch_size):
grid_size.append(in_size // pa_size)

self.decoder_pos_embedding = build_sincos_position_embedding(
grid_size, self.decoder_hidden_size, self.spatial_dims
)

else:
raise ValueError(f"decoder_pos_embed_type {self.decoder_pos_embed_type} not supported.")

# initialize patch_embedding like nn.Linear (instead of nn.Conv2d)
trunc_normal_(self.mask_tokens, mean=0.0, std=0.02, a=-2.0, b=2.0)
trunc_normal_(self.cls_token, mean=0.0, std=0.02, a=-2.0, b=2.0)

def _masking(self, x, masking_ratio: float | None = None):
batch_size, num_tokens, _ = x.shape
percentage_to_keep = 1 - masking_ratio if masking_ratio is not None else 1 - self.masking_ratio
selected_indices = torch.multinomial(
torch.ones(batch_size, num_tokens), int(percentage_to_keep * num_tokens), replacement=False
)
x_masked = x[torch.arange(batch_size).unsqueeze(1), selected_indices] # gather the selected tokens
mask = torch.ones(batch_size, num_tokens, dtype=torch.int).to(x.device)
mask[torch.arange(batch_size).unsqueeze(-1), selected_indices] = 0

return x_masked, selected_indices, mask

def forward(self, x, masking_ratio: float | None = None):
x = self.patch_embedding(x)
x, selected_indices, mask = self._masking(x, masking_ratio=masking_ratio)

cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)

x = self.blocks(x)

# decoder
x = self.decoder_embed(x)

x_ = self.mask_tokens.repeat(x.shape[0], mask.shape[1], 1)
x_[torch.arange(x.shape[0]).unsqueeze(-1), selected_indices] = x[:, 1:, :] # no cls token
x_ = x_ + self.decoder_pos_embedding
x = torch.cat([x[:, :1, :], x_], dim=1)
x = self.decoder_blocks(x)
x = self.decoder_pred(x)

x = x[:, 1:, :]
return x, mask
160 changes: 160 additions & 0 deletions tests/test_masked_autoencoder_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.nets.masked_autoencoder_vit import MaskedAutoEncoderViT
from tests.utils import skip_if_quick

TEST_CASE_MaskedAutoEncoderViT = []
for masking_ratio in [0.5]:
for dropout_rate in [0.6]:
for in_channels in [4]:
for hidden_size in [768]:
for img_size in [96, 128]:
for patch_size in [16]:
for num_heads in [12]:
for mlp_dim in [3072]:
for num_layers in [4]:
for decoder_hidden_size in [384]:
for decoder_mlp_dim in [512]:
for decoder_num_layers in [4]:
for decoder_num_heads in [16]:
for pos_embed_type in ["sincos", "learnable"]:
for proj_type in ["conv", "perceptron"]:
for nd in (2, 3):
test_case = [
{
"in_channels": in_channels,
"img_size": (img_size,) * nd,
"patch_size": (patch_size,) * nd,
"hidden_size": hidden_size,
"mlp_dim": mlp_dim,
"num_layers": num_layers,
"decoder_hidden_size": decoder_hidden_size,
"decoder_mlp_dim": decoder_mlp_dim,
"decoder_num_layers": decoder_num_layers,
"decoder_num_heads": decoder_num_heads,
"pos_embed_type": pos_embed_type,
"masking_ratio": masking_ratio,
"decoder_pos_embed_type": pos_embed_type,
"num_heads": num_heads,
"proj_type": proj_type,
"dropout_rate": dropout_rate,
},
(2, in_channels, *([img_size] * nd)),
(
2,
(img_size // patch_size) ** nd,
in_channels * (patch_size**nd),
),
]
if nd == 2:
test_case[0]["spatial_dims"] = 2 # type: ignore
TEST_CASE_MaskedAutoEncoderViT.append(test_case)

TEST_CASE_ill_args = [
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (16, 16, 16), "dropout_rate": 5.0}],
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "pos_embed_type": "sin"}],
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "decoder_pos_embed_type": "sin"}],
[{"in_channels": 1, "img_size": (32, 32, 32), "patch_size": (64, 64, 64)}],
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "num_layers": 12, "num_heads": 14}],
[{"in_channels": 1, "img_size": (97, 97, 97), "patch_size": (16, 16, 16)}],
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "masking_ratio": 1.1}],
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "masking_ratio": -0.1}],
]


@skip_if_quick
class TestMaskedAutoencoderViT(unittest.TestCase):

@parameterized.expand(TEST_CASE_MaskedAutoEncoderViT)
def test_shape(self, input_param, input_shape, expected_shape):
net = MaskedAutoEncoderViT(**input_param)
with eval_mode(net):
result, _ = net(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)

def test_frozen_pos_embedding(self):
net = MaskedAutoEncoderViT(in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16))

self.assertEqual(net.decoder_pos_embedding.requires_grad, False)

@parameterized.expand(TEST_CASE_ill_args)
def test_ill_arg(self, input_param):
with self.assertRaises(ValueError):
MaskedAutoEncoderViT(**input_param)

def test_access_attn_matrix(self):
# input format
in_channels = 1
img_size = (96, 96, 96)
patch_size = (16, 16, 16)
in_shape = (1, in_channels, img_size[0], img_size[1], img_size[2])

# no data in the matrix
no_matrix_acess_blk = MaskedAutoEncoderViT(in_channels=in_channels, img_size=img_size, patch_size=patch_size)
no_matrix_acess_blk(torch.randn(in_shape))
assert isinstance(no_matrix_acess_blk.blocks[0].attn.att_mat, torch.Tensor)
# no of elements is zero
assert no_matrix_acess_blk.blocks[0].attn.att_mat.nelement() == 0

# be able to acess the attention matrix
matrix_acess_blk = MaskedAutoEncoderViT(
in_channels=in_channels, img_size=img_size, patch_size=patch_size, save_attn=True
)
matrix_acess_blk(torch.randn(in_shape))

assert matrix_acess_blk.blocks[0].attn.att_mat.shape == (in_shape[0], 12, 55, 55)

def test_masking_ratio(self):
# input format
in_channels = 1
img_size = (96, 96, 96)
patch_size = (16, 16, 16)
in_shape = (1, in_channels, img_size[0], img_size[1], img_size[2])

# masking ratio 0.25
masking_ratio_blk = MaskedAutoEncoderViT(
in_channels=in_channels, img_size=img_size, patch_size=patch_size, masking_ratio=0.25, save_attn=True
)
masking_ratio_blk(torch.randn(in_shape))
desired_num_tokens = int(
(img_size[0] // patch_size[0])
* (img_size[1] // patch_size[1])
* (img_size[2] // patch_size[2])
* (1 - 0.25)
)
assert masking_ratio_blk.blocks[0].attn.att_mat.shape[-1] - 1 == desired_num_tokens

# masking ratio 0.33
masking_ratio_blk = MaskedAutoEncoderViT(
in_channels=in_channels, img_size=img_size, patch_size=patch_size, masking_ratio=0.33, save_attn=True
)
masking_ratio_blk(torch.randn(in_shape))
desired_num_tokens = int(
(img_size[0] // patch_size[0])
* (img_size[1] // patch_size[1])
* (img_size[2] // patch_size[2])
* (1 - 0.33)
)

assert masking_ratio_blk.blocks[0].attn.att_mat.shape[-1] - 1 == desired_num_tokens


if __name__ == "__main__":
unittest.main()
Loading