Skip to content

Commit 4685eca

Browse files
committed
Update torch.load usage to eliminate complaint mesages
Signed-off-by: Eric Kerfoot <[email protected]>
1 parent a7b615e commit 4685eca

24 files changed

+47
-41
lines changed

monai/apps/detection/networks/retinanet_detector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def forward(self, images: torch.Tensor):
180180
nesterov=True,
181181
)
182182
torch.save(detector.network.state_dict(), 'model.pt') # save model
183-
detector.network.load_state_dict(torch.load('model.pt')) # load model
183+
detector.network.load_state_dict(torch.load('model.pt', weights_only=True)) # load model
184184
"""
185185

186186
def __init__(

monai/apps/mmars/mmars.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def load_from_mmar(
241241
return torch.jit.load(_model_file, map_location=map_location)
242242

243243
# loading with `torch.load`
244-
model_dict = torch.load(_model_file, map_location=map_location)
244+
model_dict = torch.load(_model_file, map_location=map_location, weights_only=True)
245245
if weights_only:
246246
return model_dict.get(model_key, model_dict) # model_dict[model_key] or model_dict directly
247247

monai/bundle/scripts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ def load(
737737
if load_ts_module is True:
738738
return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files)
739739
# loading with `torch.load`
740-
model_dict = torch.load(full_path, map_location=torch.device(device))
740+
model_dict = torch.load(full_path, map_location=torch.device(device), weights_only=True)
741741

742742
if not isinstance(model_dict, Mapping):
743743
warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.")
@@ -1306,7 +1306,7 @@ def _export(
13061306
# here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver
13071307
Checkpoint.load_objects(to_load={key_in_ckpt: net}, checkpoint=ckpt_file)
13081308
else:
1309-
ckpt = torch.load(ckpt_file)
1309+
ckpt = torch.load(ckpt_file, weights_only=True)
13101310
copy_model_state(dst=net, src=ckpt if key_in_ckpt == "" else ckpt[key_in_ckpt])
13111311

13121312
# Use the given converter to convert a model and save with metadata, config content

monai/data/dataset.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -372,10 +372,7 @@ def _cachecheck(self, item_transformed):
372372

373373
if hashfile is not None and hashfile.is_file(): # cache hit
374374
try:
375-
if "weights_only" in signature(torch.load).parameters:
376-
return torch.load(hashfile, weights_only=False)
377-
else:
378-
return torch.load(hashfile)
375+
return torch.load(hashfile, weights_only=False)
379376
except PermissionError as e:
380377
if sys.platform != "win32":
381378
raise e
@@ -1674,7 +1671,4 @@ def _load_meta_cache(self, meta_hash_file_name):
16741671
if meta_hash_file_name in self._meta_cache:
16751672
return self._meta_cache[meta_hash_file_name]
16761673
else:
1677-
if "weights_only" in signature(torch.load).parameters:
1678-
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
1679-
else:
1680-
return torch.load(self.cache_dir / meta_hash_file_name)
1674+
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)

monai/fl/client/monai_algo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ def get_weights(self, extra=None):
574574
model_path = os.path.join(self.bundle_root, cast(str, self.model_filepaths[model_type]))
575575
if not os.path.isfile(model_path):
576576
raise ValueError(f"No best model checkpoint exists at {model_path}")
577-
weights = torch.load(model_path, map_location="cpu")
577+
weights = torch.load(model_path, map_location="cpu", weights_only=True)
578578
# if weights contain several state dicts, use the one defined by `save_dict_key`
579579
if isinstance(weights, dict) and self.save_dict_key in weights:
580580
weights = weights.get(self.save_dict_key)

monai/handlers/checkpoint_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def __call__(self, engine: Engine) -> None:
122122
Args:
123123
engine: Ignite Engine, it can be a trainer, validator or evaluator.
124124
"""
125-
checkpoint = torch.load(self.load_path, map_location=self.map_location)
125+
checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=True)
126126

127127
k, _ = list(self.load_dict.items())[0]
128128
# single object and checkpoint is directly a state_dict

monai/losses/perceptual.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def __init__(
374374
else:
375375
network = torchvision.models.resnet50(weights=None)
376376
if pretrained is True:
377-
state_dict = torch.load(pretrained_path)
377+
state_dict = torch.load(pretrained_path, weights_only=True)
378378
if pretrained_state_dict_key is not None:
379379
state_dict = state_dict[pretrained_state_dict_key]
380380
network.load_state_dict(state_dict)

monai/networks/nets/hovernet.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -633,9 +633,9 @@ def _remap_preact_resnet_model(model_url: str):
633633
# download the pretrained weights into torch hub's default dir
634634
weights_dir = os.path.join(torch.hub.get_dir(), "preact-resnet50.pth")
635635
download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)
636-
state_dict = torch.load(weights_dir, map_location=None if torch.cuda.is_available() else torch.device("cpu"))[
637-
"desc"
638-
]
636+
map_location = None if torch.cuda.is_available() else torch.device("cpu")
637+
state_dict = torch.load(weights_dir, map_location=map_location, weights_only=True)["desc"]
638+
639639
for key in list(state_dict.keys()):
640640
new_key = None
641641
if pattern_conv0.match(key):
@@ -668,7 +668,8 @@ def _remap_standard_resnet_model(model_url: str, state_dict_key: str | None = No
668668
# download the pretrained weights into torch hub's default dir
669669
weights_dir = os.path.join(torch.hub.get_dir(), "resnet50.pth")
670670
download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)
671-
state_dict = torch.load(weights_dir, map_location=None if torch.cuda.is_available() else torch.device("cpu"))
671+
map_location = None if torch.cuda.is_available() else torch.device("cpu")
672+
state_dict = torch.load(weights_dir, map_location=map_location, weights_only=True)
672673
if state_dict_key is not None:
673674
state_dict = state_dict[state_dict_key]
674675

monai/networks/nets/resnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def _resnet(
493493
if isinstance(pretrained, str):
494494
if Path(pretrained).exists():
495495
logger.info(f"Loading weights from {pretrained}...")
496-
model_state_dict = torch.load(pretrained, map_location=device)
496+
model_state_dict = torch.load(pretrained, map_location=device, weights_only=True)
497497
else:
498498
# Throw error
499499
raise FileNotFoundError("The pretrained checkpoint file is not found")
@@ -665,7 +665,7 @@ def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", dat
665665
raise EntryNotFoundError(
666666
f"{filename} not found on {medicalnet_huggingface_repo_basename}{resnet_depth}"
667667
) from None
668-
checkpoint = torch.load(pretrained_path, map_location=torch.device(device))
668+
checkpoint = torch.load(pretrained_path, map_location=torch.device(device), weights_only=True)
669669
else:
670670
raise NotImplementedError("Supported resnet_depth are: [10, 18, 34, 50, 101, 152, 200]")
671671
logger.info(f"{filename} downloaded")

monai/networks/nets/senet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def _load_state_dict(model: nn.Module, arch: str, progress: bool):
302302

303303
if isinstance(model_url, dict):
304304
download_url(model_url["url"], filepath=model_url["filename"])
305-
state_dict = torch.load(model_url["filename"], map_location=None)
305+
state_dict = torch.load(model_url["filename"], map_location=None, weights_only=True)
306306
else:
307307
state_dict = load_state_dict_from_url(model_url, progress=progress)
308308
for key in list(state_dict.keys()):

0 commit comments

Comments
 (0)