Skip to content

Commit 447d9f2

Browse files
author
Fabio Ferreira
committed
fix: set seed
1 parent 1aa8e3c commit 447d9f2

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tests/networks/nets/test_checkpointunet.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,12 @@ def test_checkpointing_equivalence_eval(self):
104104
spatial_dims=2, in_channels=1, out_channels=2, channels=(8, 16, 32), strides=(2, 2), num_res_units=1
105105
)
106106

107+
torch.manual_seed(0)
107108
x = torch.randn(2, 1, 32, 32, device=device)
108109

109-
net_ckpt = CheckpointUNet(**params).to(device)
110110
net_plain = UNet(**params).to(device)
111+
net_ckpt = CheckpointUNet(**params).to(device)
112+
net_ckpt.load_state_dict(net_plain.state_dict())
111113

112114
with eval_mode(net_ckpt), eval_mode(net_plain):
113115
y_ckpt = net_ckpt(x)

0 commit comments

Comments
 (0)