Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: to/from PyTorch JaggedTensor #3246

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements-test-ml.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
fbgemm-gpu-cpu >= 0.8.0
fsspec>=2022.11.0;sys_platform != "win32"
pytest>=6
pytest-cov
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from awkward.operations.ak_from_dlpack import *
from awkward.operations.ak_from_feather import *
from awkward.operations.ak_from_iter import *
from awkward.operations.ak_from_jaggedtensor import *
from awkward.operations.ak_from_jax import *
from awkward.operations.ak_from_json import *
from awkward.operations.ak_from_numpy import *
Expand Down Expand Up @@ -90,6 +91,7 @@
from awkward.operations.ak_to_cupy import *
from awkward.operations.ak_to_dataframe import *
from awkward.operations.ak_to_feather import *
from awkward.operations.ak_to_jaggedtensor import *
from awkward.operations.ak_to_jax import *
from awkward.operations.ak_to_json import *
from awkward.operations.ak_to_layout import *
Expand Down
76 changes: 76 additions & 0 deletions src/awkward/operations/ak_from_jaggedtensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import awkward as ak
from awkward._dispatch import high_level_function

__all__ = ("from_jaggedtensor",)


@high_level_function()
def from_jaggedtensor(array):
"""
Args:
array: (PyTorch JaggedTensor):
JaggedTensor to convert into an Awkward Array. The data type of a PyTorch JaggedTensor
is a 2-tuple of a `torch.Tensor` and a list of `torch.Tensors`.

Converts a PyTorch JaggedTensor into an Awkward Array.

If `array` contains any other data types the function raises an error.
"""

# Dispatch
yield (array,)

# Implementation
return _impl(array)


def _impl(array):
# keep the resulting array on the same device as input tensor
device = "cuda" if array[0].is_cuda else "cpu"

# convert tensors to cupy if they are on cuda
if device == "cuda":
try:
from awkward._nplikes.cupy import Cupy

cp = Cupy.instance()
except (ModuleNotFoundError, ImportError) as err:
raise err

content_cp = cp.asarray(array[0])
content = ak.contents.NumpyArray(content_cp)

offsets_arr = []
for offset in array[1]:
offset_cp = cp.asarray(offset)
offsets_arr.append(ak.index.Index64(offset_cp))
else:
content = ak.contents.NumpyArray(array[0])

offsets_arr = []
for offset in array[1]:
offsets_arr.append(ak.index.Index64(offset))

# if a tensor has one *ragged dimension*
if len(offsets_arr) == 1:
result = ak.contents.ListOffsetArray(offsets_arr[0], content)
return ak.Array(result)

# if a tensor has multiple *ragged dimensions*
return ak.Array(_recursive_call(content, offsets_arr, 0, device))


def _recursive_call(content, offsets_arr, count, device):
if count == len(offsets_arr) - 2:
return ak.contents.ListOffsetArray(
offsets_arr[count],
ak.contents.ListOffsetArray(offsets_arr[count + 1], content),
)
else:
return ak.contents.ListOffsetArray(
offsets_arr[count], _recursive_call(content, offsets_arr, count)
)
124 changes: 124 additions & 0 deletions src/awkward/operations/ak_to_jaggedtensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import awkward as ak
from awkward._dispatch import high_level_function
from awkward._nplikes.numpy_like import NumpyMetadata

__all__ = ("to_jaggedtensor",)

np = NumpyMetadata.instance()


@high_level_function()
def to_jaggedtensor(array):
"""
Args:
array: Array-like data. May be a high level #ak.Array,
or low-level #ak.contents.ListOffsetArray, #ak.contents.ListArray,
#ak.contents.RegularArray, #ak.contents.NumpyArray

Converts `array` (only ListOffsetArray, ListArray, RegularArray and NumpyArray data types supported)
into a PyTorch "jagged tensor", if possible. The data type of a PyTorch "jagged tensor" is a 2-tuple of a `torch.Tensor` and a list of `torch.Tensors`.
The first `torch.Tensor` is the numerical contents of the array and the list of integer-valued `torch.Tensors` are offsets indicating where variable-length lists start and end.

If `array` contains any other data types (RecordArray for example) the function raises a TypeError.
"""

# Dispatch
yield (array,)

# Implementation
return _impl(array)


def _impl(array):
try:
import torch
except ImportError as err:
raise ImportError(
"""to use ak.to_jaggedtensor, you must install 'torch' package with:

pip install torch or conda install pytorch"""
) from err

# unwrap the awkward array if it was made with ak.Array function
# also transforms a python list to awkward array
array = ak.to_layout(array, allow_record=False)

# keep the resulting tensor on the same device as input
device = ak.backend(array)

if isinstance(array, ak.contents.numpyarray.NumpyArray):
return torch.tensor(array.data)
elif isinstance(array, ak.contents.regulararray.RegularArray):
# since a jagged tensor can't function with an empty offsets array
raise TypeError(
"RegularArrays cannot be converted into a PyTorch JaggedTensor. Try using ak.from_regular() if you still want to use this function."
)
else:
flat_values, nested_row_splits = _recursive_call(array, [])

# since "jagged_to_padded_dense" not implemented for '64-bit floating point' raise an error if float64
if isinstance(flat_values.dtype, type(np.dtype(np.float64))):
try:
dense_test = torch.tensor(
[[[1, 1], [0, 0]], [[2, 2], [3, 3]]], dtype=torch.float64
)
offsets_test = torch.tensor([0, 1, 3], dtype=torch.float64)
torch.ops.fbgemm.dense_to_jagged(dense_test, [offsets_test])
except RuntimeError as error:
raise error

# check if cupy or numpy
if isinstance(flat_values, np.ndarray):
# convert numpy to a torch tensor
dense = torch.from_numpy(flat_values).to(device)
# convert a 'list of numpy' to a 'list of tensors'
offsets = [torch.from_numpy(item).to(device) for item in nested_row_splits]
else:
# if cupy
dense = torch.as_tensor(flat_values, device=device)
offsets = [
torch.as_tensor(item, device=device) for item in nested_row_splits
]

return (dense, offsets)


def _recursive_call(layout, offsets_arr):
try:
# change all the possible layout types to ListOffsetArray
if isinstance(layout, ak.contents.listarray.ListArray):
layout = layout.to_ListOffsetArray64()
elif isinstance(layout, ak.contents.regulararray.RegularArray):
# if RegularArray does not contain ListArrays or ListOffsetArrays return NumpyArray and accumulated offsets
numpy_arr = layout.maybe_to_NumpyArray()
if numpy_arr is not None:
return ak.to_numpy(numpy_arr), offsets_arr
else:
raise TypeError(
"RegularArrays containing ListArray or ListOffsetArray cannot be converted"
" into a PyTorch JaggedTensor. Try using ak.from_regular() if you still want to use this function."
)
elif not isinstance(
layout,
(
ak.contents.listoffsetarray.ListOffsetArray,
ak.contents.numpyarray.NumpyArray,
),
):
raise TypeError(
"Only arrays containing variable-length lists (var *) or"
" regular-length lists (# *) of numbers can be converted into a PyTorch JaggedTensor"
)

# recursively gather all of the offsets of an array
offsets_arr.append(layout.offsets.data)

except AttributeError:
# at the last iteration form a ragged tensor from the
# accumulated offsets and flattened values of the array
return layout.data, offsets_arr
return _recursive_call(layout.content, offsets_arr)
169 changes: 169 additions & 0 deletions tests/test_3246_to_jaggedtensor_from_jaggedtensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import numpy as np
import pytest

import awkward as ak

to_jaggedtensor = ak.operations.to_jaggedtensor
from_jaggedtensor = ak.operations.from_jaggedtensor

torch = pytest.importorskip("torch")
fbgemm_gpu = pytest.importorskip("fbgemm_gpu")

content = ak.contents.NumpyArray(
np.array([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9])
)
starts1 = ak.index.Index64(np.array([0, 3, 3, 5, 6]))
stops1 = ak.index.Index64(np.array([3, 3, 5, 6, 9]))
starts2 = ak.index.Index64(np.array([0, 3]))
stops2 = ak.index.Index64(np.array([3, 5]))

array = np.arange(2 * 3 * 5).reshape(2, 3, 5)
content2 = ak.contents.NumpyArray(array.reshape(-1))
inneroffsets = ak.index.Index64(np.array([0, 5, 10, 15, 20, 25, 30]))
outeroffsets = ak.index.Index64(np.array([0, 3, 6]))


def to_float32(array, highlevel=False):
return ak.values_astype(array, np.float32, highlevel=highlevel)


def test_convert_to_jaggedtensor():
# a test for ListArray -> JaggedTensor
array1 = ak.contents.ListArray(starts1, stops1, content)
array1 = to_float32(array1)
jagged1 = to_jaggedtensor(array1)
assert torch.equal(
jagged1[0], torch.tensor([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9])
)
assert torch.equal(jagged1[1][0], torch.tensor([0, 3, 3, 5, 6, 9]))

# a test for NumpyArray -> JaggedTensor
array2 = content
assert torch.equal(
to_jaggedtensor(array2),
torch.tensor(
[1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9], dtype=torch.float64
),
)

# try a single line awkward array
array4 = ak.Array([3, 1, 4, 1, 9, 2, 6])
assert torch.equal(to_jaggedtensor(array4), torch.tensor([3, 1, 4, 1, 9, 2, 6]))

# try a multiple ragged array
array5 = ak.Array([[[1.1, 2.2], [3.3]], [], [[4.4, 5.5]]])
array5 = to_float32(array5, highlevel=True)
jagged2 = to_jaggedtensor(array5)
assert torch.equal(
jagged2[0], torch.tensor([1.1000, 2.2000, 3.3000, 4.4000, 5.5000])
)
assert torch.equal(jagged2[1][0], torch.tensor([0, 2, 2, 3]))
assert torch.equal(jagged2[1][1], torch.tensor([0, 2, 3, 5]))

# try a listoffset array inside a listoffset array
array6 = ak.contents.ListOffsetArray(
outeroffsets, ak.contents.ListOffsetArray(inneroffsets, content2)
)
jagged3 = to_jaggedtensor(array6)
assert torch.equal(
jagged3[0],
torch.tensor(
[
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
26,
27,
28,
29,
]
),
)
assert torch.equal(jagged3[1][0], torch.tensor([0, 3, 6]))
assert torch.equal(jagged3[1][1], torch.tensor([0, 5, 10, 15, 20, 25, 30]))

# try a list array inside a list array
array7 = ak.contents.ListArray(
starts2, stops2, ak.contents.ListArray(starts1, stops1, content)
)
array7 = to_float32(array7)
jagged4 = to_jaggedtensor(array7)
assert torch.equal(
jagged4[0], torch.tensor([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9])
)
assert torch.equal(jagged4[1][0], torch.tensor([0, 3, 5]))
assert torch.equal(jagged4[1][1], torch.tensor([0, 3, 3, 5, 6, 9]))

# try just a python list
array8 = [3, 1, 4, 1, 9, 2, 6]
assert torch.equal(to_jaggedtensor(array8), torch.tensor([3, 1, 4, 1, 9, 2, 6]))

# try array with three inner dimensions
array9 = ak.Array([[[[1.1, 2.2], [3.3]], [[4.4]]], [], [[[5.5, 6.6], [7.7]]]])
array9 = to_float32(array9, highlevel=True)
jagged5 = to_jaggedtensor(array9)
assert torch.equal(jagged5[0], torch.tensor([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7]))
# whole offset list loooks like this -> [tensor([0, 2, 2, 3]), tensor([0, 2, 3, 5]), tensor([0, 2, 3, 4, 6, 7])]
assert torch.equal(jagged5[1][0], torch.tensor([0, 2, 2, 3]))
assert torch.equal(jagged5[1][1], torch.tensor([0, 2, 3, 5]))
assert torch.equal(jagged5[1][2], torch.tensor([0, 2, 3, 4, 6, 7]))


def test_regular_array():
# try to keep the regular arrays if possible:
array10 = ak.Array([[[1.1, 2.2], [3.3, 4.4], [5.5, 6.6]], [[7.7, 8.8], [9.9, 10]]])
array10 = to_float32(array10, highlevel=True)
regular1 = ak.to_regular(array10, axis=2)
jagged6 = to_jaggedtensor(regular1)
assert torch.equal(
jagged6[0],
torch.tensor([[1.1, 2.2], [3.3, 4.4], [5.5, 6.6], [7.7, 8.8], [9.9, 10.0]]),
)
assert torch.equal(jagged6[1][0], torch.tensor([0, 3, 5]))

# otherwise (if RegularArray contains ListArray or ListOffsetArray) raise a TypeError


def test_convert_from_jaggedtensor():
# check a simple jagged array created with "dense_to_jagged"
dense = torch.tensor([[[1, 1], [0, 0], [0, 0]], [[2, 2], [3, 3], [0, 0]]])
x_offsets = torch.tensor([0, 1, 3])
jagged_tensor = torch.ops.fbgemm.dense_to_jagged(dense, [x_offsets])

assert from_jaggedtensor(jagged_tensor).to_list() == [[[1, 1]], [[2, 2], [3, 3]]]

# test on manually generated tuple
desne1 = torch.tensor([1.1, 2.2, 3.3, 4.4, 5.5])
offsets = [torch.tensor([0, 2, 2, 3]), torch.tensor([0, 2, 3, 5])]
jagged_tensor2 = (desne1, offsets)
awkward_array = ak.Array([[[1.1, 2.2], [3.3]], [], [[4.4, 5.5]]])
awkward_array = ak.values_astype(awkward_array, np.float32)

assert ak.all(from_jaggedtensor(jagged_tensor2) == awkward_array)
Loading