Skip to content

Commit d40ec95

Browse files
authored
Merge branch 'dev' into pythonicworkflow
2 parents 70dc9b5 + 649c7c8 commit d40ec95

File tree

4 files changed

+64
-18
lines changed

4 files changed

+64
-18
lines changed

monai/networks/blocks/selfattention.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from __future__ import annotations
1313

14-
from typing import Tuple, Union
14+
from typing import Optional, Tuple, Union
1515

1616
import torch
1717
import torch.nn as nn
@@ -154,10 +154,12 @@ def __init__(
154154
)
155155
self.input_size = input_size
156156

157-
def forward(self, x):
157+
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
158158
"""
159159
Args:
160160
x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
161+
attn_mask (torch.Tensor, optional): mask to apply to the attention matrix.
162+
B x (s_dim_1 * ... * s_dim_n). Defaults to None.
161163
162164
Return:
163165
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
@@ -176,7 +178,13 @@ def forward(self, x):
176178

177179
if self.use_flash_attention:
178180
x = F.scaled_dot_product_attention(
179-
query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
181+
query=q,
182+
key=k,
183+
value=v,
184+
attn_mask=attn_mask,
185+
scale=self.scale,
186+
dropout_p=self.dropout_rate,
187+
is_causal=self.causal,
180188
)
181189
else:
182190
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
@@ -186,10 +194,16 @@ def forward(self, x):
186194
att_mat = self.rel_positional_embedding(x, att_mat, q)
187195

188196
if self.causal:
197+
if attn_mask is not None:
198+
raise ValueError("Causal attention does not support attention masks.")
189199
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))
190200

191-
att_mat = att_mat.softmax(dim=-1)
201+
if attn_mask is not None:
202+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
203+
attn_mask = attn_mask.expand(-1, self.num_heads, -1, -1)
204+
att_mat = att_mat.masked_fill(attn_mask == 0, float("-inf"))
192205

206+
att_mat = att_mat.softmax(dim=-1)
193207
if self.save_attn:
194208
# no gradients and new tensor;
195209
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html

monai/networks/blocks/transformerblock.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,10 @@ def __init__(
9090
use_flash_attention=use_flash_attention,
9191
)
9292

93-
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
94-
x = x + self.attn(self.norm1(x))
93+
def forward(
94+
self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None
95+
) -> torch.Tensor:
96+
x = x + self.attn(self.norm1(x), attn_mask=attn_mask)
9597
if self.with_cross_attention:
9698
x = x + self.cross_attn(self.norm_cross_attn(x), context=context)
9799
x = x + self.mlp(self.norm2(x))

monai/networks/nets/swin_unetr.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
import itertools
1515
from collections.abc import Sequence
16-
from typing import Final
1716

1817
import numpy as np
1918
import torch
@@ -51,8 +50,6 @@ class SwinUNETR(nn.Module):
5150
<https://arxiv.org/abs/2201.01266>"
5251
"""
5352

54-
patch_size: Final[int] = 2
55-
5653
@deprecated_arg(
5754
name="img_size",
5855
since="1.3",
@@ -65,18 +62,24 @@ def __init__(
6562
img_size: Sequence[int] | int,
6663
in_channels: int,
6764
out_channels: int,
65+
patch_size: int = 2,
6866
depths: Sequence[int] = (2, 2, 2, 2),
6967
num_heads: Sequence[int] = (3, 6, 12, 24),
68+
window_size: Sequence[int] | int = 7,
69+
qkv_bias: bool = True,
70+
mlp_ratio: float = 4.0,
7071
feature_size: int = 24,
7172
norm_name: tuple | str = "instance",
7273
drop_rate: float = 0.0,
7374
attn_drop_rate: float = 0.0,
7475
dropout_path_rate: float = 0.0,
7576
normalize: bool = True,
77+
norm_layer: type[LayerNorm] = nn.LayerNorm,
78+
patch_norm: bool = True,
7679
use_checkpoint: bool = False,
7780
spatial_dims: int = 3,
78-
downsample="merging",
79-
use_v2=False,
81+
downsample: str | nn.Module = "merging",
82+
use_v2: bool = False,
8083
) -> None:
8184
"""
8285
Args:
@@ -86,14 +89,20 @@ def __init__(
8689
It will be removed in an upcoming version.
8790
in_channels: dimension of input channels.
8891
out_channels: dimension of output channels.
92+
patch_size: size of the patch token.
8993
feature_size: dimension of network feature size.
9094
depths: number of layers in each stage.
9195
num_heads: number of attention heads.
96+
window_size: local window size.
97+
qkv_bias: add a learnable bias to query, key, value.
98+
mlp_ratio: ratio of mlp hidden dim to embedding dim.
9299
norm_name: feature normalization type and arguments.
93100
drop_rate: dropout rate.
94101
attn_drop_rate: attention dropout rate.
95102
dropout_path_rate: drop path rate.
96103
normalize: normalize output intermediate features in each stage.
104+
norm_layer: normalization layer.
105+
patch_norm: whether to apply normalization to the patch embedding.
97106
use_checkpoint: use gradient checkpointing for reduced memory usage.
98107
spatial_dims: number of spatial dims.
99108
downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a
@@ -116,13 +125,15 @@ def __init__(
116125

117126
super().__init__()
118127

119-
img_size = ensure_tuple_rep(img_size, spatial_dims)
120-
patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)
121-
window_size = ensure_tuple_rep(7, spatial_dims)
122-
123128
if spatial_dims not in (2, 3):
124129
raise ValueError("spatial dimension should be 2 or 3.")
125130

131+
self.patch_size = patch_size
132+
133+
img_size = ensure_tuple_rep(img_size, spatial_dims)
134+
patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)
135+
window_size = ensure_tuple_rep(window_size, spatial_dims)
136+
126137
self._check_input_size(img_size)
127138

128139
if not (0 <= drop_rate <= 1):
@@ -146,12 +157,13 @@ def __init__(
146157
patch_size=patch_sizes,
147158
depths=depths,
148159
num_heads=num_heads,
149-
mlp_ratio=4.0,
150-
qkv_bias=True,
160+
mlp_ratio=mlp_ratio,
161+
qkv_bias=qkv_bias,
151162
drop_rate=drop_rate,
152163
attn_drop_rate=attn_drop_rate,
153164
drop_path_rate=dropout_path_rate,
154-
norm_layer=nn.LayerNorm,
165+
norm_layer=norm_layer,
166+
patch_norm=patch_norm,
155167
use_checkpoint=use_checkpoint,
156168
spatial_dims=spatial_dims,
157169
downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample,

tests/test_selfattention.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,24 @@ def test_causal(self):
122122
# check upper triangular part of the attention matrix is zero
123123
assert torch.triu(block.att_mat, diagonal=1).sum() == 0
124124

125+
def test_masked_selfattention(self):
126+
n = 64
127+
block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, sequence_length=16, save_attn=True)
128+
input_shape = (1, n, 128)
129+
# generate a mask randomly with zeros and ones of shape (1, n)
130+
mask = torch.randint(0, 2, (1, n)).bool()
131+
block(torch.randn(input_shape), attn_mask=mask)
132+
att_mat = block.att_mat.squeeze()
133+
# ensure all masked columns are zeros
134+
assert torch.allclose(att_mat[:, ~mask.squeeze(0)], torch.zeros_like(att_mat[:, ~mask.squeeze(0)]))
135+
136+
def test_causal_and_mask(self):
137+
with self.assertRaises(ValueError):
138+
block = SABlock(hidden_size=128, num_heads=1, causal=True, sequence_length=64)
139+
inputs = torch.randn(2, 64, 128)
140+
mask = torch.randint(0, 2, (2, 64)).bool()
141+
block(inputs, attn_mask=mask)
142+
125143
@skipUnless(has_einops, "Requires einops")
126144
def test_access_attn_matrix(self):
127145
# input format

0 commit comments

Comments
 (0)