Skip to content

Commit 1411564

Browse files
committed
Fix old state dict loading and add tests
Signed-off-by: John Zielke <[email protected]>
1 parent 8e6d04e commit 1411564

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

monai/networks/nets/diffusion_model_unet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1847,9 +1847,9 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
18471847
new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias")
18481848

18491849
# projection
1850-
new_state_dict[f"{block}.attn.out_proj.weight"] = old_state_dict.pop(f"{block}.proj_attn.weight")
1851-
new_state_dict[f"{block}.attn.out_proj.bias"] = old_state_dict.pop(f"{block}.proj_attn.bias")
1852-
1850+
if f"{block}.attn.out_proj.weight" in new_state_dict and f"{block}.attn.out_proj.bias" in new_state_dict:
1851+
new_state_dict[f"{block}.attn.out_proj.weight"] = old_state_dict.pop(f"{block}.proj_attn.weight")
1852+
new_state_dict[f"{block}.attn.out_proj.bias"] = old_state_dict.pop(f"{block}.proj_attn.bias")
18531853
# fix the cross attention blocks
18541854
cross_attention_blocks = [
18551855
k.replace(".out_proj.weight", "")

tests/networks/blocks/test_selfattention.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,27 @@ def test_flash_attention(self):
227227
out_2 = block_wo_flash_attention(test_data)
228228
assert_allclose(out_1, out_2, atol=1e-4)
229229

230+
@parameterized.expand([[True], [False]])
231+
def test_no_extra_weights_if_no_fc(self, include_fc):
232+
input_param = {
233+
"hidden_size": 360,
234+
"num_heads": 4,
235+
"dropout_rate": 0.0,
236+
"rel_pos_embedding": None,
237+
"input_size": (16, 32),
238+
"include_fc": include_fc,
239+
"use_combined_linear": use_combined_linear,
240+
}
241+
net = SABlock(**input_param)
242+
if not include_fc:
243+
self.assertNotIn("out_proj.weight", net.state_dict())
244+
self.assertNotIn("out_proj.bias", net.state_dict())
245+
self.assertIsInstance(net.out_proj, torch.nn.Identity)
246+
else:
247+
self.assertIn("out_proj.weight", net.state_dict())
248+
self.assertIn("out_proj.bias", net.state_dict())
249+
self.assertIsInstance(net.out_proj, torch.nn.Linear)
250+
230251

231252
if __name__ == "__main__":
232253
unittest.main()

0 commit comments

Comments
 (0)