From 5fc21895076599853dd6da6deafb8d645c333105 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 15 Sep 2021 12:32:04 +0800 Subject: [PATCH] Create a ragged tensor from a regular tensor. --- .github/workflows/build-doc.yml | 1 + k2/python/csrc/torch/torch_util.cu | 4 +- k2/python/csrc/torch/torch_util.h | 8 +- k2/python/csrc/torch/v2/any.cu | 29 ++-- k2/python/csrc/torch/v2/doc/any.h | 206 +++++++++++++++++++++++--- k2/python/csrc/torch/v2/ragged_any.cu | 44 ++++++ k2/python/csrc/torch/v2/ragged_any.h | 12 ++ 7 files changed, 269 insertions(+), 35 deletions(-) diff --git a/.github/workflows/build-doc.yml b/.github/workflows/build-doc.yml index 5f61d0683..0a8d65607 100644 --- a/.github/workflows/build-doc.yml +++ b/.github/workflows/build-doc.yml @@ -22,6 +22,7 @@ on: push: branches: - master + - doc env: # debug is faster in terms of compilation time diff --git a/k2/python/csrc/torch/torch_util.cu b/k2/python/csrc/torch/torch_util.cu index afe88cb79..1860180be 100644 --- a/k2/python/csrc/torch/torch_util.cu +++ b/k2/python/csrc/torch/torch_util.cu @@ -105,7 +105,7 @@ torch::Tensor ToTorch(Array1 &array) { } template <> -Array1 FromTorch(torch::Tensor &tensor) { +Array1 FromTorch(torch::Tensor tensor) { K2_CHECK_EQ(tensor.dim(), 2) << "Expected dim: 2. Given: " << tensor.dim(); K2_CHECK_EQ(tensor.scalar_type(), ToScalarType::value) << "Expected scalar type: " << ToScalarType::value @@ -124,7 +124,7 @@ Array1 FromTorch(torch::Tensor &tensor) { return ans; } -Tensor FromTorch(torch::Tensor &tensor, TensorTag) { +Tensor FromTorch(torch::Tensor tensor, TensorTag) { Dtype dtype = ScalarTypeToDtype(tensor.scalar_type()); torch::IntArrayRef sizes = tensor.sizes(); torch::IntArrayRef strides = tensor.strides(); diff --git a/k2/python/csrc/torch/torch_util.h b/k2/python/csrc/torch/torch_util.h index 110e46a9a..7808bdc85 100644 --- a/k2/python/csrc/torch/torch_util.h +++ b/k2/python/csrc/torch/torch_util.h @@ -113,7 +113,7 @@ torch::Tensor ToTorch(Array1 &array) { input tensor. */ template -Array1 FromTorch(torch::Tensor &tensor) { +Array1 FromTorch(torch::Tensor tensor) { K2_CHECK_EQ(tensor.dim(), 1) << "Expected dim: 1. Given: " << tensor.dim(); K2_CHECK_EQ(tensor.scalar_type(), ToScalarType::value) << "Expected scalar type: " << ToScalarType::value @@ -158,12 +158,12 @@ torch::Tensor ToTorch(Array1 &array); the input tensor. */ template <> -Array1 FromTorch(torch::Tensor &tensor); +Array1 FromTorch(torch::Tensor tensor); struct Array2Tag {}; template -Array2 FromTorch(torch::Tensor &tensor, Array2Tag) { +Array2 FromTorch(torch::Tensor tensor, Array2Tag) { K2_CHECK_EQ(tensor.dim(), 2) << "Expected dim: 2. Given: " << tensor.dim(); K2_CHECK_EQ(tensor.scalar_type(), ToScalarType::value) << "Expected scalar type: " << ToScalarType::value @@ -199,7 +199,7 @@ torch::Tensor ToTorch(Array2 &array) { struct TensorTag {}; -Tensor FromTorch(torch::Tensor &tensor, TensorTag); +Tensor FromTorch(torch::Tensor tensor, TensorTag); torch::Tensor ToTorch(Tensor &tensor); /* Transfer an object to a specific device. diff --git a/k2/python/csrc/torch/v2/any.cu b/k2/python/csrc/torch/v2/any.cu index 472ea7cea..5e56b6c08 100644 --- a/k2/python/csrc/torch/v2/any.cu +++ b/k2/python/csrc/torch/v2/any.cu @@ -54,10 +54,12 @@ void PybindRaggedAny(py::module &m) { }), py::arg("s"), py::arg("dtype") = py::none(), kRaggedAnyInitStrDoc); - // TODO(fangjun): add documentation for it any.def(py::init(), py::arg("shape"), py::arg("value"), kRaggedInitFromShapeAndTensorDoc); + any.def(py::init(), py::arg("tensor"), + kRaggedAnyInitTensorDoc); + any.def( "__str__", [](const RaggedAny &self) -> std::string { return self.ToString(); }, @@ -78,8 +80,7 @@ void PybindRaggedAny(py::module &m) { K2_CHECK_EQ(self.any.NumAxes(), 2); Array1 row_split = self.any.RowSplits(1).To(GetCpuContext()); const int32_t *row_split_data = row_split.Data(); - int32_t begin = row_split_data[i], - end = row_split_data[i + 1]; + int32_t begin = row_split_data[i], end = row_split_data[i + 1]; Dtype t = self.any.GetDtype(); FOR_REAL_AND_INT32_TYPES(t, T, { Array1 array = @@ -100,19 +101,18 @@ void PybindRaggedAny(py::module &m) { if (!slice.compute(self.any.Dim0(), &start, &stop, &step, &slicelength)) throw py::error_already_set(); int32_t istart = static_cast(start); - int32_t istop = static_cast(stop); - int32_t istep = static_cast(step); - K2_CHECK_EQ(istep, 1) << "Only support slicing with step 1, given : " - << istep; + int32_t istop = static_cast(stop); + int32_t istep = static_cast(step); + K2_CHECK_EQ(istep, 1) + << "Only support slicing with step 1, given : " << istep; return self.Arange(/*axis*/ 0, istart, istop); - }, py::arg("key"), kRaggedAnyGetItemSliceDoc); + }, + py::arg("key"), kRaggedAnyGetItemSliceDoc); any.def("index", - static_cast( - &RaggedAny::Index), - py::arg("indexes"), - kRaggedAnyRaggedIndexDoc); + static_cast(&RaggedAny::Index), + py::arg("indexes"), kRaggedAnyRaggedIndexDoc); any.def("index", static_cast> ( @@ -423,6 +423,11 @@ void PybindRaggedAny(py::module &m) { return RaggedAny(s, dtype); }, py::arg("s"), py::arg("dtype") = py::none(), kCreateRaggedTensorStrDoc); + + m.def( + "create_ragged_tensor", + [](torch::Tensor tensor) -> RaggedAny { return RaggedAny(tensor); }, + py::arg("tensor"), kCreateRaggedTensorTensorDoc); } } // namespace k2 diff --git a/k2/python/csrc/torch/v2/doc/any.h b/k2/python/csrc/torch/v2/doc/any.h index b69ca3093..69e4a0378 100644 --- a/k2/python/csrc/torch/v2/doc/any.h +++ b/k2/python/csrc/torch/v2/doc/any.h @@ -120,6 +120,95 @@ torch.float32 Returns: Return a ragged tensor. )doc"; + +static constexpr const char *kCreateRaggedTensorTensorDoc = R"doc( +Create a ragged tensor from a torch tensor. + +Note: + It turns a regular tensor into a ragged tensor. + +Caution: + The input tensor has to have more than 1 dimension. + That is ``tensor.ndim > 1``. + + Also, if the input tensor is contiguous, ``self`` + will share the underlying memory with it. Otherwise, + memory of the input tensor is copied to create ``self``. + + Supported dtypes of the input tensor are: ``torch.int32``, + ``torch.float32``, and ``torch.float64``. + +**Example 1**: + + >>> import torch + >>> import k2.ragged as k2r + >>> a = torch.arange(6, dtype=torch.int32).reshape(2, 3) + >>> b = k2r.create_ragged_tensor(a) + >>> a + tensor([[0, 1, 2], + [3, 4, 5]], dtype=torch.int32) + >>> b + [ [ 0 1 2 ] [ 3 4 5 ] ] + >>> b.dtype + torch.int32 + >>> a.is_contiguous() + True + >>> a[0, 0] = 10 + >>> b + [ [ 10 1 2 ] [ 3 4 5 ] ] + >>> b.values[1] = -2 + >>> a + tensor([[10, -2, 2], + [ 3, 4, 5]], dtype=torch.int32) + +**Example 2**: + + >>> import k2.ragged as k2r + >>> a = torch.arange(24, dtype=torch.int32).reshape(2, 12)[:, ::4] + >>> a + tensor([[ 0, 4, 8], + [12, 16, 20]], dtype=torch.int32) + >>> a.is_contiguous() + False + >>> b = k2r.create_ragged_tensor(a) + >>> b + [ [ 0 4 8 ] [ 12 16 20 ] ] + >>> b.dtype + torch.int32 + >>> a[0, 0] = 10 + >>> b + [ [ 0 4 8 ] [ 12 16 20 ] ] + >>> a + tensor([[10, 4, 8], + [12, 16, 20]], dtype=torch.int32) + >>> b + [ [ 0 -2 8 ] [ 12 16 20 ] ] + +**Example 3**: + + >>> import torch + >>> import k2.ragged as k2r + >>> a = torch.arange(24, dtype=torch.float32).reshape(2, 3, 4) + >>> a + tensor([[[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.]], + [[12., 13., 14., 15.], + [16., 17., 18., 19.], + [20., 21., 22., 23.]]]) + >>> b = k2r.create_ragged_tensor(a) + >>> b + [ [ [ 0 1 2 3 ] [ 4 5 6 7 ] [ 8 9 10 11 ] ] [ [ 12 13 14 15 ] [ 16 17 18 19 ] [ 20 21 22 23 ] ] ] + >>> b.dtype + torch.float32 + +Args: + tensor: + An N-D (N > 1) tensor. +Returns: + Return a ragged tensor. +)doc"; + static constexpr const char *kRaggedInitFromShapeAndTensorDoc = R"doc( Create a ragged tensor from a shape and a value. @@ -245,6 +334,89 @@ torch.float32 ``torch.int32``, ``torch.float32``, and ``torch.float64``. )doc"; +static constexpr const char *kRaggedAnyInitTensorDoc = R"doc( +Create a ragged tensor from a torch tensor. + +Note: + It turns a regular tensor into a ragged tensor. + +Caution: + The input tensor has to have more than 1 dimension. + That is ``tensor.ndim > 1``. + + Also, if the input tensor is contiguous, ``self`` + will share the underlying memory with it. Otherwise, + memory of the input tensor is copied to create ``self``. + + Supported dtypes of the input tensor are: ``torch.int32``, + ``torch.float32``, and ``torch.float64``. + +**Example 1**: + + >>> import torch + >>> import k2.ragged as k2r + >>> a = torch.arange(6, dtype=torch.int32).reshape(2, 3) + >>> b = k2r.RaggedTensor(a) + >>> a + tensor([[0, 1, 2], + [3, 4, 5]], dtype=torch.int32) + >>> b + [ [ 0 1 2 ] [ 3 4 5 ] ] + >>> a.is_contiguous() + True + >>> a[0, 0] = 10 + >>> b + [ [ 10 1 2 ] [ 3 4 5 ] ] + >>> b.values[1] = -2 + >>> a + tensor([[10, -2, 2], + [ 3, 4, 5]], dtype=torch.int32) + +**Example 2**: + + >>> import k2.ragged as k2r + >>> a = torch.arange(24, dtype=torch.int32).reshape(2, 12)[:, ::4] + >>> a + tensor([[ 0, 4, 8], + [12, 16, 20]], dtype=torch.int32) + >>> a.is_contiguous() + False + >>> b = k2r.RaggedTensor(a) + >>> b + [ [ 0 4 8 ] [ 12 16 20 ] ] + >>> a[0, 0] = 10 + >>> b + [ [ 0 4 8 ] [ 12 16 20 ] ] + >>> a + tensor([[10, 4, 8], + [12, 16, 20]], dtype=torch.int32) + >>> b + [ [ 0 -2 8 ] [ 12 16 20 ] ] + +**Example 3**: + + >>> import torch + >>> import k2.ragged as k2r + >>> a = torch.arange(24, dtype=torch.float32).reshape(2, 3, 4) + >>> a + tensor([[[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.]], + [[12., 13., 14., 15.], + [16., 17., 18., 19.], + [20., 21., 22., 23.]]]) + >>> b = k2r.RaggedTensor(a) + >>> b + [ [ [ 0 1 2 3 ] [ 4 5 6 7 ] [ 8 9 10 11 ] ] [ [ 12 13 14 15 ] [ 16 17 18 19 ] [ 20 21 22 23 ] ] ] + >>> b.dtype + torch.float32 + + +Args: + tensor: + An N-D (N > 1) tensor. +)doc"; + static constexpr const char *kRaggedAnyToDeviceDoc = R"doc( Transfer this tensor to a given device. @@ -411,12 +583,12 @@ Return a copy of this tensor. >>> c = a.clone() >>> a [ [ 1 2 ] [ 3 ] ] ->>> b.data[0] = 10 +>>> b.values[0] = 10 >>> a [ [ 10 2 ] [ 3 ] ] >>> c [ [ 1 2 ] [ 3 ] ] ->>> c.data[0] = -1 +>>> c.values[0] = -1 >>> c [ [ -1 2 ] [ 3 ] ] >>> a @@ -577,7 +749,7 @@ tensor(40., grad_fn=) static constexpr const char *kRaggedAnyNumelDoc = R"doc( Returns: Return number of elements in this tensor. It equals to - ``self.data.numel()``. + ``self.values.numel()``. >>> import torch >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([[1], [], [3, 4, 5, 6]]) @@ -622,10 +794,10 @@ You are not expected to call it by yourself. Returns: If this tensor has 2 axes, return a tuple containing - (self.row_splits(1), "row_ids1", self.data). + (self.row_splits(1), "row_ids1", self.values). If this tensor has 3 axes, return a tuple containing (self.row_splits(1), "row_ids1", self.row_splits(1), - "row_ids2", self.data) + "row_ids2", self.values) Note: "row_ids1" and "row_ids2" in the returned value is for @@ -876,7 +1048,7 @@ The ``axis`` argument may be confusing; its behavior is equivalent to: >>> b = a.arange(axis=0, begin=1, end=4) >>> b [ [ 1 ] [ 2 ] [ ] ] - >>> b.data[0] = -1 + >>> b.values[0] = -1 >>> a [ [ 0 ] [ -1 ] [ 2 ] [ ] [ 3 ] ] @@ -953,7 +1125,7 @@ tensor([ 3, -1, 7], dtype=torch.int32) >>> d = c.argmax(initial_value=0) >>> d tensor([ 3, -1, 7], dtype=torch.int32) ->>> c.data[3], c.data[7] +>>> c.values[3], c.values[7] (tensor(5, dtype=torch.int32), tensor(8, dtype=torch.int32)) >>> c.argmax(initial_value=6) tensor([-1, -1, 7], dtype=torch.int32) @@ -1300,7 +1472,7 @@ Sort a ragged tensor over the last axis **in-place**. tensor([1, 0, 2, 4, 5, 3, 7, 6, 8], dtype=torch.int32) >>> a [ [ 3 1 0 ] [ 5 3 2 ] [ ] [ 3 1 0 ] ] ->>> a_clone.data[b.long()] +>>> a_clone.values[b.long()] tensor([3., 1., 0., 5., 3., 2., 3., 1., 0.]) >>> a_clone = a.clone() >>> c = a.sort_(descending=False, need_new2old_indexes=True) @@ -1308,7 +1480,7 @@ tensor([3., 1., 0., 5., 3., 2., 3., 1., 0.]) tensor([2, 1, 0, 5, 4, 3, 8, 7, 6], dtype=torch.int32) >>> a [ [ 0 1 3 ] [ 2 3 5 ] [ ] [ 0 1 3 ] ] ->>> a_clone.data[c.long()] +>>> a_clone.values[c.long()] tensor([0., 1., 3., 2., 3., 5., 0., 1., 3.]) Args: @@ -1318,7 +1490,7 @@ tensor([0., 1., 3., 2., 3., 5., 0., 1., 3.]) need_new2old_indexes: If ``True``, also returns a 1-D tensor, containing the indexes mapping from the sorted elements to the unsorted elements. We can use - ``self.clone().data[returned_tensor]`` to get a sorted tensor. + ``self.clone().values[returned_tensor]`` to get a sorted tensor. Returns: If ``need_new2old_indexes`` is False, returns None. Otherwise, returns a 1-D tensor of dtype ``torch.int32``. @@ -1348,7 +1520,7 @@ Index a ragged tensor with a ragged tensor. Args: indexes: - Its values must satisfy ``0 <= data[i] < self.dim0``. + Its values must satisfy ``0 <= values[i] < self.dim0``. Caution: Its dtype has to be ``torch.int32``. @@ -1376,7 +1548,7 @@ the elements of ``indexes`` are interpreted as indexes into axis ``axis`` of [ [ 0 1 2 ] [ 0 2 3 ] [ ] [ 3 -1.25 ] ] >>> value_indexes tensor([3, 4, 5, 0, 1, 2, 6, 7], dtype=torch.int32) - >>> a.data[value_indexes.long()] + >>> a.values[value_indexes.long()] tensor([ 0.0000, 1.0000, 2.0000, 0.0000, 2.0000, 3.0000, 3.0000, -1.2500]) >>> k = torch.tensor([2, -1, 0], dtype=torch.int32) >>> a.index(k, axis=0, need_value_indexes=True) @@ -1394,7 +1566,7 @@ the elements of ``indexes`` are interpreted as indexes into axis ``axis`` of [ [ [ 1 3 ] [ 2 ] [ ] ] [ [ 2 ] [ 5 8 ] [ -1 ] [ ] ] ] >>> value_indexes tensor([0, 1, 2, 6, 3, 4, 5], dtype=torch.int32) - >>> a.data[value_indexes.long()] + >>> a.values[value_indexes.long()] tensor([ 1, 3, 2, 2, 5, 8, -1], dtype=torch.int32) @@ -1414,15 +1586,15 @@ the elements of ``indexes`` are interpreted as indexes into axis ``axis`` of The axis to be indexed. Must satisfy ``0 <= axis < self.num_axes``. need_value_indexes: If ``True``, it will return a torch.Tensor containing the indexes into - ``self.data`` that ``ans.data`` has, as in - ``ans.data = self.data[value_indexes]``. + ``self.values`` that ``ans.values`` has, as in + ``ans.values = self.values[value_indexes]``. Returns: Return a tuple containing: - A ragged tensor, sharing the same dtype and device with ``self`` - ``None`` if ``need_value_indexes`` is False; a 1-D torch.tensor of - dtype ``torch.int32`` containing the indexes into ``self.data`` that - ``ans.data`` has. + dtype ``torch.int32`` containing the indexes into ``self.values`` that + ``ans.values`` has. )doc"; static constexpr const char *kRaggedAnyIndexTensorWithRaggedDoc = R"doc( diff --git a/k2/python/csrc/torch/v2/ragged_any.cu b/k2/python/csrc/torch/v2/ragged_any.cu index 8366a5f45..8a03b369b 100644 --- a/k2/python/csrc/torch/v2/ragged_any.cu +++ b/k2/python/csrc/torch/v2/ragged_any.cu @@ -220,6 +220,50 @@ RaggedAny::RaggedAny(py::list data, py::object dtype /*= py::none()*/) { << "and torch.float64"; } +RaggedAny::RaggedAny(torch::Tensor tensor) { + int32_t ndim = tensor.dim(); + K2_CHECK_GE(ndim, 2) << "Expect a tensor with more than 1-D"; + ContextPtr context = GetContext(tensor); + std::vector shapes; + shapes.reserve(ndim - 1); + int32_t dim0 = tensor.size(0); + for (int32_t i = 1; i != ndim; ++i) { + int32_t dim1 = tensor.size(i); + shapes.push_back(RegularRaggedShape(context, dim0, dim1)); + dim0 *= dim1; + } + while (shapes.size() > 2u) { + RaggedShape c = std::move(shapes.back()); + shapes.pop_back(); + + RaggedShape b = std::move(shapes.back()); + shapes.pop_back(); + + RaggedShape a = std::move(shapes.back()); + shapes.pop_back(); + + RaggedShape abc = ComposeRaggedShapes3(a, b, c); + shapes.push_back(std::move(abc)); + } + + if (shapes.size() > 1u) { + RaggedShape b = std::move(shapes.back()); + shapes.pop_back(); + + RaggedShape a = std::move(shapes.back()); + shapes.pop_back(); + + RaggedShape ab = ComposeRaggedShapes(a, b); + shapes.push_back(std::move(ab)); + } + + Dtype t = ScalarTypeToDtype(tensor.scalar_type()); + FOR_REAL_AND_INT32_TYPES(t, T, { + Array1 values = FromTorch(tensor.contiguous().view({-1})); + any = Ragged(shapes[0], values).Generic(); + }); +} + const torch::Tensor &RaggedAny::Data() const { DeviceGuard guard(any.Context()); if (!data.defined()) { diff --git a/k2/python/csrc/torch/v2/ragged_any.h b/k2/python/csrc/torch/v2/ragged_any.h index c3e171f59..5616b68c3 100644 --- a/k2/python/csrc/torch/v2/ragged_any.h +++ b/k2/python/csrc/torch/v2/ragged_any.h @@ -52,6 +52,18 @@ struct RaggedAny { */ RaggedAny(const RaggedShape &shape, torch::Tensor value); + /* Create a ragged tensor from a torch tensor. + + @note The resulting ragged tensor has a regular structure. + + @params tensor An N-D PyTorch tensor, where N > 1. Supported dtypes are + torch.int32, torch.float32, torch.float64. + + @caution If the input tensor is contiguous, the ragged tensor shares the + underlying memory with the input tensor. Otherwise, memory is copied. + */ + explicit RaggedAny(torch::Tensor tensor); + RaggedAny(const RaggedAny &) = default; RaggedAny &operator=(const RaggedAny &) = default; RaggedAny(RaggedAny &&) = default;