Skip to content

Commit

Permalink
3d version of navit nested tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 21, 2024
1 parent bed48b5 commit 19eb6d4
Show file tree
Hide file tree
Showing 2 changed files with 333 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.7.7',
version = '1.7.9',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,
Expand Down
332 changes: 332 additions & 0 deletions vit_pytorch/na_vit_nested_tensor_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,332 @@
from __future__ import annotations

from typing import List
from functools import partial

import torch
import packaging.version as pkg_version

if pkg_version.parse(torch.__version__) < pkg_version.parse('2.4'):
print('nested tensor NaViT was tested on pytorch 2.4')

from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import Module, ModuleList
from torch.nested import nested_tensor

from einops import rearrange
from einops.layers.torch import Rearrange

# helpers

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def pair(t):
return t if isinstance(t, tuple) else (t, t)

def divisible_by(numer, denom):
return (numer % denom) == 0

# feedforward

def FeedForward(dim, hidden_dim, dropout = 0.):
return nn.Sequential(
nn.LayerNorm(dim, bias = False),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)

class Attention(Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
self.norm = nn.LayerNorm(dim, bias = False)

dim_inner = heads * dim_head
self.heads = heads
self.dim_head = dim_head

self.to_queries = nn.Linear(dim, dim_inner, bias = False)
self.to_keys = nn.Linear(dim, dim_inner, bias = False)
self.to_values = nn.Linear(dim, dim_inner, bias = False)

# in the paper, they employ qk rmsnorm, a way to stabilize attention
# will use layernorm in place of rmsnorm, which has been shown to work in certain papers. requires l2norm on non-ragged dimension to be supported in nested tensors

self.query_norm = nn.LayerNorm(dim_head, bias = False)
self.key_norm = nn.LayerNorm(dim_head, bias = False)

self.dropout = dropout

self.to_out = nn.Linear(dim_inner, dim, bias = False)

def forward(
self,
x,
context: Tensor | None = None
):

x = self.norm(x)

# for attention pooling, one query pooling to entire sequence

context = default(context, x)

# queries, keys, values

query = self.to_queries(x)
key = self.to_keys(context)
value = self.to_values(context)

# split heads

def split_heads(t):
return t.unflatten(-1, (self.heads, self.dim_head)).transpose(1, 2).contiguous()

query, key, value = map(split_heads, (query, key, value))

# qk norm for attention stability

query = self.query_norm(query)
key = self.key_norm(key)

# attention

out = F.scaled_dot_product_attention(
query, key, value,
dropout_p = self.dropout if self.training else 0.
)

# merge heads

out = out.transpose(1, 2).flatten(-2)

return self.to_out(out)

class Transformer(Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = ModuleList([])

for _ in range(depth):
self.layers.append(ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
FeedForward(dim, mlp_dim, dropout = dropout)
]))

self.norm = nn.LayerNorm(dim, bias = False)

def forward(self, x):

for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x

return self.norm(x)

class NaViT(Module):
def __init__(
self,
*,
image_size,
max_frames,
patch_size,
frame_patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
channels = 3,
dim_head = 64,
dropout = 0.,
emb_dropout = 0.,
token_dropout_prob: float | None = None
):
super().__init__()
image_height, image_width = pair(image_size)

# what percent of tokens to dropout
# if int or float given, then assume constant dropout prob
# otherwise accept a callback that in turn calculates dropout prob from height and width

self.token_dropout_prob = token_dropout_prob

# calculate patching related stuff

assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'
assert divisible_by(max_frames, frame_patch_size)

patch_frame_dim, patch_height_dim, patch_width_dim = (max_frames // frame_patch_size), (image_height // patch_size), (image_width // patch_size)

patch_dim = channels * (patch_size ** 2) * frame_patch_size

self.channels = channels
self.patch_size = patch_size
self.to_patches = Rearrange('c (f pf) (h p1) (w p2) -> f h w (c p1 p2 pf)', p1 = patch_size, p2 = patch_size, pf = frame_patch_size)

self.to_patch_embedding = nn.Sequential(
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)

self.pos_embed_frame = nn.Parameter(torch.randn(patch_frame_dim, dim))
self.pos_embed_height = nn.Parameter(torch.randn(patch_height_dim, dim))
self.pos_embed_width = nn.Parameter(torch.randn(patch_width_dim, dim))

self.dropout = nn.Dropout(emb_dropout)

self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

# final attention pooling queries

self.attn_pool_queries = nn.Parameter(torch.randn(dim))
self.attn_pool = Attention(dim = dim, dim_head = dim_head, heads = heads)

# output to logits

self.to_latent = nn.Identity()

self.mlp_head = nn.Sequential(
nn.LayerNorm(dim, bias = False),
nn.Linear(dim, num_classes, bias = False)
)

@property
def device(self):
return next(self.parameters()).device

def forward(
self,
volumes: List[Tensor], # different resolution images / CT scans
):
batch, device = len(volumes), self.device
arange = partial(torch.arange, device = device)

assert all([volume.ndim == 4 and volume.shape[0] == self.channels for volume in volumes]), f'all volumes must have {self.channels} channels and number of dimensions of {self.channels} (channels, frame, height, width)'

all_patches = [self.to_patches(volume) for volume in volumes]

# prepare factorized positional embedding height width indices

positions = []

for patches in all_patches:
patch_frame, patch_height, patch_width = patches.shape[:3]
fhw_indices = torch.stack(torch.meshgrid((arange(patch_frame), arange(patch_height), arange(patch_width)), indexing = 'ij'), dim = -1)
fhw_indices = rearrange(fhw_indices, 'f h w c -> (f h w) c')

positions.append(fhw_indices)

# need the sizes to compute token dropout + positional embedding

tokens = [rearrange(patches, 'f h w d -> (f h w) d') for patches in all_patches]

# handle token dropout

seq_lens = torch.tensor([i.shape[0] for i in tokens], device = device)

if self.training and self.token_dropout_prob > 0:

keep_seq_lens = ((1. - self.token_dropout_prob) * seq_lens).int().clamp(min = 1)

kept_tokens = []
kept_positions = []

for one_image_tokens, one_image_positions, seq_len, num_keep in zip(tokens, positions, seq_lens, keep_seq_lens):
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices

one_image_kept_tokens = one_image_tokens[keep_indices]
one_image_kept_positions = one_image_positions[keep_indices]

kept_tokens.append(one_image_kept_tokens)
kept_positions.append(one_image_kept_positions)

tokens, positions, seq_lens = kept_tokens, kept_positions, keep_seq_lens

# add all height and width factorized positions


frame_indices, height_indices, width_indices = torch.cat(positions).unbind(dim = -1)
frame_embed, height_embed, width_embed = self.pos_embed_frame[frame_indices], self.pos_embed_height[height_indices], self.pos_embed_width[width_indices]

pos_embed = frame_embed + height_embed + width_embed

# use nested tensor for transformers and save on padding computation

tokens = torch.cat(tokens)

# linear projection to patch embeddings

tokens = self.to_patch_embedding(tokens)

# absolute positions

tokens = tokens + pos_embed

tokens = nested_tensor(tokens.split(seq_len.tolist()), layout = torch.jagged, device = device)

# embedding dropout

tokens = self.dropout(tokens)

# transformer

tokens = self.transformer(tokens)

# attention pooling
# will use a jagged tensor for queries, as SDPA requires all inputs to be jagged, or not

attn_pool_queries = [rearrange(self.attn_pool_queries, '... -> 1 ...')] * batch

attn_pool_queries = nested_tensor(attn_pool_queries, layout = torch.jagged)

pooled = self.attn_pool(attn_pool_queries, tokens)

# back to unjagged

logits = torch.stack(pooled.unbind())

logits = rearrange(logits, 'b 1 d -> b d')

logits = self.to_latent(logits)

return self.mlp_head(logits)

# quick test

if __name__ == '__main__':

# works for torch 2.4

v = NaViT(
image_size = 256,
max_frames = 8,
patch_size = 32,
frame_patch_size = 2,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.,
emb_dropout = 0.,
token_dropout_prob = 0.1
)

# 5 volumetric data (videos or CT scans) of different resolutions - List[Tensor]

volumes = [
torch.randn(3, 2, 256, 256), torch.randn(3, 8, 128, 128),
torch.randn(3, 4, 128, 256), torch.randn(3, 2, 256, 128),
torch.randn(3, 4, 64, 256)
]

assert v(volumes).shape == (5, 1000)

0 comments on commit 19eb6d4

Please sign in to comment.