From 19eb6d4b222860b7f7234a8a71ee2feff63dca4a Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 21 Aug 2024 07:11:11 -0700 Subject: [PATCH] 3d version of navit nested tensor --- setup.py | 2 +- vit_pytorch/na_vit_nested_tensor_3d.py | 332 +++++++++++++++++++++++++ 2 files changed, 333 insertions(+), 1 deletion(-) create mode 100644 vit_pytorch/na_vit_nested_tensor_3d.py diff --git a/setup.py b/setup.py index aa47dfe..f3cd734 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.7.7', + version = '1.7.9', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description=long_description, diff --git a/vit_pytorch/na_vit_nested_tensor_3d.py b/vit_pytorch/na_vit_nested_tensor_3d.py new file mode 100644 index 0000000..ae19ca0 --- /dev/null +++ b/vit_pytorch/na_vit_nested_tensor_3d.py @@ -0,0 +1,332 @@ +from __future__ import annotations + +from typing import List +from functools import partial + +import torch +import packaging.version as pkg_version + +if pkg_version.parse(torch.__version__) < pkg_version.parse('2.4'): + print('nested tensor NaViT was tested on pytorch 2.4') + +from torch import nn, Tensor +import torch.nn.functional as F +from torch.nn import Module, ModuleList +from torch.nested import nested_tensor + +from einops import rearrange +from einops.layers.torch import Rearrange + +# helpers + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +def divisible_by(numer, denom): + return (numer % denom) == 0 + +# feedforward + +def FeedForward(dim, hidden_dim, dropout = 0.): + return nn.Sequential( + nn.LayerNorm(dim, bias = False), + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + +class Attention(Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + self.norm = nn.LayerNorm(dim, bias = False) + + dim_inner = heads * dim_head + self.heads = heads + self.dim_head = dim_head + + self.to_queries = nn.Linear(dim, dim_inner, bias = False) + self.to_keys = nn.Linear(dim, dim_inner, bias = False) + self.to_values = nn.Linear(dim, dim_inner, bias = False) + + # in the paper, they employ qk rmsnorm, a way to stabilize attention + # will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors + + self.query_norm = nn.LayerNorm(dim_head, bias = False) + self.key_norm = nn.LayerNorm(dim_head, bias = False) + + self.dropout = dropout + + self.to_out = nn.Linear(dim_inner, dim, bias = False) + + def forward( + self, + x, + context: Tensor | None = None + ): + + x = self.norm(x) + + # for attention pooling, one query pooling to entire sequence + + context = default(context, x) + + # queries, keys, values + + query = self.to_queries(x) + key = self.to_keys(context) + value = self.to_values(context) + + # split heads + + def split_heads(t): + return t.unflatten(-1, (self.heads, self.dim_head)).transpose(1, 2).contiguous() + + query, key, value = map(split_heads, (query, key, value)) + + # qk norm for attention stability + + query = self.query_norm(query) + key = self.key_norm(key) + + # attention + + out = F.scaled_dot_product_attention( + query, key, value, + dropout_p = self.dropout if self.training else 0. + ) + + # merge heads + + out = out.transpose(1, 2).flatten(-2) + + return self.to_out(out) + +class Transformer(Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = ModuleList([]) + + for _ in range(depth): + self.layers.append(ModuleList([ + Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout), + FeedForward(dim, mlp_dim, dropout = dropout) + ])) + + self.norm = nn.LayerNorm(dim, bias = False) + + def forward(self, x): + + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + + return self.norm(x) + +class NaViT(Module): + def __init__( + self, + *, + image_size, + max_frames, + patch_size, + frame_patch_size, + num_classes, + dim, + depth, + heads, + mlp_dim, + channels = 3, + dim_head = 64, + dropout = 0., + emb_dropout = 0., + token_dropout_prob: float | None = None + ): + super().__init__() + image_height, image_width = pair(image_size) + + # what percent of tokens to dropout + # if int or float given, then assume constant dropout prob + # otherwise accept a callback that in turn calculates dropout prob from height and width + + self.token_dropout_prob = token_dropout_prob + + # calculate patching related stuff + + assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.' + assert divisible_by(max_frames, frame_patch_size) + + patch_frame_dim, patch_height_dim, patch_width_dim = (max_frames // frame_patch_size), (image_height // patch_size), (image_width // patch_size) + + patch_dim = channels * (patch_size ** 2) * frame_patch_size + + self.channels = channels + self.patch_size = patch_size + self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c p1 p2 pf)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size) + + self.to_patch_embedding = nn.Sequential( + nn.LayerNorm(patch_dim), + nn.Linear(patch_dim, dim), + nn.LayerNorm(dim), + ) + + self.pos_embed_frame = nn.Parameter(torch.randn(patch_frame_dim, dim)) + self.pos_embed_height = nn.Parameter(torch.randn(patch_height_dim, dim)) + self.pos_embed_width = nn.Parameter(torch.randn(patch_width_dim, dim)) + + self.dropout = nn.Dropout(emb_dropout) + + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) + + # final attention pooling queries + + self.attn_pool_queries = nn.Parameter(torch.randn(dim)) + self.attn_pool = Attention(dim = dim, dim_head = dim_head, heads = heads) + + # output to logits + + self.to_latent = nn.Identity() + + self.mlp_head = nn.Sequential( + nn.LayerNorm(dim, bias = False), + nn.Linear(dim, num_classes, bias = False) + ) + + @property + def device(self): + return next(self.parameters()).device + + def forward( + self, + volumes: List[Tensor], # different resolution images / CT scans + ): + batch, device = len(volumes), self.device + arange = partial(torch.arange, device = device) + + assert all([volume.ndim == 4 and volume.shape[0] == self.channels for volume in volumes]), f'all volumes must have {self.channels} channels and number of dimensions of {self.channels} (channels, frame, height, width)' + + all_patches = [self.to_patches(volume) for volume in volumes] + + # prepare factorized positional embedding height width indices + + positions = [] + + for patches in all_patches: + patch_frame, patch_height, patch_width = patches.shape[:3] + fhw_indices = torch.stack(torch.meshgrid((arange(patch_frame), arange(patch_height), arange(patch_width)), indexing = 'ij'), dim = -1) + fhw_indices = rearrange(fhw_indices, 'f h w c -> (f h w) c') + + positions.append(fhw_indices) + + # need the sizes to compute token dropout + positional embedding + + tokens = [rearrange(patches, 'f h w d -> (f h w) d') for patches in all_patches] + + # handle token dropout + + seq_lens = torch.tensor([i.shape[0] for i in tokens], device = device) + + if self.training and self.token_dropout_prob > 0: + + keep_seq_lens = ((1. - self.token_dropout_prob) * seq_lens).int().clamp(min = 1) + + kept_tokens = [] + kept_positions = [] + + for one_image_tokens, one_image_positions, seq_len, num_keep in zip(tokens, positions, seq_lens, keep_seq_lens): + keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices + + one_image_kept_tokens = one_image_tokens[keep_indices] + one_image_kept_positions = one_image_positions[keep_indices] + + kept_tokens.append(one_image_kept_tokens) + kept_positions.append(one_image_kept_positions) + + tokens, positions, seq_lens = kept_tokens, kept_positions, keep_seq_lens + + # add all height and width factorized positions + + + frame_indices, height_indices, width_indices = torch.cat(positions).unbind(dim = -1) + frame_embed, height_embed, width_embed = self.pos_embed_frame[frame_indices], self.pos_embed_height[height_indices], self.pos_embed_width[width_indices] + + pos_embed = frame_embed + height_embed + width_embed + + # use nested tensor for transformers and save on padding computation + + tokens = torch.cat(tokens) + + # linear projection to patch embeddings + + tokens = self.to_patch_embedding(tokens) + + # absolute positions + + tokens = tokens + pos_embed + + tokens = nested_tensor(tokens.split(seq_len.tolist()), layout = torch.jagged, device = device) + + # embedding dropout + + tokens = self.dropout(tokens) + + # transformer + + tokens = self.transformer(tokens) + + # attention pooling + # will use a jagged tensor for queries, as SDPA requires all inputs to be jagged, or not + + attn_pool_queries = [rearrange(self.attn_pool_queries, '... -> 1 ...')] * batch + + attn_pool_queries = nested_tensor(attn_pool_queries, layout = torch.jagged) + + pooled = self.attn_pool(attn_pool_queries, tokens) + + # back to unjagged + + logits = torch.stack(pooled.unbind()) + + logits = rearrange(logits, 'b 1 d -> b d') + + logits = self.to_latent(logits) + + return self.mlp_head(logits) + +# quick test + +if __name__ == '__main__': + + # works for torch 2.4 + + v = NaViT( + image_size = 256, + max_frames = 8, + patch_size = 32, + frame_patch_size = 2, + num_classes = 1000, + dim = 1024, + depth = 6, + heads = 16, + mlp_dim = 2048, + dropout = 0., + emb_dropout = 0., + token_dropout_prob = 0.1 + ) + + # 5 volumetric data (videos or CT scans) of different resolutions - List[Tensor] + + volumes = [ + torch.randn(3, 2, 256, 256), torch.randn(3, 8, 128, 128), + torch.randn(3, 4, 128, 256), torch.randn(3, 2, 256, 128), + torch.randn(3, 4, 64, 256) + ] + + assert v(volumes).shape == (5, 1000)