Skip to content

Commit 9d34bb5

Browse files
committed
Updating Torchvision Model Loading
Signed-off-by: Eric Kerfoot <[email protected]>
1 parent 0968da2 commit 9d34bb5

File tree

5 files changed

+18
-16
lines changed

5 files changed

+18
-16
lines changed

monai/networks/blocks/fcn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ def __init__(
123123
self.upsample_mode = upsample_mode
124124
self.conv2d_type = conv2d_type
125125
self.out_channels = out_channels
126-
resnet = models.resnet50(pretrained=pretrained, progress=progress)
126+
resnet = models.resnet50(
127+
progress=progress, weights=models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None
128+
)
127129

128130
self.conv1 = resnet.conv1
129131
self.bn0 = resnet.bn1

monai/networks/nets/milmodel.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch
1717
import torch.nn as nn
1818

19-
from monai.utils.module import optional_import
19+
from monai.utils import first, optional_import
2020

2121
models, _ = optional_import("torchvision.models")
2222

@@ -48,6 +48,7 @@ class MILModel(nn.Module):
4848
Defaults to ``None`` (necessary only when using a custom backbone)
4949
trans_blocks: number of the blocks in `TransformEncoder` layer.
5050
trans_dropout: dropout rate in `TransformEncoder` layer.
51+
backbone_weights: name of weight object in torchvision.models to load when `backbone` names a torchvision model
5152
5253
"""
5354

@@ -74,7 +75,7 @@ def __init__(
7475
self.transformer: nn.Module | None = None
7576

7677
if backbone is None:
77-
net = models.resnet50(pretrained=pretrained)
78+
net = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None)
7879
nfc = net.fc.in_features # save the number of final features
7980
net.fc = torch.nn.Identity() # remove final linear layer
8081

@@ -99,7 +100,7 @@ def hook(module, input, output):
99100
torch_model = getattr(models, backbone, None)
100101
if torch_model is None:
101102
raise ValueError("Unknown torch vision model" + str(backbone))
102-
net = torch_model(pretrained=pretrained)
103+
net = torch_model(weights="DEFAULT" if pretrained else None)
103104

104105
if getattr(net, "fc", None) is not None:
105106
nfc = net.fc.in_features # save the number of final features

monai/networks/nets/torchvision_fc.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,11 @@ def __init__(
112112
weights=None,
113113
**kwargs,
114114
):
115-
if weights is not None:
116-
model = getattr(models, model_name)(weights=weights, **kwargs)
117-
elif pretrained:
118-
model = getattr(models, model_name)(weights="DEFAULT", **kwargs)
119-
else:
120-
model = getattr(models, model_name)(weights=None, **kwargs)
115+
# if pretrained is False, weights is a weight tensor or None for no pretrained loading
116+
if pretrained and weights is None:
117+
weights = "DEFAULT"
118+
119+
model = getattr(models, model_name)(weights=weights, **kwargs)
121120

122121
super().__init__(
123122
model=model,

tests/networks/nets/test_densenet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_pretrain_consistency(self, model, input_param, input_shape):
9696
net = model(**input_param).to(device)
9797
with eval_mode(net):
9898
result = net.features.forward(example)
99-
torchvision_net = torchvision.models.densenet121(pretrained=True).to(device)
99+
torchvision_net = torchvision.models.densenet121(weights="DEFAULT").to(device)
100100
with eval_mode(torchvision_net):
101101
expected_result = torchvision_net.features.forward(example)
102102
self.assertTrue(torch.all(result == expected_result))

tests/networks/nets/test_milmodel.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@
4444
TEST_CASE_MILMODEL.append(test_case)
4545

4646
# torchvision backbone
47-
TEST_CASE_MILMODEL.append(
48-
[{"num_classes": 5, "backbone": "resnet18", "pretrained": False}, (2, 2, 3, 512, 512), (2, 5)]
49-
)
50-
TEST_CASE_MILMODEL.append([{"num_classes": 5, "backbone": "resnet18", "pretrained": True}, (2, 2, 3, 512, 512), (2, 5)])
47+
for pretrained in [True, False]:
48+
TEST_CASE_MILMODEL.append(
49+
[{"num_classes": 5, "backbone": "resnet18", "pretrained": pretrained}, (2, 2, 3, 512, 512), (2, 5)]
50+
)
5151

5252
# custom backbone
53-
backbone = models.densenet121(pretrained=False)
53+
backbone = models.densenet121()
5454
backbone_nfeatures = backbone.classifier.in_features
5555
backbone.classifier = torch.nn.Identity()
5656
TEST_CASE_MILMODEL.append(

0 commit comments

Comments
 (0)