Skip to content

Commit 2b1a37c

Browse files
Update monai/networks/blocks/transformerblock.py
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Yang Zekang <[email protected]>
1 parent 4cd46de commit 2b1a37c

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

monai/networks/blocks/transformerblock.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,12 @@ def __init__(
9090
causal=False,
9191
use_flash_attention=use_flash_attention,
9292
)
93+
else:
94+
def _drop_cross_attn_keys(state_dict, prefix, *_args):
95+
for key in list(state_dict.keys()):
96+
if key.startswith(prefix + "cross_attn.") or key.startswith(prefix + "norm_cross_attn."):
97+
state_dict.pop(key)
98+
self._register_load_state_dict_pre_hook(_drop_cross_attn_keys)
9399

94100
def forward(
95101
self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None

0 commit comments

Comments
 (0)