Commit 649c7c8
Allow an arbitrary mask to be used in the self attention (#8235)
### Description
The aim of this PR is to enable the use of an arbitrary mask in the self
attention module, which is very useful in the case of missing data or
masked modeling.
Official torch implementations allow the use of an arbitrary mask, and
in MONAI the use of a mask is also made possible with the `causal`
argument. Here, it's just a generalization directly in the forward pass.
In the `SABlock` and `TransformerBlock`, it is now possible to input a
boolean mask of size `(BS, Seq_length)`.
Only the columns of the masked token are set to `-inf` and not the rows,
as is rarely the case in common implementations. Masked tokens don't
contribute to the gradient anyway.
In cases where causal attention is required, inputting a mask is not
supported to avoid masks overlapping.
I haven't implemented the addition mask to the attention matrix, which
allows you to use values other than `-inf` in certain cases, as may be
the case here:
https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
If you think it's relevant, it could be added.
### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [ ] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: Lucas Robinet <[email protected]>
Signed-off-by: Lucas Robinet <[email protected]>
Co-authored-by: YunLiu <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>1 parent 3ee4cd2 commit 649c7c8
File tree
3 files changed
+40
-6
lines changed- monai/networks/blocks
- tests
3 files changed
+40
-6
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
11 | 11 | | |
12 | 12 | | |
13 | 13 | | |
14 | | - | |
| 14 | + | |
15 | 15 | | |
16 | 16 | | |
17 | 17 | | |
| |||
154 | 154 | | |
155 | 155 | | |
156 | 156 | | |
157 | | - | |
| 157 | + | |
158 | 158 | | |
159 | 159 | | |
160 | 160 | | |
| 161 | + | |
| 162 | + | |
161 | 163 | | |
162 | 164 | | |
163 | 165 | | |
| |||
176 | 178 | | |
177 | 179 | | |
178 | 180 | | |
179 | | - | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
180 | 188 | | |
181 | 189 | | |
182 | 190 | | |
| |||
186 | 194 | | |
187 | 195 | | |
188 | 196 | | |
| 197 | + | |
| 198 | + | |
189 | 199 | | |
190 | 200 | | |
191 | | - | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
192 | 205 | | |
| 206 | + | |
193 | 207 | | |
194 | 208 | | |
195 | 209 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
90 | 90 | | |
91 | 91 | | |
92 | 92 | | |
93 | | - | |
94 | | - | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
95 | 97 | | |
96 | 98 | | |
97 | 99 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
122 | 122 | | |
123 | 123 | | |
124 | 124 | | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
125 | 143 | | |
126 | 144 | | |
127 | 145 | | |
| |||
0 commit comments