Skip to content

Commit

Permalink
[syft/serde] - move all code related to torch tensor serde into a try…
Browse files Browse the repository at this point in the history
…-exception block in `third_party.py`

- move torch back to under data_science marker in setup.cfg
  • Loading branch information
khoaguin committed May 22, 2024
1 parent 404af04 commit 4b8cc4c
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 40 deletions.
4 changes: 2 additions & 2 deletions packages/syft/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ syft =
rich==13.7.1
jinja2==3.1.4
tenacity==8.3.0
# backend.dockerfile installs torch separately, so update the version over there as well!
torch==2.3.0

install_requires =
%(syft)s
Expand All @@ -88,6 +86,8 @@ data_science =
opendp==0.9.2
evaluate==0.4.1
recordlinkage==0.16
# backend.dockerfile installs torch separately, so update the version over there as well!
torch==2.3.0

dev =
%(test_plugins)s
Expand Down
18 changes: 0 additions & 18 deletions packages/syft/src/syft/serde/array.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# third party
import numpy as np
from numpy import frombuffer
import torch

# relative
from .arrow import numpy_deserialize
Expand Down Expand Up @@ -151,22 +150,5 @@
# deserialize=lambda buffer: frombuffer(buffer, dtype=numpy_scalar_type),
# )


# Add support for torch tensors
def torch_serialize(tensor: torch.Tensor) -> bytes:
return numpy_serialize(tensor.numpy())


def torch_deserialize(buffer: bytes) -> torch.tensor:
np_array = numpy_deserialize(buffer)
return torch.from_numpy(np_array)


recursive_serde_register(
torch.Tensor,
serialize=torch_serialize,
deserialize=lambda data: torch_deserialize(data),
)

# how else do you import a relative file to execute it?
NOTHING = None
58 changes: 38 additions & 20 deletions packages/syft/src/syft/serde/third_party.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
from result import Err
from result import Ok
from result import Result
import torch
from torch._C import _TensorMeta

# relative
from ..types.dicttuple import DictTuple
from ..types.dicttuple import _Meta as _DictTupleMetaClass
from ..types.syft_metaclass import EmptyType
from ..types.syft_metaclass import PartialModelMetaclass
from .array import numpy_deserialize
from .array import numpy_serialize
from .deserialize import _deserialize as deserialize
from .recursive_primitives import _serialize_kv_pairs
from .recursive_primitives import deserialize_kv
Expand Down Expand Up @@ -107,24 +107,6 @@ def deserialize_series(blob: bytes) -> Series:
deserialize=deserialize_series,
)


def serialize_torch_tensor_meta(t: _TensorMeta) -> bytes:
buffer = BytesIO()
torch.save(t, buffer)
return buffer.getvalue()


def deserialize_torch_tensor_meta(buf: bytes) -> _TensorMeta:
buffer = BytesIO(buf)
return torch.load(buffer)


recursive_serde_register(
_TensorMeta,
serialize=serialize_torch_tensor_meta,
deserialize=deserialize_torch_tensor_meta,
)

recursive_serde_register(
datetime,
serialize=lambda x: serialize(x.isoformat(), to_bytes=True),
Expand Down Expand Up @@ -198,6 +180,42 @@ def serialize_bytes_io(io: BytesIO) -> bytes:
pass


try:
# third party
import torch
from torch._C import _TensorMeta

def serialize_torch_tensor_meta(t: _TensorMeta) -> bytes:
buffer = BytesIO()
torch.save(t, buffer)
return buffer.getvalue()

def deserialize_torch_tensor_meta(buf: bytes) -> _TensorMeta:
buffer = BytesIO(buf)
return torch.load(buffer)

recursive_serde_register(
_TensorMeta,
serialize=serialize_torch_tensor_meta,
deserialize=deserialize_torch_tensor_meta,
)

def torch_serialize(tensor: torch.Tensor) -> bytes:
return numpy_serialize(tensor.numpy())

def torch_deserialize(buffer: bytes) -> torch.tensor:
np_array = numpy_deserialize(buffer)
return torch.from_numpy(np_array)

recursive_serde_register(
torch.Tensor,
serialize=torch_serialize,
deserialize=lambda data: torch_deserialize(data),
)

except Exception: # nosec
pass

# unsure why we have to register the object not the type but this works
recursive_serde_register(np.core._ufunc_config._unspecified())

Expand Down

0 comments on commit 4b8cc4c

Please sign in to comment.