From 502950b00b91a7ea5db00bc972006078792f140e Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Thu, 11 Sep 2025 14:20:05 +0100 Subject: [PATCH 01/20] Updating torch.load To Load Weights Only Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/apps/nnunet/nnunet_bundle.py | 18 +++++++++--------- monai/data/dataset.py | 5 +++++ monai/data/meta_tensor.py | 2 +- monai/handlers/checkpoint_loader.py | 2 +- monai/utils/state_cacher.py | 2 +- tests/data/meta_tensor/test_meta_tensor.py | 2 +- 6 files changed, 18 insertions(+), 13 deletions(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index e358cd4b99..b018aed57a 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -133,7 +133,7 @@ def get_nnunet_trainer( cudnn.benchmark = True if pretrained_model is not None: - state_dict = torch.load(pretrained_model) + state_dict = torch.load(pretrained_model, weights_only=True) if "network_weights" in state_dict: nnunet_trainer.network._orig_mod.load_state_dict(state_dict["network_weights"]) return nnunet_trainer @@ -182,7 +182,7 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name parameters = [] checkpoint = torch.load( - join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu") + join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu"), weights_only=True ) trainer_name = checkpoint["trainer_name"] configuration_name = checkpoint["init_args"]["configuration"] @@ -192,7 +192,7 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name else None ) if Path(model_training_output_dir).joinpath(model_name).is_file(): - monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu")) + monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu"), weights_only=True) if "network_weights" in monai_checkpoint.keys(): parameters.append(monai_checkpoint["network_weights"]) else: @@ -383,8 +383,8 @@ def convert_nnunet_to_monai_bundle(nnunet_config: dict, bundle_root_folder: str, dataset_name, f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}" ) - nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth")) - nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth")) + nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"), weights_only=True) + nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"), weights_only=True) nnunet_checkpoint = {} nnunet_checkpoint["inference_allowed_mirroring_axes"] = nnunet_checkpoint_final["inference_allowed_mirroring_axes"] @@ -470,7 +470,7 @@ def get_network_from_nnunet_plans( if model_ckpt is None: return network else: - state_dict = torch.load(model_ckpt) + state_dict = torch.load(model_ckpt, weights_only=True) network.load_state_dict(state_dict[model_key_in_ckpt]) return network @@ -534,7 +534,7 @@ def subfiles( Path(nnunet_model_folder).joinpath(f"fold_{fold}").mkdir(parents=True, exist_ok=True) - nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth") + nnunet_checkpoint: dict = torch.load(f"{bundle_root_folder}/models/nnunet_checkpoint.pth", weights_only=True) latest_checkpoints: list[str] = subfiles( Path(bundle_root_folder).joinpath("models", f"fold_{fold}"), prefix="checkpoint_epoch", sort=True ) @@ -545,7 +545,7 @@ def subfiles( epochs.sort() final_epoch: int = epochs[-1] monai_last_checkpoint: dict = torch.load( - f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt" + f"{bundle_root_folder}/models/fold_{fold}/checkpoint_epoch={final_epoch}.pt", weights_only=True ) best_checkpoints: list[str] = subfiles( @@ -558,7 +558,7 @@ def subfiles( key_metrics.sort() best_key_metric: str = key_metrics[-1] monai_best_checkpoint: dict = torch.load( - f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt" + f"{bundle_root_folder}/models/fold_{fold}/checkpoint_key_metric={best_key_metric}.pt", weights_only=True ) nnunet_checkpoint["optimizer_state"] = monai_last_checkpoint["optimizer_state"] diff --git a/monai/data/dataset.py b/monai/data/dataset.py index e5842bfa7a..1149411118 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -207,6 +207,10 @@ class PersistentDataset(Dataset): not guaranteed, so caution should be used when modifying transforms to avoid unexpected errors. If in doubt, it is advisable to clear the cache directory. + Loading is done using `torch.load` with `weights_only=False`, thus the user must ensure the data + being loaded is safe. Typically this will be cached data the user has created themselves but if + data from external sources is used this should be validated for safetly independently. + Lazy Resampling: If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to its documentation to familiarize yourself with the interaction between `PersistentDataset` and @@ -371,6 +375,7 @@ def _cachecheck(self, item_transformed): if hashfile is not None and hashfile.is_file(): # cache hit try: + # Loading with weights_only=False is expected to be safe as these should be the user's own cached data return torch.load(hashfile, weights_only=False) except PermissionError as e: if sys.platform != "win32": diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 6425bc0a4f..12bd76ba60 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -611,4 +611,4 @@ def print_verbose(self) -> None: # needed in later versions of Pytorch to indicate the class is safe for serialisation if hasattr(torch.serialization, "add_safe_globals"): - torch.serialization.add_safe_globals([MetaTensor]) + torch.serialization.add_safe_globals([MetaObj, MetaTensor, MetaKeys, SpaceKeys]) diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py index 16cb875d03..105b4f3a79 100644 --- a/monai/handlers/checkpoint_loader.py +++ b/monai/handlers/checkpoint_loader.py @@ -122,7 +122,7 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=False) + checkpoint = torch.load(self.load_path, map_location=self.map_location, weights_only=True) k, _ = list(self.load_dict.items())[0] # single object and checkpoint is directly a state_dict diff --git a/monai/utils/state_cacher.py b/monai/utils/state_cacher.py index c59436525c..726d59273b 100644 --- a/monai/utils/state_cacher.py +++ b/monai/utils/state_cacher.py @@ -124,7 +124,7 @@ def retrieve(self, key: Hashable) -> Any: fn = self.cached[key]["obj"] # pytype: disable=attribute-error if not os.path.exists(fn): # pytype: disable=wrong-arg-types raise RuntimeError(f"Failed to load state in {fn}. File doesn't exist anymore.") - data_obj = torch.load(fn, map_location=lambda storage, location: storage, weights_only=False) + data_obj = torch.load(fn, map_location=lambda storage, location: storage, weights_only=True) # copy back to device if necessary if "device" in self.cached[key]: data_obj = data_obj.to(self.cached[key]["device"]) diff --git a/tests/data/meta_tensor/test_meta_tensor.py b/tests/data/meta_tensor/test_meta_tensor.py index c0e53fd24c..ea516cdffb 100644 --- a/tests/data/meta_tensor/test_meta_tensor.py +++ b/tests/data/meta_tensor/test_meta_tensor.py @@ -245,7 +245,7 @@ def test_pickling(self): with tempfile.TemporaryDirectory() as tmp_dir: fname = os.path.join(tmp_dir, "im.pt") torch.save(m, fname) - m2 = torch.load(fname, weights_only=False) + m2 = torch.load(fname, weights_only=True) self.check(m2, m, ids=False) @skip_if_no_cuda From 77d2d827a8c065a91977a0c8a70e5347799a6078 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Thu, 11 Sep 2025 14:23:08 +0100 Subject: [PATCH 02/20] Autofix Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/apps/nnunet/nnunet_bundle.py | 16 ++++++++++++---- tests/data/meta_tensor/test_meta_tensor.py | 2 +- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/monai/apps/nnunet/nnunet_bundle.py b/monai/apps/nnunet/nnunet_bundle.py index b018aed57a..df8f09bf4b 100644 --- a/monai/apps/nnunet/nnunet_bundle.py +++ b/monai/apps/nnunet/nnunet_bundle.py @@ -182,7 +182,9 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name parameters = [] checkpoint = torch.load( - join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), map_location=torch.device("cpu"), weights_only=True + join(Path(model_training_output_dir).parent, "nnunet_checkpoint.pth"), + map_location=torch.device("cpu"), + weights_only=True, ) trainer_name = checkpoint["trainer_name"] configuration_name = checkpoint["init_args"]["configuration"] @@ -192,7 +194,9 @@ def __init__(self, predictor: object, model_folder: Union[str, Path], model_name else None ) if Path(model_training_output_dir).joinpath(model_name).is_file(): - monai_checkpoint = torch.load(join(model_training_output_dir, model_name), map_location=torch.device("cpu"), weights_only=True) + monai_checkpoint = torch.load( + join(model_training_output_dir, model_name), map_location=torch.device("cpu"), weights_only=True + ) if "network_weights" in monai_checkpoint.keys(): parameters.append(monai_checkpoint["network_weights"]) else: @@ -383,8 +387,12 @@ def convert_nnunet_to_monai_bundle(nnunet_config: dict, bundle_root_folder: str, dataset_name, f"{nnunet_trainer}__{nnunet_plans}__{nnunet_configuration}" ) - nnunet_checkpoint_final = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"), weights_only=True) - nnunet_checkpoint_best = torch.load(Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"), weights_only=True) + nnunet_checkpoint_final = torch.load( + Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_final.pth"), weights_only=True + ) + nnunet_checkpoint_best = torch.load( + Path(nnunet_model_folder).joinpath(f"fold_{fold}", "checkpoint_best.pth"), weights_only=True + ) nnunet_checkpoint = {} nnunet_checkpoint["inference_allowed_mirroring_axes"] = nnunet_checkpoint_final["inference_allowed_mirroring_axes"] diff --git a/tests/data/meta_tensor/test_meta_tensor.py b/tests/data/meta_tensor/test_meta_tensor.py index ea516cdffb..427902f784 100644 --- a/tests/data/meta_tensor/test_meta_tensor.py +++ b/tests/data/meta_tensor/test_meta_tensor.py @@ -245,7 +245,7 @@ def test_pickling(self): with tempfile.TemporaryDirectory() as tmp_dir: fname = os.path.join(tmp_dir, "im.pt") torch.save(m, fname) - m2 = torch.load(fname, weights_only=True) + m2 = torch.load(fname, weights_only=True) self.check(m2, m, ids=False) @skip_if_no_cuda From 79c2cf8196cb727b8a22021ce509a22cad5c929e Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Thu, 11 Sep 2025 14:43:49 +0100 Subject: [PATCH 03/20] StateCacher should be fine with default pickle protocol Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- tests/utils/test_state_cacher.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_state_cacher.py b/tests/utils/test_state_cacher.py index 22c2836239..6e6eabf03d 100644 --- a/tests/utils/test_state_cacher.py +++ b/tests/utils/test_state_cacher.py @@ -27,7 +27,13 @@ TEST_CASE_0 = [torch.Tensor([1]).to(DEVICE), {"in_memory": True}] TEST_CASE_1 = [ torch.Tensor([1]).to(DEVICE), - {"in_memory": False, "cache_dir": gettempdir(), "pickle_module": None, "pickle_protocol": pickle.HIGHEST_PROTOCOL}, + { + "in_memory": False, + "cache_dir": gettempdir(), + "pickle_module": None, + # TODO: was pickle.HIGHEST_PROTOCOL but this wasn't compatible with torch.load, need to improve compatibility + "pickle_protocol": torch.serialization.DEFAULT_PROTOCOL, + }, ] TEST_CASE_2 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "allow_overwrite": False}] TEST_CASE_3 = [torch.Tensor([1]).to(DEVICE), {"in_memory": False, "cache_dir": Path(gettempdir())}] From f6f9867f8c2c538db7a3052deea2f720a7494b12 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Thu, 11 Sep 2025 15:02:44 +0100 Subject: [PATCH 04/20] Docstring Update Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/data/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 1149411118..bbc31086ad 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -208,8 +208,8 @@ class PersistentDataset(Dataset): errors. If in doubt, it is advisable to clear the cache directory. Loading is done using `torch.load` with `weights_only=False`, thus the user must ensure the data - being loaded is safe. Typically this will be cached data the user has created themselves but if - data from external sources is used this should be validated for safetly independently. + being loaded is safe. Typically this will be cached data the user created themselves, if data + from external sources is used this should be validated for safely independently. Lazy Resampling: If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to From 93a5dd159191da8fdbec67885acad99da25e75f7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Sep 2025 14:04:58 +0000 Subject: [PATCH 05/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index bbc31086ad..ad99c4e216 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -208,7 +208,7 @@ class PersistentDataset(Dataset): errors. If in doubt, it is advisable to clear the cache directory. Loading is done using `torch.load` with `weights_only=False`, thus the user must ensure the data - being loaded is safe. Typically this will be cached data the user created themselves, if data + being loaded is safe. Typically this will be cached data the user created themselves, if data from external sources is used this should be validated for safely independently. Lazy Resampling: From 10b5de3588d013e7d710fcfb073f663a77e4722f Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Fri, 12 Sep 2025 15:34:36 +0100 Subject: [PATCH 06/20] Removing pickle_operations Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/data/utils.py | 60 +++++++++++++++++++++++---------------------- 1 file changed, 31 insertions(+), 29 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 14217e9103..ede1e6ff7b 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -421,27 +421,27 @@ def dev_collate(batch, level: int = 1, logger_name: str = "dev_collate"): PICKLE_KEY_SUFFIX = TraceKeys.KEY_SUFFIX -def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True): - """ - Applied_operations are dictionaries with varying sizes, this method converts them to bytes so that we can (de-)collate. - - Args: - data: a list or dictionary with substructures to be pickled/unpickled. - key: the key suffix for the target substructures, defaults to "_transforms" (`data.utils.PICKLE_KEY_SUFFIX`). - is_encode: whether it's encoding using pickle.dumps (True) or decoding using pickle.loads (False). - """ - if isinstance(data, Mapping): - data = dict(data) - for k in data: - if f"{k}".endswith(key): - if is_encode and not isinstance(data[k], bytes): - data[k] = pickle.dumps(data[k], 0) - if not is_encode and isinstance(data[k], bytes): - data[k] = pickle.loads(data[k]) - return {k: pickle_operations(v, key=key, is_encode=is_encode) for k, v in data.items()} - elif isinstance(data, (list, tuple)): - return [pickle_operations(item, key=key, is_encode=is_encode) for item in data] - return data +# def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True): +# """ +# Applied_operations are dictionaries with varying sizes, this method converts them to bytes so that we can (de-)collate. + +# Args: +# data: a list or dictionary with substructures to be pickled/unpickled. +# key: the key suffix for the target substructures, defaults to "_transforms" (`data.utils.PICKLE_KEY_SUFFIX`). +# is_encode: whether it's encoding using pickle.dumps (True) or decoding using pickle.loads (False). +# """ +# if isinstance(data, Mapping): +# data = dict(data) +# for k in data: +# if f"{k}".endswith(key): +# if is_encode and not isinstance(data[k], bytes): +# data[k] = pickle.dumps(data[k], 0) +# if not is_encode and isinstance(data[k], bytes): +# data[k] = pickle.loads(data[k]) +# return {k: pickle_operations(v, key=key, is_encode=is_encode) for k, v in data.items()} +# elif isinstance(data, (list, tuple)): +# return [pickle_operations(item, key=key, is_encode=is_encode) for item in data] +# return data def collate_meta_tensor_fn(batch, *, collate_fn_map=None): @@ -500,8 +500,8 @@ def list_data_collate(batch: Sequence): key = None collate_fn = default_collate try: - if config.USE_META_DICT: - data = pickle_operations(data) # bc 0.9.0 + # if config.USE_META_DICT: + # data = pickle_operations(data) # bc 0.9.0 if isinstance(elem, Mapping): ret = {} for k in elem: @@ -654,15 +654,17 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None): if isinstance(deco, Mapping): _gen = zip_longest(*deco.values(), fillvalue=fill_value) if pad else zip(*deco.values()) ret = [dict(zip(deco, item)) for item in _gen] - if not config.USE_META_DICT: - return ret - return pickle_operations(ret, is_encode=False) # bc 0.9.0 + # if not config.USE_META_DICT: + # return ret + # return pickle_operations(ret, is_encode=False) # bc 0.9.0 + return ret if isinstance(deco, Iterable): _gen = zip_longest(*deco, fillvalue=fill_value) if pad else zip(*deco) ret_list = [list(item) for item in _gen] - if not config.USE_META_DICT: - return ret_list - return pickle_operations(ret_list, is_encode=False) # bc 0.9.0 + # if not config.USE_META_DICT: + # return ret_list + # return pickle_operations(ret_list, is_encode=False) # bc 0.9.0 + return ret_list raise NotImplementedError(f"Unable to de-collate: {batch}, type: {type(batch)}.") From 64221d03137dc14bcbbf1db741653769b4923188 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Fri, 12 Sep 2025 23:46:04 +0100 Subject: [PATCH 07/20] Fixes loading with weights_only for PersistenDataset by force converting to tensors before saving. Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/data/dataset.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index ad99c4e216..9c8a08a753 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -207,9 +207,9 @@ class PersistentDataset(Dataset): not guaranteed, so caution should be used when modifying transforms to avoid unexpected errors. If in doubt, it is advisable to clear the cache directory. - Loading is done using `torch.load` with `weights_only=False`, thus the user must ensure the data - being loaded is safe. Typically this will be cached data the user created themselves, if data - from external sources is used this should be validated for safely independently. + Cached data is expected to be tensors, primitives, or dictionaries keying to these values. Numpy arrays will + be converted to tensors, however any other object type returned by transforms will not be loadable since + `torch.load` will be used with `weights_only=True` to prevent loading of potentially malicious objects. Lazy Resampling: If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to @@ -376,7 +376,7 @@ def _cachecheck(self, item_transformed): if hashfile is not None and hashfile.is_file(): # cache hit try: # Loading with weights_only=False is expected to be safe as these should be the user's own cached data - return torch.load(hashfile, weights_only=False) + return torch.load(hashfile, weights_only=True) except PermissionError as e: if sys.platform != "win32": raise e @@ -397,7 +397,7 @@ def _cachecheck(self, item_transformed): with tempfile.TemporaryDirectory() as tmpdirname: temp_hash_file = Path(tmpdirname) / hashfile.name torch.save( - obj=_item_transformed, + obj=convert_to_tensor(_item_transformed), f=temp_hash_file, pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD), pickle_protocol=self.pickle_protocol, @@ -1655,7 +1655,7 @@ def _create_new_cache(self, data, data_hashfile, meta_hash_file_name): meta_hash_file = self.cache_dir / meta_hash_file_name temp_hash_file = Path(tmpdirname) / meta_hash_file_name torch.save( - obj=self._meta_cache[meta_hash_file_name], + obj=convert_to_tensor(self._meta_cache[meta_hash_file_name]), f=temp_hash_file, pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD), pickle_protocol=self.pickle_protocol, @@ -1675,4 +1675,4 @@ def _load_meta_cache(self, meta_hash_file_name): if meta_hash_file_name in self._meta_cache: return self._meta_cache[meta_hash_file_name] else: - return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False) + return torch.load(self.cache_dir / meta_hash_file_name, weights_only=True) From 8a7579545d9f66ab9ee6480066edec36d2bdcbe0 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:06:01 +0100 Subject: [PATCH 08/20] Tweak Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/data/dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 9c8a08a753..a2c559fa90 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -210,6 +210,7 @@ class PersistentDataset(Dataset): Cached data is expected to be tensors, primitives, or dictionaries keying to these values. Numpy arrays will be converted to tensors, however any other object type returned by transforms will not be loadable since `torch.load` will be used with `weights_only=True` to prevent loading of potentially malicious objects. + Legacy cache files may not be loadable and may need to be recomputed. Lazy Resampling: If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to @@ -375,13 +376,12 @@ def _cachecheck(self, item_transformed): if hashfile is not None and hashfile.is_file(): # cache hit try: - # Loading with weights_only=False is expected to be safe as these should be the user's own cached data return torch.load(hashfile, weights_only=True) except PermissionError as e: if sys.platform != "win32": raise e - except RuntimeError as e: - if "Invalid magic number; corrupt file" in str(e): + except (pickle.UnpicklingError, RuntimeError) as e: # corrupt or unloadable cached files are recomputed + if "Invalid magic number; corrupt file" in str(e) or isinstance(e, pickle.UnpicklingError): warnings.warn(f"Corrupt cache file detected: {hashfile}. Deleting and recomputing.") hashfile.unlink() else: From a60569cdc5869ca4bf1026460fe78cffebad912d Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Sun, 14 Sep 2025 00:49:25 +0100 Subject: [PATCH 09/20] Comment unneeded components Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/data/__init__.py | 2 +- monai/data/utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 340c5eb8fa..9e2ce3fb6c 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -78,7 +78,7 @@ from .thread_buffer import ThreadBuffer, ThreadDataLoader from .torchscript_utils import load_net_with_metadata, save_net_with_metadata from .utils import ( - PICKLE_KEY_SUFFIX, + # PICKLE_KEY_SUFFIX, affine_to_spacing, compute_importance_map, compute_shape_offset, diff --git a/monai/data/utils.py b/monai/data/utils.py index ede1e6ff7b..361cbe8aab 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -93,7 +93,7 @@ "remove_keys", "remove_extra_metadata", "get_extra_metadata_keys", - "PICKLE_KEY_SUFFIX", + # "PICKLE_KEY_SUFFIX", "is_no_channel", ] @@ -418,7 +418,7 @@ def dev_collate(batch, level: int = 1, logger_name: str = "dev_collate"): return -PICKLE_KEY_SUFFIX = TraceKeys.KEY_SUFFIX +# PICKLE_KEY_SUFFIX = TraceKeys.KEY_SUFFIX # def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True): From b54e55d5fcce8c1390e3b0bb902f32ffaa0a6124 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Sun, 14 Sep 2025 17:27:24 +0100 Subject: [PATCH 10/20] Modify convert_to_tensor to skip converting primitives Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/data/dataset.py | 4 ++-- monai/utils/type_conversion.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index a2c559fa90..75ae51dd80 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -397,7 +397,7 @@ def _cachecheck(self, item_transformed): with tempfile.TemporaryDirectory() as tmpdirname: temp_hash_file = Path(tmpdirname) / hashfile.name torch.save( - obj=convert_to_tensor(_item_transformed), + obj=convert_to_tensor(_item_transformed, convert_numeric=False), f=temp_hash_file, pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD), pickle_protocol=self.pickle_protocol, @@ -1655,7 +1655,7 @@ def _create_new_cache(self, data, data_hashfile, meta_hash_file_name): meta_hash_file = self.cache_dir / meta_hash_file_name temp_hash_file = Path(tmpdirname) / meta_hash_file_name torch.save( - obj=convert_to_tensor(self._meta_cache[meta_hash_file_name]), + obj=convert_to_tensor(self._meta_cache[meta_hash_file_name], convert_numeric=False), f=temp_hash_file, pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD), pickle_protocol=self.pickle_protocol, diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 420e935b33..dd99f39d0a 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -117,6 +117,7 @@ def convert_to_tensor( wrap_sequence: bool = False, track_meta: bool = False, safe: bool = False, + convert_numeric: bool = True ) -> Any: """ Utility to convert the input data to a PyTorch Tensor, if `track_meta` is True, the output will be a `MetaTensor`, @@ -136,6 +137,7 @@ def convert_to_tensor( safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`. E.g., `[256, -12]` -> `[tensor(0), tensor(244)]`. If `True`, then `[256, -12]` -> `[tensor(255), tensor(0)]`. + convert_numeric: if `True`, convert numeric Python values to tensors. """ @@ -167,7 +169,7 @@ def _convert_tensor(tensor: Any, **kwargs: Any) -> Any: if data.ndim > 0: data = np.ascontiguousarray(data) return _convert_tensor(data, dtype=dtype, device=device) - elif (has_cp and isinstance(data, cp_ndarray)) or isinstance(data, (float, int, bool)): + elif (has_cp and isinstance(data, cp_ndarray)) or (convert_numeric and isinstance(data, (float, int, bool))): return _convert_tensor(data, dtype=dtype, device=device) elif isinstance(data, list): list_ret = [convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta) for i in data] From 52f8694e45111fac0f30f0cafa7f465e14c4cf90 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Sun, 14 Sep 2025 17:42:38 +0100 Subject: [PATCH 11/20] Trying safe torch load save usage in place of pickle Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/data/dataset.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 75ae51dd80..86273b6a07 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -12,6 +12,7 @@ from __future__ import annotations import collections.abc +from io import BytesIO import math import pickle import shutil @@ -599,6 +600,16 @@ def set_data(self, data: Sequence): super().set_data(data=data) self._read_env = self._fill_cache_start_reader(show_progress=self.progress) + def _safe_serialize(self,val): + out=BytesIO() + torch.save(convert_to_tensor(val), out, protocol=self.pickle_protocol) + out.seek(0) + return out.read() + + def _safe_deserialize(self,val): + out=BytesIO(val) + return torch.load(out,weights_only=True) + def _fill_cache_start_reader(self, show_progress=True): """ Check the LMDB cache and write the cache if needed. py-lmdb doesn't have a good support for concurrent write. @@ -624,7 +635,8 @@ def _fill_cache_start_reader(self, show_progress=True): continue if val is None: val = self._pre_transform(deepcopy(item)) # keep the original hashed - val = pickle.dumps(val, protocol=self.pickle_protocol) + # val = pickle.dumps(val, protocol=self.pickle_protocol) + val=self._safe_serialize(val) with env.begin(write=True) as txn: txn.put(key, val) done = True @@ -669,7 +681,8 @@ def _cachecheck(self, item_transformed): warnings.warn("LMDBDataset: cache key not found, running fallback caching.") return super()._cachecheck(item_transformed) try: - return pickle.loads(data) + # return pickle.loads(data) + return self._safe_deserialize(data) except Exception as err: raise RuntimeError("Invalid cache value, corrupted lmdb file?") from err From 14e5e6b27b519c5149502734d6ec89684e4cc9f1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 14 Sep 2025 16:44:08 +0000 Subject: [PATCH 12/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/dataset.py | 4 ++-- monai/data/utils.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 86273b6a07..409d7b411f 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -605,10 +605,10 @@ def _safe_serialize(self,val): torch.save(convert_to_tensor(val), out, protocol=self.pickle_protocol) out.seek(0) return out.read() - + def _safe_deserialize(self,val): out=BytesIO(val) - return torch.load(out,weights_only=True) + return torch.load(out,weights_only=True) def _fill_cache_start_reader(self, show_progress=True): """ diff --git a/monai/data/utils.py b/monai/data/utils.py index 361cbe8aab..61d723cc19 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -30,7 +30,6 @@ import torch from torch.utils.data._utils.collate import default_collate -from monai import config from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike from monai.data.meta_obj import MetaObj from monai.utils import ( From 38a618b1d9cdb90343fad1e718fed17f2bba2ff7 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Sun, 14 Sep 2025 18:06:19 +0100 Subject: [PATCH 13/20] Updates to further remove pickle usage Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/data/dataset.py | 23 ++++++++++++----------- monai/utils/state_cacher.py | 4 ++-- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 86273b6a07..83f78abb74 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -14,7 +14,7 @@ import collections.abc from io import BytesIO import math -import pickle +from pickle import UnpicklingError import shutil import sys import tempfile @@ -254,8 +254,8 @@ def __init__( this arg is used by `torch.save`, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save, and ``monai.data.utils.SUPPORTED_PICKLE_MOD``. - pickle_protocol: can be specified to override the default protocol, default to `2`. - this arg is used by `torch.save`, for more details, please check: + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. + Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. hash_transform: a callable to compute hash from the transform information when caching. This may reduce errors due to transforms changing during experiments. Default to None (no hash). @@ -381,8 +381,8 @@ def _cachecheck(self, item_transformed): except PermissionError as e: if sys.platform != "win32": raise e - except (pickle.UnpicklingError, RuntimeError) as e: # corrupt or unloadable cached files are recomputed - if "Invalid magic number; corrupt file" in str(e) or isinstance(e, pickle.UnpicklingError): + except (UnpicklingError, RuntimeError) as e: # corrupt or unloadable cached files are recomputed + if "Invalid magic number; corrupt file" in str(e) or isinstance(e, UnpicklingError): warnings.warn(f"Corrupt cache file detected: {hashfile}. Deleting and recomputing.") hashfile.unlink() else: @@ -461,8 +461,8 @@ def __init__( this arg is used by `torch.save`, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save, and ``monai.data.utils.SUPPORTED_PICKLE_MOD``. - pickle_protocol: can be specified to override the default protocol, default to `2`. - this arg is used by `torch.save`, for more details, please check: + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. + Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. hash_transform: a callable to compute hash from the transform information when caching. This may reduce errors due to transforms changing during experiments. Default to None (no hash). @@ -537,7 +537,7 @@ def __init__( hash_func: Callable[..., bytes] = pickle_hashing, db_name: str = "monai_cache", progress: bool = True, - pickle_protocol=pickle.HIGHEST_PROTOCOL, + pickle_protocol=DEFAULT_PROTOCOL, hash_transform: Callable[..., bytes] | None = None, reset_ops_id: bool = True, lmdb_kwargs: dict | None = None, @@ -557,8 +557,9 @@ def __init__( defaults to `monai.data.utils.pickle_hashing`. db_name: lmdb database file name. Defaults to "monai_cache". progress: whether to display a progress bar. - pickle_protocol: pickle protocol version. Defaults to pickle.HIGHEST_PROTOCOL. - https://docs.python.org/3/library/pickle.html#pickle-protocols + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. + Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: + https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. hash_transform: a callable to compute hash from the transform information when caching. This may reduce errors due to transforms changing during experiments. Default to None (no hash). Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`. @@ -602,7 +603,7 @@ def set_data(self, data: Sequence): def _safe_serialize(self,val): out=BytesIO() - torch.save(convert_to_tensor(val), out, protocol=self.pickle_protocol) + torch.save(convert_to_tensor(val), out, pickle_protocol =self.pickle_protocol) out.seek(0) return out.read() diff --git a/monai/utils/state_cacher.py b/monai/utils/state_cacher.py index 726d59273b..f6d0ff84f4 100644 --- a/monai/utils/state_cacher.py +++ b/monai/utils/state_cacher.py @@ -64,8 +64,8 @@ def __init__( pickle_module: module used for pickling metadata and objects, default to `pickle`. this arg is used by `torch.save`, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. - pickle_protocol: can be specified to override the default protocol, default to `2`. - this arg is used by `torch.save`, for more details, please check: + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. + Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. """ From 77d699204da0bbb1d18d34103d8913e579e81353 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 14 Sep 2025 17:06:47 +0000 Subject: [PATCH 14/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/dataset.py | 6 +++--- monai/utils/state_cacher.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 911859ca8a..d04ca79c4b 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -254,7 +254,7 @@ def __init__( this arg is used by `torch.save`, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save, and ``monai.data.utils.SUPPORTED_PICKLE_MOD``. - pickle_protocol: specifies pickle protocol when saving, with `torch.save`. + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. hash_transform: a callable to compute hash from the transform information when caching. @@ -461,7 +461,7 @@ def __init__( this arg is used by `torch.save`, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save, and ``monai.data.utils.SUPPORTED_PICKLE_MOD``. - pickle_protocol: specifies pickle protocol when saving, with `torch.save`. + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. hash_transform: a callable to compute hash from the transform information when caching. @@ -557,7 +557,7 @@ def __init__( defaults to `monai.data.utils.pickle_hashing`. db_name: lmdb database file name. Defaults to "monai_cache". progress: whether to display a progress bar. - pickle_protocol: specifies pickle protocol when saving, with `torch.save`. + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. hash_transform: a callable to compute hash from the transform information when caching. diff --git a/monai/utils/state_cacher.py b/monai/utils/state_cacher.py index f6d0ff84f4..e0eeef6001 100644 --- a/monai/utils/state_cacher.py +++ b/monai/utils/state_cacher.py @@ -64,7 +64,7 @@ def __init__( pickle_module: module used for pickling metadata and objects, default to `pickle`. this arg is used by `torch.save`, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. - pickle_protocol: specifies pickle protocol when saving, with `torch.save`. + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. From 79e09669e9e0c21714cb6cd230c17674fed046a4 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Sun, 14 Sep 2025 18:41:52 +0100 Subject: [PATCH 15/20] Autofix Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/data/__init__.py | 3 +-- monai/data/dataset.py | 23 +++++++++++------------ monai/data/utils.py | 6 +++--- monai/utils/state_cacher.py | 2 +- monai/utils/type_conversion.py | 2 +- 5 files changed, 17 insertions(+), 19 deletions(-) diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 9e2ce3fb6c..8709eae153 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -77,8 +77,7 @@ from .test_time_augmentation import TestTimeAugmentation from .thread_buffer import ThreadBuffer, ThreadDataLoader from .torchscript_utils import load_net_with_metadata, save_net_with_metadata -from .utils import ( - # PICKLE_KEY_SUFFIX, +from .utils import ( # PICKLE_KEY_SUFFIX, affine_to_spacing, compute_importance_map, compute_shape_offset, diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 911859ca8a..d63ff32293 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -12,9 +12,7 @@ from __future__ import annotations import collections.abc -from io import BytesIO import math -from pickle import UnpicklingError import shutil import sys import tempfile @@ -23,9 +21,11 @@ import warnings from collections.abc import Callable, Sequence from copy import copy, deepcopy +from io import BytesIO from multiprocessing.managers import ListProxy from multiprocessing.pool import ThreadPool from pathlib import Path +from pickle import UnpicklingError from typing import IO, TYPE_CHECKING, Any, cast import numpy as np @@ -254,7 +254,7 @@ def __init__( this arg is used by `torch.save`, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save, and ``monai.data.utils.SUPPORTED_PICKLE_MOD``. - pickle_protocol: specifies pickle protocol when saving, with `torch.save`. + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. hash_transform: a callable to compute hash from the transform information when caching. @@ -461,7 +461,7 @@ def __init__( this arg is used by `torch.save`, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save, and ``monai.data.utils.SUPPORTED_PICKLE_MOD``. - pickle_protocol: specifies pickle protocol when saving, with `torch.save`. + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. hash_transform: a callable to compute hash from the transform information when caching. @@ -557,7 +557,7 @@ def __init__( defaults to `monai.data.utils.pickle_hashing`. db_name: lmdb database file name. Defaults to "monai_cache". progress: whether to display a progress bar. - pickle_protocol: specifies pickle protocol when saving, with `torch.save`. + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. hash_transform: a callable to compute hash from the transform information when caching. @@ -601,15 +601,14 @@ def set_data(self, data: Sequence): super().set_data(data=data) self._read_env = self._fill_cache_start_reader(show_progress=self.progress) - def _safe_serialize(self,val): - out=BytesIO() - torch.save(convert_to_tensor(val), out, pickle_protocol =self.pickle_protocol) + def _safe_serialize(self, val): + out = BytesIO() + torch.save(convert_to_tensor(val), out, pickle_protocol=self.pickle_protocol) out.seek(0) return out.read() - def _safe_deserialize(self,val): - out=BytesIO(val) - return torch.load(out,weights_only=True) + def _safe_deserialize(self, val): + return torch.load(BytesIO(val), map_location="cpu", weights_only=True) def _fill_cache_start_reader(self, show_progress=True): """ @@ -637,7 +636,7 @@ def _fill_cache_start_reader(self, show_progress=True): if val is None: val = self._pre_transform(deepcopy(item)) # keep the original hashed # val = pickle.dumps(val, protocol=self.pickle_protocol) - val=self._safe_serialize(val) + val = self._safe_serialize(val) with env.begin(write=True) as txn: txn.put(key, val) done = True diff --git a/monai/data/utils.py b/monai/data/utils.py index 61d723cc19..dde873707d 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -500,7 +500,7 @@ def list_data_collate(batch: Sequence): collate_fn = default_collate try: # if config.USE_META_DICT: - # data = pickle_operations(data) # bc 0.9.0 + # data = pickle_operations(data) # bc 0.9.0 if isinstance(elem, Mapping): ret = {} for k in elem: @@ -654,14 +654,14 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None): _gen = zip_longest(*deco.values(), fillvalue=fill_value) if pad else zip(*deco.values()) ret = [dict(zip(deco, item)) for item in _gen] # if not config.USE_META_DICT: - # return ret + # return ret # return pickle_operations(ret, is_encode=False) # bc 0.9.0 return ret if isinstance(deco, Iterable): _gen = zip_longest(*deco, fillvalue=fill_value) if pad else zip(*deco) ret_list = [list(item) for item in _gen] # if not config.USE_META_DICT: - # return ret_list + # return ret_list # return pickle_operations(ret_list, is_encode=False) # bc 0.9.0 return ret_list raise NotImplementedError(f"Unable to de-collate: {batch}, type: {type(batch)}.") diff --git a/monai/utils/state_cacher.py b/monai/utils/state_cacher.py index f6d0ff84f4..e0eeef6001 100644 --- a/monai/utils/state_cacher.py +++ b/monai/utils/state_cacher.py @@ -64,7 +64,7 @@ def __init__( pickle_module: module used for pickling metadata and objects, default to `pickle`. this arg is used by `torch.save`, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. - pickle_protocol: specifies pickle protocol when saving, with `torch.save`. + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index dd99f39d0a..ce36b51046 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -117,7 +117,7 @@ def convert_to_tensor( wrap_sequence: bool = False, track_meta: bool = False, safe: bool = False, - convert_numeric: bool = True + convert_numeric: bool = True, ) -> Any: """ Utility to convert the input data to a PyTorch Tensor, if `track_meta` is True, the output will be a `MetaTensor`, From 11c0ee5ba39f404ee485cb49026b5f49ec7f3cb0 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Sun, 14 Sep 2025 18:42:59 +0100 Subject: [PATCH 16/20] Removing commented code Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/data/__init__.py | 2 +- monai/data/utils.py | 27 --------------------------- 2 files changed, 1 insertion(+), 28 deletions(-) diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 8709eae153..5e367cc297 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -77,7 +77,7 @@ from .test_time_augmentation import TestTimeAugmentation from .thread_buffer import ThreadBuffer, ThreadDataLoader from .torchscript_utils import load_net_with_metadata, save_net_with_metadata -from .utils import ( # PICKLE_KEY_SUFFIX, +from .utils import ( affine_to_spacing, compute_importance_map, compute_shape_offset, diff --git a/monai/data/utils.py b/monai/data/utils.py index dde873707d..ca7d5c9d9e 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -92,7 +92,6 @@ "remove_keys", "remove_extra_metadata", "get_extra_metadata_keys", - # "PICKLE_KEY_SUFFIX", "is_no_channel", ] @@ -417,32 +416,6 @@ def dev_collate(batch, level: int = 1, logger_name: str = "dev_collate"): return -# PICKLE_KEY_SUFFIX = TraceKeys.KEY_SUFFIX - - -# def pickle_operations(data, key=PICKLE_KEY_SUFFIX, is_encode: bool = True): -# """ -# Applied_operations are dictionaries with varying sizes, this method converts them to bytes so that we can (de-)collate. - -# Args: -# data: a list or dictionary with substructures to be pickled/unpickled. -# key: the key suffix for the target substructures, defaults to "_transforms" (`data.utils.PICKLE_KEY_SUFFIX`). -# is_encode: whether it's encoding using pickle.dumps (True) or decoding using pickle.loads (False). -# """ -# if isinstance(data, Mapping): -# data = dict(data) -# for k in data: -# if f"{k}".endswith(key): -# if is_encode and not isinstance(data[k], bytes): -# data[k] = pickle.dumps(data[k], 0) -# if not is_encode and isinstance(data[k], bytes): -# data[k] = pickle.loads(data[k]) -# return {k: pickle_operations(v, key=key, is_encode=is_encode) for k, v in data.items()} -# elif isinstance(data, (list, tuple)): -# return [pickle_operations(item, key=key, is_encode=is_encode) for item in data] -# return data - - def collate_meta_tensor_fn(batch, *, collate_fn_map=None): """ Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor` From 2edf46c36aaa7e5d77de3aba0a7e2701a99957be Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Sun, 14 Sep 2025 19:03:48 +0100 Subject: [PATCH 17/20] Pass argument in recursive call of convert_to_tensor Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/utils/type_conversion.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index ce36b51046..262b15101a 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -158,6 +158,10 @@ def _convert_tensor(tensor: Any, **kwargs: Any) -> Any: if safe: data = safe_dtype_range(data, dtype) dtype = get_equivalent_dtype(dtype, torch.Tensor) + + # common keyword arguments for recursive calls + conv_kwargs = dict(dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric) + if isinstance(data, torch.Tensor): return _convert_tensor(data).to(dtype=dtype, device=device, memory_format=torch.contiguous_format) if isinstance(data, np.ndarray): @@ -172,13 +176,13 @@ def _convert_tensor(tensor: Any, **kwargs: Any) -> Any: elif (has_cp and isinstance(data, cp_ndarray)) or (convert_numeric and isinstance(data, (float, int, bool))): return _convert_tensor(data, dtype=dtype, device=device) elif isinstance(data, list): - list_ret = [convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta) for i in data] + list_ret = [convert_to_tensor(i, **conv_kwargs) for i in data] return _convert_tensor(list_ret, dtype=dtype, device=device) if wrap_sequence else list_ret elif isinstance(data, tuple): - tuple_ret = tuple(convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta) for i in data) + tuple_ret = tuple(convert_to_tensor(i, **conv_kwargs) for i in data) return _convert_tensor(tuple_ret, dtype=dtype, device=device) if wrap_sequence else tuple_ret elif isinstance(data, dict): - return {k: convert_to_tensor(v, dtype=dtype, device=device, track_meta=track_meta) for k, v in data.items()} + return {k: convert_to_tensor(v, **conv_kwargs) for k, v in data.items()} return data From 9b171d44543c184b56770b290385aa4d9f23259c Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Sun, 14 Sep 2025 19:33:00 +0100 Subject: [PATCH 18/20] Type fix Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/utils/type_conversion.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 262b15101a..b5dfb580c5 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -159,9 +159,6 @@ def _convert_tensor(tensor: Any, **kwargs: Any) -> Any: data = safe_dtype_range(data, dtype) dtype = get_equivalent_dtype(dtype, torch.Tensor) - # common keyword arguments for recursive calls - conv_kwargs = dict(dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric) - if isinstance(data, torch.Tensor): return _convert_tensor(data).to(dtype=dtype, device=device, memory_format=torch.contiguous_format) if isinstance(data, np.ndarray): @@ -176,13 +173,22 @@ def _convert_tensor(tensor: Any, **kwargs: Any) -> Any: elif (has_cp and isinstance(data, cp_ndarray)) or (convert_numeric and isinstance(data, (float, int, bool))): return _convert_tensor(data, dtype=dtype, device=device) elif isinstance(data, list): - list_ret = [convert_to_tensor(i, **conv_kwargs) for i in data] + list_ret = [ + convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric) + for i in data + ] return _convert_tensor(list_ret, dtype=dtype, device=device) if wrap_sequence else list_ret elif isinstance(data, tuple): - tuple_ret = tuple(convert_to_tensor(i, **conv_kwargs) for i in data) + tuple_ret = tuple( + convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric) + for i in data + ) return _convert_tensor(tuple_ret, dtype=dtype, device=device) if wrap_sequence else tuple_ret elif isinstance(data, dict): - return {k: convert_to_tensor(v, **conv_kwargs) for k, v in data.items()} + return { + k: convert_to_tensor(v, dtype=dtype, device=device, track_meta=track_meta, convert_numeric=convert_numeric) + for k, v in data.items() + } return data From 149a5bbf2e2fa793a6761b89efad711b44343569 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Tue, 16 Sep 2025 14:44:21 +0100 Subject: [PATCH 19/20] Fixing pickle protocol issue Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- tests/data/test_gdsdataset.py | 3 ++- tests/data/test_persistentdataset.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/data/test_gdsdataset.py b/tests/data/test_gdsdataset.py index b4acb3bf55..fb7b78e41c 100644 --- a/tests/data/test_gdsdataset.py +++ b/tests/data/test_gdsdataset.py @@ -86,7 +86,8 @@ def test_cache(self): cache_dir=tempdir, device=0, pickle_module="pickle", - pickle_protocol=pickle.HIGHEST_PROTOCOL, + # TODO: was pickle.HIGHEST_PROTOCOL but this wasn't compatible with torch.load, need to improve compatibility + pickle_protocol=torch.serialization.DEFAULT_PROTOCOL, ) assert_allclose(items[0], p(np.arange(0, np.prod(shape)).reshape(shape))) ds1 = GDSDataset(items, transform=_InplaceXform(), cache_dir=tempdir, device=0) diff --git a/tests/data/test_persistentdataset.py b/tests/data/test_persistentdataset.py index 7c4969e283..7bf1245592 100644 --- a/tests/data/test_persistentdataset.py +++ b/tests/data/test_persistentdataset.py @@ -12,12 +12,12 @@ from __future__ import annotations import os -import pickle import tempfile import unittest import nibabel as nib import numpy as np +import torch from parameterized import parameterized from monai.data import PersistentDataset, json_hashing @@ -66,7 +66,8 @@ def test_cache(self): transform=_InplaceXform(), cache_dir=tempdir, pickle_module="pickle", - pickle_protocol=pickle.HIGHEST_PROTOCOL, + # TODO: was pickle.HIGHEST_PROTOCOL but this wasn't compatible with torch.load, need to improve compatibility + pickle_protocol=torch.serialization.DEFAULT_PROTOCOL, ) self.assertEqual(items, [[[]], [[0]], [[0, 1]], [[0, 1, 2]], [[0, 1, 2, 3]]]) ds1 = PersistentDataset(items, transform=_InplaceXform(), cache_dir=tempdir) From dd1de4a67bd30c3c6b8cd63f6dcf4e2d88aa5ca7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Sep 2025 13:44:51 +0000 Subject: [PATCH 20/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/data/test_gdsdataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/data/test_gdsdataset.py b/tests/data/test_gdsdataset.py index fb7b78e41c..aa802249bc 100644 --- a/tests/data/test_gdsdataset.py +++ b/tests/data/test_gdsdataset.py @@ -12,7 +12,6 @@ from __future__ import annotations import os -import pickle import tempfile import unittest