|
12 | 12 | from __future__ import annotations |
13 | 13 |
|
14 | 14 | import collections.abc |
15 | | -from io import BytesIO |
16 | 15 | import math |
17 | | -from pickle import UnpicklingError |
18 | 16 | import shutil |
19 | 17 | import sys |
20 | 18 | import tempfile |
|
23 | 21 | import warnings |
24 | 22 | from collections.abc import Callable, Sequence |
25 | 23 | from copy import copy, deepcopy |
| 24 | +from io import BytesIO |
26 | 25 | from multiprocessing.managers import ListProxy |
27 | 26 | from multiprocessing.pool import ThreadPool |
28 | 27 | from pathlib import Path |
| 28 | +from pickle import UnpicklingError |
29 | 29 | from typing import IO, TYPE_CHECKING, Any, cast |
30 | 30 |
|
31 | 31 | import numpy as np |
@@ -254,7 +254,7 @@ def __init__( |
254 | 254 | this arg is used by `torch.save`, for more details, please check: |
255 | 255 | https://pytorch.org/docs/stable/generated/torch.save.html#torch.save, |
256 | 256 | and ``monai.data.utils.SUPPORTED_PICKLE_MOD``. |
257 | | - pickle_protocol: specifies pickle protocol when saving, with `torch.save`. |
| 257 | + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. |
258 | 258 | Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: |
259 | 259 | https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. |
260 | 260 | hash_transform: a callable to compute hash from the transform information when caching. |
@@ -461,7 +461,7 @@ def __init__( |
461 | 461 | this arg is used by `torch.save`, for more details, please check: |
462 | 462 | https://pytorch.org/docs/stable/generated/torch.save.html#torch.save, |
463 | 463 | and ``monai.data.utils.SUPPORTED_PICKLE_MOD``. |
464 | | - pickle_protocol: specifies pickle protocol when saving, with `torch.save`. |
| 464 | + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. |
465 | 465 | Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: |
466 | 466 | https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. |
467 | 467 | hash_transform: a callable to compute hash from the transform information when caching. |
@@ -557,7 +557,7 @@ def __init__( |
557 | 557 | defaults to `monai.data.utils.pickle_hashing`. |
558 | 558 | db_name: lmdb database file name. Defaults to "monai_cache". |
559 | 559 | progress: whether to display a progress bar. |
560 | | - pickle_protocol: specifies pickle protocol when saving, with `torch.save`. |
| 560 | + pickle_protocol: specifies pickle protocol when saving, with `torch.save`. |
561 | 561 | Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check: |
562 | 562 | https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. |
563 | 563 | hash_transform: a callable to compute hash from the transform information when caching. |
@@ -601,15 +601,14 @@ def set_data(self, data: Sequence): |
601 | 601 | super().set_data(data=data) |
602 | 602 | self._read_env = self._fill_cache_start_reader(show_progress=self.progress) |
603 | 603 |
|
604 | | - def _safe_serialize(self,val): |
605 | | - out=BytesIO() |
606 | | - torch.save(convert_to_tensor(val), out, pickle_protocol =self.pickle_protocol) |
| 604 | + def _safe_serialize(self, val): |
| 605 | + out = BytesIO() |
| 606 | + torch.save(convert_to_tensor(val), out, pickle_protocol=self.pickle_protocol) |
607 | 607 | out.seek(0) |
608 | 608 | return out.read() |
609 | 609 |
|
610 | | - def _safe_deserialize(self,val): |
611 | | - out=BytesIO(val) |
612 | | - return torch.load(out,weights_only=True) |
| 610 | + def _safe_deserialize(self, val): |
| 611 | + return torch.load(BytesIO(val), map_location="cpu", weights_only=True) |
613 | 612 |
|
614 | 613 | def _fill_cache_start_reader(self, show_progress=True): |
615 | 614 | """ |
@@ -637,7 +636,7 @@ def _fill_cache_start_reader(self, show_progress=True): |
637 | 636 | if val is None: |
638 | 637 | val = self._pre_transform(deepcopy(item)) # keep the original hashed |
639 | 638 | # val = pickle.dumps(val, protocol=self.pickle_protocol) |
640 | | - val=self._safe_serialize(val) |
| 639 | + val = self._safe_serialize(val) |
641 | 640 | with env.begin(write=True) as txn: |
642 | 641 | txn.put(key, val) |
643 | 642 | done = True |
|
0 commit comments