Skip to content

Commit 52f8694

Browse files
committed
Trying safe torch load save usage in place of pickle
Signed-off-by: Eric Kerfoot <[email protected]>
1 parent 7dc3ad3 commit 52f8694

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

monai/data/dataset.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
import collections.abc
15+
from io import BytesIO
1516
import math
1617
import pickle
1718
import shutil
@@ -599,6 +600,16 @@ def set_data(self, data: Sequence):
599600
super().set_data(data=data)
600601
self._read_env = self._fill_cache_start_reader(show_progress=self.progress)
601602

603+
def _safe_serialize(self,val):
604+
out=BytesIO()
605+
torch.save(convert_to_tensor(val), out, protocol=self.pickle_protocol)
606+
out.seek(0)
607+
return out.read()
608+
609+
def _safe_deserialize(self,val):
610+
out=BytesIO(val)
611+
return torch.load(out,weights_only=True)
612+
602613
def _fill_cache_start_reader(self, show_progress=True):
603614
"""
604615
Check the LMDB cache and write the cache if needed. py-lmdb doesn't have a good support for concurrent write.
@@ -624,7 +635,8 @@ def _fill_cache_start_reader(self, show_progress=True):
624635
continue
625636
if val is None:
626637
val = self._pre_transform(deepcopy(item)) # keep the original hashed
627-
val = pickle.dumps(val, protocol=self.pickle_protocol)
638+
# val = pickle.dumps(val, protocol=self.pickle_protocol)
639+
val=self._safe_serialize(val)
628640
with env.begin(write=True) as txn:
629641
txn.put(key, val)
630642
done = True
@@ -669,7 +681,8 @@ def _cachecheck(self, item_transformed):
669681
warnings.warn("LMDBDataset: cache key not found, running fallback caching.")
670682
return super()._cachecheck(item_transformed)
671683
try:
672-
return pickle.loads(data)
684+
# return pickle.loads(data)
685+
return self._safe_deserialize(data)
673686
except Exception as err:
674687
raise RuntimeError("Invalid cache value, corrupted lmdb file?") from err
675688

0 commit comments

Comments
 (0)