Skip to content

Commit 0789a2e

Browse files
authored
Merge branch 'dev' into h3rrr-patch-1
2 parents 98e5a0e + fd13c1b commit 0789a2e

File tree

7 files changed

+23
-22
lines changed

7 files changed

+23
-22
lines changed

monai/networks/blocks/fcn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ def __init__(
123123
self.upsample_mode = upsample_mode
124124
self.conv2d_type = conv2d_type
125125
self.out_channels = out_channels
126-
resnet = models.resnet50(pretrained=pretrained, progress=progress)
126+
resnet = models.resnet50(
127+
progress=progress, weights=models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None
128+
)
127129

128130
self.conv1 = resnet.conv1
129131
self.bn0 = resnet.bn1

monai/networks/nets/milmodel.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch
1717
import torch.nn as nn
1818

19-
from monai.utils.module import optional_import
19+
from monai.utils import optional_import
2020

2121
models, _ = optional_import("torchvision.models")
2222

@@ -48,7 +48,6 @@ class MILModel(nn.Module):
4848
Defaults to ``None`` (necessary only when using a custom backbone)
4949
trans_blocks: number of the blocks in `TransformEncoder` layer.
5050
trans_dropout: dropout rate in `TransformEncoder` layer.
51-
5251
"""
5352

5453
def __init__(
@@ -74,7 +73,7 @@ def __init__(
7473
self.transformer: nn.Module | None = None
7574

7675
if backbone is None:
77-
net = models.resnet50(pretrained=pretrained)
76+
net = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None)
7877
nfc = net.fc.in_features # save the number of final features
7978
net.fc = torch.nn.Identity() # remove final linear layer
8079

@@ -99,7 +98,7 @@ def hook(module, input, output):
9998
torch_model = getattr(models, backbone, None)
10099
if torch_model is None:
101100
raise ValueError("Unknown torch vision model" + str(backbone))
102-
net = torch_model(pretrained=pretrained)
101+
net = torch_model(weights="DEFAULT" if pretrained else None)
103102

104103
if getattr(net, "fc", None) is not None:
105104
nfc = net.fc.in_features # save the number of final features

monai/networks/nets/torchvision_fc.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,11 @@ def __init__(
112112
weights=None,
113113
**kwargs,
114114
):
115-
if weights is not None:
116-
model = getattr(models, model_name)(weights=weights, **kwargs)
117-
elif pretrained:
118-
model = getattr(models, model_name)(weights="DEFAULT", **kwargs)
119-
else:
120-
model = getattr(models, model_name)(weights=None, **kwargs)
115+
# if pretrained is False, weights is a weight tensor or None for no pretrained loading
116+
if pretrained and weights is None:
117+
weights = "DEFAULT"
118+
119+
model = getattr(models, model_name)(weights=weights, **kwargs)
121120

122121
super().__init__(
123122
model=model,

monai/transforms/io/array.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@
4343
from monai.data.utils import is_no_channel
4444
from monai.transforms.transform import Transform
4545
from monai.transforms.utility.array import EnsureChannelFirst
46-
from monai.utils import GridSamplePadMode
47-
from monai.utils import ImageMetaKey as Key
4846
from monai.utils import (
47+
GridSamplePadMode,
48+
ImageMetaKey,
4949
MetaKeys,
5050
OptionalImportError,
5151
convert_to_dst_type,
@@ -293,7 +293,8 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader
293293
# make sure all elements in metadata are little endian
294294
meta_data = switch_endianness(meta_data, "<")
295295

296-
meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader
296+
# Path obj should be strings for data loader
297+
meta_data[ImageMetaKey.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}"
297298
img = MetaTensor.ensure_torch_and_prune_meta(
298299
img_array, meta_data, self.simple_keys, pattern=self.pattern, sep=self.sep
299300
)
@@ -548,7 +549,7 @@ def __call__(self, img: NdarrayOrTensor):
548549
"Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True."
549550
)
550551

551-
input_path = meta_data[Key.FILENAME_OR_OBJ]
552+
input_path = meta_data[ImageMetaKey.FILENAME_OR_OBJ]
552553
output_path = meta_data[MetaKeys.SAVED_TO]
553554
log_data = {"input": input_path, "output": output_path}
554555

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pep8-naming
1818
pycodestyle
1919
pyflakes
2020
black>=25.1.0
21-
isort>=5.1, <6.0
21+
isort>=5.1, !=6.0.0
2222
ruff
2323
pytype>=2020.6.1, <=2024.4.11; platform_system != "Windows"
2424
types-setuptools

tests/networks/nets/test_densenet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_pretrain_consistency(self, model, input_param, input_shape):
9696
net = model(**input_param).to(device)
9797
with eval_mode(net):
9898
result = net.features.forward(example)
99-
torchvision_net = torchvision.models.densenet121(pretrained=True).to(device)
99+
torchvision_net = torchvision.models.densenet121(weights="DEFAULT").to(device)
100100
with eval_mode(torchvision_net):
101101
expected_result = torchvision_net.features.forward(example)
102102
self.assertTrue(torch.all(result == expected_result))

tests/networks/nets/test_milmodel.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@
4444
TEST_CASE_MILMODEL.append(test_case)
4545

4646
# torchvision backbone
47-
TEST_CASE_MILMODEL.append(
48-
[{"num_classes": 5, "backbone": "resnet18", "pretrained": False}, (2, 2, 3, 512, 512), (2, 5)]
49-
)
50-
TEST_CASE_MILMODEL.append([{"num_classes": 5, "backbone": "resnet18", "pretrained": True}, (2, 2, 3, 512, 512), (2, 5)])
47+
for pretrained in [True, False]:
48+
TEST_CASE_MILMODEL.append(
49+
[{"num_classes": 5, "backbone": "resnet18", "pretrained": pretrained}, (2, 2, 3, 512, 512), (2, 5)]
50+
)
5151

5252
# custom backbone
53-
backbone = models.densenet121(pretrained=False)
53+
backbone = models.densenet121()
5454
backbone_nfeatures = backbone.classifier.in_features
5555
backbone.classifier = torch.nn.Identity()
5656
TEST_CASE_MILMODEL.append(

0 commit comments

Comments
 (0)