Skip to content

Commit

Permalink
feat: to/from PyTorch Tensor (#3259)
Browse files Browse the repository at this point in the history
* add new to_torch function

* add new from_torch function

* add changes suggested by Jim

* style: pre-commit fixes

* fix style

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
maxymnaumchyk and pre-commit-ci[bot] authored Oct 3, 2024
1 parent cfe58f3 commit ee5865a
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/awkward/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from awkward.operations.ak_from_raggedtensor import *
from awkward.operations.ak_from_rdataframe import *
from awkward.operations.ak_from_regular import *
from awkward.operations.ak_from_torch import *
from awkward.operations.ak_full_like import *
from awkward.operations.ak_imag import *
from awkward.operations.ak_is_categorical import *
Expand Down Expand Up @@ -102,6 +103,7 @@
from awkward.operations.ak_to_raggedtensor import *
from awkward.operations.ak_to_rdataframe import *
from awkward.operations.ak_to_regular import *
from awkward.operations.ak_to_torch import *
from awkward.operations.ak_transform import *
from awkward.operations.ak_type import *
from awkward.operations.ak_unflatten import *
Expand Down
65 changes: 65 additions & 0 deletions src/awkward/operations/ak_from_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import awkward as ak
from awkward._dispatch import high_level_function

__all__ = ("from_torch",)


@high_level_function()
def from_torch(array):
"""
Args:
array: (PyTorch Tensor):
Tensor to convert into an Awkward Array.
Converts a PyTorch Tensor into an Awkward Array.
If `array` contains any other data types the function raises an error.
"""

# Dispatch
yield (array,)

# Implementation
return _impl(array)


def _impl(array):
try:
import torch
except ImportError as err:
raise ImportError(
"""to use ak.from_torch, you must install 'torch' package with:
pip install torch
or
conda install pytorch"""
) from err

# check if array is a Tensor
if not isinstance(array, torch.Tensor):
raise TypeError("""only PyTorch Tensor can be converted to Awkward Array""")

# keep the resulting array on the same device as input tensor
device = "cuda" if array.is_cuda else "cpu"

# convert tensors to cupy if they are on cuda
if device == "cuda":
from awkward._nplikes.cupy import Cupy

cp = Cupy.instance()

# zero-copy data exchange through DLPack
cp_array = cp.from_dlpack(array)
ak_array = ak.from_cupy(cp_array)

else:
np_array = array.numpy()
ak_array = ak.from_numpy(np_array)

return ak_array
74 changes: 74 additions & 0 deletions src/awkward/operations/ak_to_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import awkward as ak
from awkward._dispatch import high_level_function
from awkward._nplikes.numpy_like import NumpyMetadata

__all__ = ("to_torch",)

np = NumpyMetadata.instance()


@high_level_function()
def to_torch(array):
"""
Args:
array: Array-like data. May be a high level #ak.Array,
or low-level #ak.contents.ListOffsetArray, #ak.contents.ListArray,
#ak.contents.RegularArray, #ak.contents.NumpyArray
Converts `array` (only ListOffsetArray, ListArray, RegularArray and NumpyArray data types supported)
into a PyTorch Tensor, if possible.
If `array` contains any other data types (RecordArray for example) the function raises a TypeError.
"""

# Dispatch
yield (array,)

# Implementation
return _impl(array)


def _impl(array):
try:
import torch
except ImportError as err:
raise ImportError(
"""to use ak.to_torch, you must install 'torch' package with:
pip install torch
or
conda install pytorch"""
) from err

# useful function that handles all possible input arrays
array = ak.to_layout(array, allow_record=False)

# get the device array is on
device = ak.backend(array)

if device not in ["cuda", "cpu"]:
raise ValueError("Only 'cpu' and 'cuda' backend conversions are allowed")

# convert to numpy or cupy if `array` on gpu
try:
backend_array = array.to_backend_array(allow_missing=False)
except ValueError as err:
raise TypeError(
"Only arrays containing equal-length lists of numbers can be converted into a PyTorch Tensor"
) from err

# check if cupy or numpy
if isinstance(backend_array, np.ndarray):
# convert numpy to a torch tensor
tensor = torch.from_numpy(backend_array)
else:
# cupy -> torch tensor
tensor = torch.utils.dlpack.from_dlpack(backend_array.toDlpack())

return tensor
72 changes: 72 additions & 0 deletions tests/test_3259_to_torch_from_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import numpy as np
import pytest

import awkward as ak

to_torch = ak.operations.to_torch
from_torch = ak.operations.from_torch

torch = pytest.importorskip("torch")

a = np.arange(2 * 2 * 2, dtype=np.float64).reshape(2, 2, 2)
b = np.arange(2 * 2 * 2).reshape(2, 2, 2)

array = np.arange(2 * 3 * 5).reshape(2, 3, 5)
content2 = ak.contents.NumpyArray(array.reshape(-1))
inneroffsets = ak.index.Index64(np.array([0, 5, 10, 15, 20, 25, 30]))
outeroffsets = ak.index.Index64(np.array([0, 3, 6]))


def test_to_torch():
# a basic test for a 4 dimensional array
array1 = ak.Array([a, b])
i = 0
for sub_array in [
[[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]]],
[[[0.0, 1.0], [2.0, 3.0]], [[4.0, 5.0], [6.0, 7.0]]],
]:
assert to_torch(array1)[i].tolist() == sub_array
i += 1

# test that the data types are remaining the same (float64 in this case)
assert array1.layout.to_backend_array().dtype.name in str(to_torch(array1).dtype)

# try a listoffset array inside a listoffset array
array2 = ak.contents.ListOffsetArray(
outeroffsets, ak.contents.ListOffsetArray(inneroffsets, content2)
)
assert to_torch(array2)[0].tolist() == [
[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
]
assert to_torch(array2)[1].tolist() == [
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24],
[25, 26, 27, 28, 29],
]

# try just a python list
array3 = [3, 1, 4, 1, 9, 2, 6]
assert to_torch(array3).tolist() == [3, 1, 4, 1, 9, 2, 6]


array1 = torch.tensor([[1.0, -1.0], [1.0, -1.0]], dtype=torch.float32)
array2 = torch.tensor(np.array([[1, 2, 3], [4, 5, 6]]))


def test_from_torch():
# Awkward.to_list() == Tensor.tolist()
assert from_torch(array1).to_list() == array1.tolist()

assert from_torch(array2).to_list() == array2.tolist()

# test that the data types are remaining the same (int64 in this case)
assert from_torch(array1).layout.dtype.name in str(array1.dtype)

# test that the data types are remaining the same (float32 in this case)
assert from_torch(array2).layout.dtype.name in str(array2.dtype)

0 comments on commit ee5865a

Please sign in to comment.