Skip to content

Commit 42db7c0

Browse files
committed
Refactoring to be easier to follow
Signed-off-by: Lucas Robinet <[email protected]>
1 parent c869d69 commit 42db7c0

File tree

2 files changed

+18
-24
lines changed

2 files changed

+18
-24
lines changed

monai/networks/blocks/selfattention.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def forward(self, x, attn_mask: torch.Tensor | None = None):
159159
Args:
160160
x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
161161
attn_mask (torch.Tensor, optional): mask to apply to the attention matrix.
162-
Defaults to None. B x N_heads x (s_dim_1 * ... * s_dim_n) x (s_dim_1 * ... * s_dim_n).
162+
B x (s_dim_1 * ... * s_dim_n). Defaults to None.
163163
164164
Return:
165165
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
@@ -194,14 +194,15 @@ def forward(self, x, attn_mask: torch.Tensor | None = None):
194194
att_mat = self.rel_positional_embedding(x, att_mat, q)
195195

196196
if self.causal:
197+
assert attn_mask is None, "Causal attention does not support attention masks."
197198
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))
198199

199200
if attn_mask is not None:
200-
attn_mask = attn_mask[:, None, :, None] * attn_mask[:, None, None, :]
201-
att_mat.masked_fill_(~attn_mask, torch.finfo(att_mat.dtype).min)
201+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
202+
attn_mask = attn_mask.expand(-1, self.num_heads, -1, -1)
203+
att_mat = att_mat.masked_fill(attn_mask == 0, float("-inf"))
202204

203205
att_mat = att_mat.softmax(dim=-1)
204-
205206
if self.save_attn:
206207
# no gradients and new tensor;
207208
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
@@ -215,13 +216,3 @@ def forward(self, x, attn_mask: torch.Tensor | None = None):
215216
x = self.out_proj(x)
216217
x = self.drop_output(x)
217218
return x
218-
219-
220-
if __name__ == "__main__":
221-
sa = SABlock(128, 1)
222-
x = torch.randn(1, 6, 128)
223-
mask = torch.ones((1, 6), dtype=torch.bool)
224-
mask[0][2] = False
225-
print(mask)
226-
out = sa(x, attn_mask=mask)
227-
print(out.shape)

tests/test_selfattention.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -123,19 +123,22 @@ def test_causal(self):
123123
assert torch.triu(block.att_mat, diagonal=1).sum() == 0
124124

125125
def test_masked_selfattention(self):
126-
n = 4
126+
n = 64
127127
block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, sequence_length=16, save_attn=True)
128128
input_shape = (1, n, 128)
129-
mask = torch.tensor([[1, 1, 1, 0]]).bool()
129+
# generate a mask randomly with zeros and ones of shape (1, n)
130+
mask = torch.randint(0, 2, (1, n)).bool()
130131
block(torch.randn(input_shape), attn_mask=mask)
131-
att_mat = block.att_mat.squeeze(1)
132-
# get the masked row and the remaining ones based on mask 0 values
133-
rows_true = att_mat[mask, :]
134-
rows_false = att_mat[~mask, :]
135-
# check that in false rows every element is equal to 1/4
136-
assert torch.allclose(rows_false, torch.ones_like(rows_false) / n)
137-
# check that in true rows the mask column is zero
138-
assert torch.allclose(rows_true[:, -1], torch.zeros_like(rows_true[:, -1]))
132+
att_mat = block.att_mat.squeeze()
133+
# ensure all masked columns are zeros
134+
assert torch.allclose(att_mat[:, ~mask.squeeze(0)], torch.zeros_like(att_mat[:, ~mask.squeeze(0)]))
135+
136+
def test_causal_and_mask(self):
137+
with self.assertRaises(AssertionError):
138+
block = SABlock(hidden_size=128, num_heads=1, causal=True, sequence_length=64)
139+
inputs = torch.randn(2, 64, 128)
140+
mask = torch.randint(0, 2, (2, 64)).bool()
141+
block(inputs, attn_mask=mask)
139142

140143
@skipUnless(has_einops, "Requires einops")
141144
def test_access_attn_matrix(self):

0 commit comments

Comments
 (0)