@@ -230,6 +230,8 @@ def __init__(
230230 pickle_protocol : int = DEFAULT_PROTOCOL ,
231231 hash_transform : Callable [..., bytes ] | None = None ,
232232 reset_ops_id : bool = True ,
233+ track_meta : bool = False ,
234+ weights_only : bool = True ,
233235 ) -> None :
234236 """
235237 Args:
@@ -264,7 +266,17 @@ def __init__(
264266 When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.
265267 This is useful for skipping the transform instance checks when inverting applied operations
266268 using the cached content and with re-created transform instances.
267-
269+ track_meta: whether to track the meta information, if `True`, will convert to `MetaTensor`.
270+ default to `False`. Cannot be used with `weights_only=True`.
271+ weights_only: keyword argument passed to `torch.load` when reading cached files.
272+ default to `True`. When set to `True`, `torch.load` restricts loading to tensors and
273+ other safe objects. Setting this to `False` is required for loading `MetaTensor`
274+ objects saved with `track_meta=True`, however this creates the possibility of remote
275+ code execution through `torch.load` so be aware of the security implications of doing so.
276+
277+ Raises:
278+ ValueError: When both `track_meta=True` and `weights_only=True`, since this combination
279+ prevents cached MetaTensors from being reloaded and causes perpetual cache regeneration.
268280 """
269281 super ().__init__ (data = data , transform = transform )
270282 self .cache_dir = Path (cache_dir ) if cache_dir is not None else None
@@ -280,6 +292,13 @@ def __init__(
280292 if hash_transform is not None :
281293 self .set_transform_hash (hash_transform )
282294 self .reset_ops_id = reset_ops_id
295+ if track_meta and weights_only :
296+ raise ValueError (
297+ "Invalid argument combination: `track_meta=True` cannot be used with `weights_only=True`. "
298+ "To cache and reload MetaTensors, set `track_meta=True` and `weights_only=False`."
299+ )
300+ self .track_meta = track_meta
301+ self .weights_only = weights_only
283302
284303 def set_transform_hash (self , hash_xform_func : Callable [..., bytes ]):
285304 """Get hashable transforms, and then hash them. Hashable transforms
@@ -377,7 +396,7 @@ def _cachecheck(self, item_transformed):
377396
378397 if hashfile is not None and hashfile .is_file (): # cache hit
379398 try :
380- return torch .load (hashfile , weights_only = True )
399+ return torch .load (hashfile , weights_only = self . weights_only )
381400 except PermissionError as e :
382401 if sys .platform != "win32" :
383402 raise e
@@ -398,7 +417,7 @@ def _cachecheck(self, item_transformed):
398417 with tempfile .TemporaryDirectory () as tmpdirname :
399418 temp_hash_file = Path (tmpdirname ) / hashfile .name
400419 torch .save (
401- obj = convert_to_tensor (_item_transformed , convert_numeric = False ),
420+ obj = convert_to_tensor (_item_transformed , convert_numeric = False , track_meta = self . track_meta ),
402421 f = temp_hash_file ,
403422 pickle_module = look_up_option (self .pickle_module , SUPPORTED_PICKLE_MOD ),
404423 pickle_protocol = self .pickle_protocol ,
0 commit comments