Skip to content

Commit 37917e0

Browse files
Make ViT and Unetr to be torchscript comaptible (#7937)
Fixes #7936 ### Description - Pre-define `self.causal_mask = torch.Tensor()` before register buffer - Move norm_cross_attn and cross_attn out of if block ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 7a8680e commit 37917e0

File tree

4 files changed

+17
-12
lines changed

4 files changed

+17
-12
lines changed

monai/networks/blocks/crossattention.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ def __init__(
109109
torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length),
110110
)
111111
self.causal_mask: torch.Tensor
112+
else:
113+
self.causal_mask = torch.Tensor()
112114

113115
self.att_mat = torch.Tensor()
114116
self.rel_positional_embedding = (
@@ -118,7 +120,7 @@ def __init__(
118120
)
119121
self.input_size = input_size
120122

121-
def forward(self, x: torch.Tensor, context: torch.Tensor | None = None):
123+
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None):
122124
"""
123125
Args:
124126
x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C

monai/networks/blocks/selfattention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ def __init__(
105105
torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length),
106106
)
107107
self.causal_mask: torch.Tensor
108+
else:
109+
self.causal_mask = torch.Tensor()
108110

109111
self.rel_positional_embedding = (
110112
get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.dim_head, self.num_heads)

monai/networks/blocks/transformerblock.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
from __future__ import annotations
1313

14+
from typing import Optional
15+
1416
import torch
1517
import torch.nn as nn
1618

@@ -68,13 +70,12 @@ def __init__(
6870
self.norm2 = nn.LayerNorm(hidden_size)
6971
self.with_cross_attention = with_cross_attention
7072

71-
if self.with_cross_attention:
72-
self.norm_cross_attn = nn.LayerNorm(hidden_size)
73-
self.cross_attn = CrossAttentionBlock(
74-
hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False
75-
)
73+
self.norm_cross_attn = nn.LayerNorm(hidden_size)
74+
self.cross_attn = CrossAttentionBlock(
75+
hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False
76+
)
7677

77-
def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
78+
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
7879
x = x + self.attn(self.norm1(x))
7980
if self.with_cross_attention:
8081
x = x + self.cross_attn(self.norm_cross_attn(x), context=context)

tests/test_vit.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
for mlp_dim in [3072]:
3131
for num_layers in [4]:
3232
for num_classes in [8]:
33-
for pos_embed in ["conv", "perceptron"]:
33+
for proj_type in ["conv", "perceptron"]:
3434
for classification in [False, True]:
3535
for nd in (2, 3):
3636
test_case = [
@@ -42,7 +42,7 @@
4242
"mlp_dim": mlp_dim,
4343
"num_layers": num_layers,
4444
"num_heads": num_heads,
45-
"pos_embed": pos_embed,
45+
"proj_type": proj_type,
4646
"classification": classification,
4747
"num_classes": num_classes,
4848
"dropout_rate": dropout_rate,
@@ -87,7 +87,7 @@ def test_ill_arg(
8787
mlp_dim,
8888
num_layers,
8989
num_heads,
90-
pos_embed,
90+
proj_type,
9191
classification,
9292
dropout_rate,
9393
):
@@ -100,12 +100,12 @@ def test_ill_arg(
100100
mlp_dim=mlp_dim,
101101
num_layers=num_layers,
102102
num_heads=num_heads,
103-
pos_embed=pos_embed,
103+
proj_type=proj_type,
104104
classification=classification,
105105
dropout_rate=dropout_rate,
106106
)
107107

108-
@parameterized.expand(TEST_CASE_Vit)
108+
@parameterized.expand(TEST_CASE_Vit[:1])
109109
@SkipIfBeforePyTorchVersion((1, 9))
110110
def test_script(self, input_param, input_shape, _):
111111
net = ViT(**(input_param))

0 commit comments

Comments
 (0)