Skip to content

Commit fc5ffdf

Browse files
authored
Merge branch 'dev' into fix_gdown_fails
2 parents d8a7c24 + 946cfdf commit fc5ffdf

26 files changed

+402
-109
lines changed

SECURITY.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Security Policy
2+
3+
## Reporting a Vulnerability
4+
MONAI takes security seriously and appreciate your efforts to responsibly disclose vulnerabilities. If you discover a security issue, please report it as soon as possible.
5+
6+
To report a security issue:
7+
* please use the GitHub Security Advisories tab to "[Open a draft security advisory](https://github.com/Project-MONAI/MONAI/security/advisories/new)".
8+
* Include a detailed description of the issue, steps to reproduce, potential impact, and any possible mitigations.
9+
* If applicable, please also attach proof-of-concept code or screenshots.
10+
* We aim to acknowledge your report within 72 hours and provide a status update as we investigate.
11+
* Please do not create public issues for security-related reports.
12+
13+
## Disclosure Policy
14+
* We follow a coordinated disclosure approach.
15+
* We will not publicly disclose vulnerabilities until a fix has been developed and released.
16+
* Credit will be given to researchers who responsibly disclose vulnerabilities, if requested.
17+
## Acknowledgements
18+
We greatly appreciate contributions from the security community and strive to recognize all researchers who help keep MONAI safe.

monai/apps/nnunet/nnunet_bundle.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def get_nnunet_trainer(
133133
cudnn.benchmark = True
134134

135135
if pretrained_model is not None:
136-
state_dict = torch.load(pretrained_model)
136+
state_dict = torch.load(pretrained_model, weights_only=True)
137137
if "network_weights" in state_dict:
138138
nnunet_trainer.network._orig_mod.load_state_dict(state_dict["network_weights"])
139139
return nnunet_trainer
@@ -182,7 +182,9 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name
182182
parameters = []
183183

184184
checkpoint = torch.load(
185-
join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu")
185+
join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"),
186+
map_location=torch.device("cpu"),
187+
weights_only=True,
186188
)
187189
trainer_name = checkpoint["trainer_name"]
188190
configuration_name = checkpoint["init_args"]["configuration"]
@@ -192,7 +194,9 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name
192194
else None
193195
)
194196
if Path(model_training_output_dir).joinpath(model_name).is_file():
195-
monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu"))
197+
monai_checkpoint = torch.load(
198+
join(model_training_output_dir, model_name), map_location=torch.device("cpu"), weights_only=True
199+
)
196200
if "network_weights" in monai_checkpoint.keys():
197201
parameters.append(monai_checkpoint["network_weights"])
198202
else:
@@ -383,8 +387,12 @@ def convert_nnunet_to_monai_bundle(nnunet_config: dict, bundle_root_folder: str,
383387
dataset_name, f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}"
384388
)
385389

386-
nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"))
387-
nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"))
390+
nnunet_checkpoint_final = torch.load(
391+
Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"), weights_only=True
392+
)
393+
nnunet_checkpoint_best = torch.load(
394+
Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"), weights_only=True
395+
)
388396

389397
nnunet_checkpoint = {}
390398
nnunet_checkpoint["inference_allowed_mirroring_axes"] = nnunet_checkpoint_final["inference_allowed_mirroring_axes"]
@@ -470,7 +478,7 @@ def get_network_from_nnunet_plans(
470478
if model_ckpt is None:
471479
return network
472480
else:
473-
state_dict = torch.load(model_ckpt)
481+
state_dict = torch.load(model_ckpt, weights_only=True)
474482
network.load_state_dict(state_dict[model_key_in_ckpt])
475483
return network
476484

@@ -534,7 +542,7 @@ def subfiles(
534542

535543
Path(nnunet_model_folder).joinpath(f"fold_{fold}").mkdir(parents=True, exist_ok=True)
536544

537-
nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth")
545+
nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth", weights_only=True)
538546
latest_checkpoints: list[str] = subfiles(
539547
Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_epoch", sort=True
540548
)
@@ -545,7 +553,7 @@ def subfiles(
545553
epochs.sort()
546554
final_epoch: int = epochs[-1]
547555
monai_last_checkpoint: dict = torch.load(
548-
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt"
556+
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt", weights_only=True
549557
)
550558

551559
best_checkpoints: list[str] = subfiles(
@@ -558,7 +566,7 @@ def subfiles(
558566
key_metrics.sort()
559567
best_key_metric: str = key_metrics[-1]
560568
monai_best_checkpoint: dict = torch.load(
561-
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt"
569+
f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt", weights_only=True
562570
)
563571

564572
nnunet_checkpoint["optimizer_state"] = monai_last_checkpoint["optimizer_state"]

monai/apps/utils.py

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,38 @@ def update_to(self, b: int = 1, bsize: int = 1, tsize: int | None = None) -> Non
122122
raise e
123123

124124

125+
def safe_extract_member(member, extract_to):
126+
"""Securely verify compressed package member paths to prevent path traversal attacks"""
127+
# Get member path (handle different compression formats)
128+
if hasattr(member, "filename"):
129+
member_path = member.filename # zipfile
130+
elif hasattr(member, "name"):
131+
member_path = member.name # tarfile
132+
else:
133+
member_path = str(member)
134+
135+
if hasattr(member, "issym") and member.issym():
136+
raise ValueError(f"Symbolic link detected in archive: {member_path}")
137+
if hasattr(member, "islnk") and member.islnk():
138+
raise ValueError(f"Hard link detected in archive: {member_path}")
139+
140+
member_path = os.path.normpath(member_path)
141+
142+
if os.path.isabs(member_path) or ".." in member_path.split(os.sep):
143+
raise ValueError(f"Unsafe path detected in archive: {member_path}")
144+
145+
full_path = os.path.join(extract_to, member_path)
146+
full_path = os.path.normpath(full_path)
147+
148+
extract_root = os.path.realpath(extract_to)
149+
target_real = os.path.realpath(full_path)
150+
# Ensure the resolved path stays within the extraction root
151+
if os.path.commonpath([extract_root, target_real]) != extract_root:
152+
raise ValueError(f"Unsafe path: path traversal {member_path}")
153+
154+
return full_path
155+
156+
125157
def check_hash(filepath: PathLike, val: str | None = None, hash_type: str = "md5") -> bool:
126158
"""
127159
Verify hash signature of specified file.
@@ -242,6 +274,32 @@ def download_url(
242274
)
243275

244276

277+
def _extract_zip(filepath, output_dir):
278+
with zipfile.ZipFile(filepath, "r") as zip_file:
279+
for member in zip_file.infolist():
280+
safe_path = safe_extract_member(member, output_dir)
281+
if member.is_dir():
282+
continue
283+
os.makedirs(os.path.dirname(safe_path), exist_ok=True)
284+
with zip_file.open(member) as source:
285+
with open(safe_path, "wb") as target:
286+
shutil.copyfileobj(source, target)
287+
288+
289+
def _extract_tar(filepath, output_dir):
290+
with tarfile.open(filepath, "r") as tar_file:
291+
for member in tar_file.getmembers():
292+
safe_path = safe_extract_member(member, output_dir)
293+
if not member.isfile():
294+
continue
295+
os.makedirs(os.path.dirname(safe_path), exist_ok=True)
296+
source = tar_file.extractfile(member)
297+
if source is not None:
298+
with source:
299+
with open(safe_path, "wb") as target:
300+
shutil.copyfileobj(source, target)
301+
302+
245303
def extractall(
246304
filepath: PathLike,
247305
output_dir: PathLike = ".",
@@ -287,14 +345,10 @@ def extractall(
287345
logger.info(f"Writing into directory: {output_dir}.")
288346
_file_type = file_type.lower().strip()
289347
if filepath.name.endswith("zip") or _file_type == "zip":
290-
zip_file = zipfile.ZipFile(filepath)
291-
zip_file.extractall(output_dir)
292-
zip_file.close()
348+
_extract_zip(filepath, output_dir)
293349
return
294350
if filepath.name.endswith("tar") or filepath.name.endswith("tar.gz") or "tar" in _file_type:
295-
tar_file = tarfile.open(filepath)
296-
tar_file.extractall(output_dir)
297-
tar_file.close()
351+
_extract_tar(filepath, output_dir)
298352
return
299353
raise NotImplementedError(
300354
f'Unsupported file type, available options are: ["zip", "tar.gz", "tar"]. name={filepath} type={file_type}.'

monai/data/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@
7878
from .thread_buffer import ThreadBuffer, ThreadDataLoader
7979
from .torchscript_utils import load_net_with_metadata, save_net_with_metadata
8080
from .utils import (
81-
PICKLE_KEY_SUFFIX,
8281
affine_to_spacing,
8382
compute_importance_map,
8483
compute_shape_offset,

monai/data/dataset.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
import collections.abc
1515
import math
16-
import pickle
1716
import shutil
1817
import sys
1918
import tempfile
@@ -22,9 +21,11 @@
2221
import warnings
2322
from collections.abc import Callable, Sequence
2423
from copy import copy, deepcopy
24+
from io import BytesIO
2525
from multiprocessing.managers import ListProxy
2626
from multiprocessing.pool import ThreadPool
2727
from pathlib import Path
28+
from pickle import UnpicklingError
2829
from typing import IO, TYPE_CHECKING, Any, cast
2930

3031
import numpy as np
@@ -207,6 +208,11 @@ class PersistentDataset(Dataset):
207208
not guaranteed, so caution should be used when modifying transforms to avoid unexpected
208209
errors. If in doubt, it is advisable to clear the cache directory.
209210
211+
Cached data is expected to be tensors, primitives, or dictionaries keying to these values. Numpy arrays will
212+
be converted to tensors, however any other object type returned by transforms will not be loadable since
213+
`torch.load` will be used with `weights_only=True` to prevent loading of potentially malicious objects.
214+
Legacy cache files may not be loadable and may need to be recomputed.
215+
210216
Lazy Resampling:
211217
If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to
212218
its documentation to familiarize yourself with the interaction between `PersistentDataset` and
@@ -248,8 +254,8 @@ def __init__(
248254
this arg is used by `torch.save`, for more details, please check:
249255
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save,
250256
and ``monai.data.utils.SUPPORTED_PICKLE_MOD``.
251-
pickle_protocol: can be specified to override the default protocol, default to `2`.
252-
this arg is used by `torch.save`, for more details, please check:
257+
pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
258+
Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
253259
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
254260
hash_transform: a callable to compute hash from the transform information when caching.
255261
This may reduce errors due to transforms changing during experiments. Default to None (no hash).
@@ -371,12 +377,12 @@ def _cachecheck(self, item_transformed):
371377

372378
if hashfile is not None and hashfile.is_file(): # cache hit
373379
try:
374-
return torch.load(hashfile, weights_only=False)
380+
return torch.load(hashfile, weights_only=True)
375381
except PermissionError as e:
376382
if sys.platform != "win32":
377383
raise e
378-
except RuntimeError as e:
379-
if "Invalid magic number; corrupt file" in str(e):
384+
except (UnpicklingError, RuntimeError) as e: # corrupt or unloadable cached files are recomputed
385+
if "Invalid magic number; corrupt file" in str(e) or isinstance(e, UnpicklingError):
380386
warnings.warn(f"Corrupt cache file detected: {hashfile}. Deleting and recomputing.")
381387
hashfile.unlink()
382388
else:
@@ -392,7 +398,7 @@ def _cachecheck(self, item_transformed):
392398
with tempfile.TemporaryDirectory() as tmpdirname:
393399
temp_hash_file = Path(tmpdirname) / hashfile.name
394400
torch.save(
395-
obj=_item_transformed,
401+
obj=convert_to_tensor(_item_transformed, convert_numeric=False),
396402
f=temp_hash_file,
397403
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
398404
pickle_protocol=self.pickle_protocol,
@@ -455,8 +461,8 @@ def __init__(
455461
this arg is used by `torch.save`, for more details, please check:
456462
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save,
457463
and ``monai.data.utils.SUPPORTED_PICKLE_MOD``.
458-
pickle_protocol: can be specified to override the default protocol, default to `2`.
459-
this arg is used by `torch.save`, for more details, please check:
464+
pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
465+
Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
460466
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
461467
hash_transform: a callable to compute hash from the transform information when caching.
462468
This may reduce errors due to transforms changing during experiments. Default to None (no hash).
@@ -531,7 +537,7 @@ def __init__(
531537
hash_func: Callable[..., bytes] = pickle_hashing,
532538
db_name: str = "monai_cache",
533539
progress: bool = True,
534-
pickle_protocol=pickle.HIGHEST_PROTOCOL,
540+
pickle_protocol=DEFAULT_PROTOCOL,
535541
hash_transform: Callable[..., bytes] | None = None,
536542
reset_ops_id: bool = True,
537543
lmdb_kwargs: dict | None = None,
@@ -551,8 +557,9 @@ def __init__(
551557
defaults to `monai.data.utils.pickle_hashing`.
552558
db_name: lmdb database file name. Defaults to "monai_cache".
553559
progress: whether to display a progress bar.
554-
pickle_protocol: pickle protocol version. Defaults to pickle.HIGHEST_PROTOCOL.
555-
https://docs.python.org/3/library/pickle.html#pickle-protocols
560+
pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
561+
Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
562+
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
556563
hash_transform: a callable to compute hash from the transform information when caching.
557564
This may reduce errors due to transforms changing during experiments. Default to None (no hash).
558565
Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.
@@ -594,6 +601,15 @@ def set_data(self, data: Sequence):
594601
super().set_data(data=data)
595602
self._read_env = self._fill_cache_start_reader(show_progress=self.progress)
596603

604+
def _safe_serialize(self, val):
605+
out = BytesIO()
606+
torch.save(convert_to_tensor(val), out, pickle_protocol=self.pickle_protocol)
607+
out.seek(0)
608+
return out.read()
609+
610+
def _safe_deserialize(self, val):
611+
return torch.load(BytesIO(val), map_location="cpu", weights_only=True)
612+
597613
def _fill_cache_start_reader(self, show_progress=True):
598614
"""
599615
Check the LMDB cache and write the cache if needed. py-lmdb doesn't have a good support for concurrent write.
@@ -619,7 +635,8 @@ def _fill_cache_start_reader(self, show_progress=True):
619635
continue
620636
if val is None:
621637
val = self._pre_transform(deepcopy(item)) # keep the original hashed
622-
val = pickle.dumps(val, protocol=self.pickle_protocol)
638+
# val = pickle.dumps(val, protocol=self.pickle_protocol)
639+
val = self._safe_serialize(val)
623640
with env.begin(write=True) as txn:
624641
txn.put(key, val)
625642
done = True
@@ -664,7 +681,8 @@ def _cachecheck(self, item_transformed):
664681
warnings.warn("LMDBDataset: cache key not found, running fallback caching.")
665682
return super()._cachecheck(item_transformed)
666683
try:
667-
return pickle.loads(data)
684+
# return pickle.loads(data)
685+
return self._safe_deserialize(data)
668686
except Exception as err:
669687
raise RuntimeError("Invalid cache value, corrupted lmdb file?") from err
670688

@@ -1650,7 +1668,7 @@ def _create_new_cache(self, data, data_hashfile, meta_hash_file_name):
16501668
meta_hash_file = self.cache_dir / meta_hash_file_name
16511669
temp_hash_file = Path(tmpdirname) / meta_hash_file_name
16521670
torch.save(
1653-
obj=self._meta_cache[meta_hash_file_name],
1671+
obj=convert_to_tensor(self._meta_cache[meta_hash_file_name], convert_numeric=False),
16541672
f=temp_hash_file,
16551673
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
16561674
pickle_protocol=self.pickle_protocol,
@@ -1670,4 +1688,4 @@ def _load_meta_cache(self, meta_hash_file_name):
16701688
if meta_hash_file_name in self._meta_cache:
16711689
return self._meta_cache[meta_hash_file_name]
16721690
else:
1673-
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
1691+
return torch.load(self.cache_dir / meta_hash_file_name, weights_only=True)

monai/data/meta_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,4 +611,4 @@ def print_verbose(self) -> None:
611611

612612
# needed in later versions of Pytorch to indicate the class is safe for serialisation
613613
if hasattr(torch.serialization, "add_safe_globals"):
614-
torch.serialization.add_safe_globals([MetaTensor])
614+
torch.serialization.add_safe_globals([MetaObj, MetaTensor, MetaKeys, SpaceKeys])

0 commit comments

Comments
 (0)