1212from __future__ import annotations
1313
1414import collections .abc
15+ from io import BytesIO
1516import math
1617import pickle
1718import shutil
@@ -599,6 +600,16 @@ def set_data(self, data: Sequence):
599600 super ().set_data (data = data )
600601 self ._read_env = self ._fill_cache_start_reader (show_progress = self .progress )
601602
603+ def _safe_serialize (self ,val ):
604+ out = BytesIO ()
605+ torch .save (convert_to_tensor (val ), out , protocol = self .pickle_protocol )
606+ out .seek (0 )
607+ return out .read ()
608+
609+ def _safe_deserialize (self ,val ):
610+ out = BytesIO (val )
611+ return torch .load (out ,weights_only = True )
612+
602613 def _fill_cache_start_reader (self , show_progress = True ):
603614 """
604615 Check the LMDB cache and write the cache if needed. py-lmdb doesn't have a good support for concurrent write.
@@ -624,7 +635,8 @@ def _fill_cache_start_reader(self, show_progress=True):
624635 continue
625636 if val is None :
626637 val = self ._pre_transform (deepcopy (item )) # keep the original hashed
627- val = pickle .dumps (val , protocol = self .pickle_protocol )
638+ # val = pickle.dumps(val, protocol=self.pickle_protocol)
639+ val = self ._safe_serialize (val )
628640 with env .begin (write = True ) as txn :
629641 txn .put (key , val )
630642 done = True
@@ -669,7 +681,8 @@ def _cachecheck(self, item_transformed):
669681 warnings .warn ("LMDBDataset: cache key not found, running fallback caching." )
670682 return super ()._cachecheck (item_transformed )
671683 try :
672- return pickle .loads (data )
684+ # return pickle.loads(data)
685+ return self ._safe_deserialize (data )
673686 except Exception as err :
674687 raise RuntimeError ("Invalid cache value, corrupted lmdb file?" ) from err
675688
0 commit comments