diff --git a/k2/csrc/ragged_ops_inl.h b/k2/csrc/ragged_ops_inl.h index f2c4b7377..4d44761cd 100644 --- a/k2/csrc/ragged_ops_inl.h +++ b/k2/csrc/ragged_ops_inl.h @@ -412,6 +412,8 @@ std::istream &operator>>(std::istream &is, Ragged &r) { : (row_splits[cur_level + 1].size() - 1)); is.get(); // consume character 'c' if (cur_level == 0) break; + } else if (c == ',') { + is.get(); // consume character 'c' } else { InputFixer t; is >> t; diff --git a/k2/python/csrc/torch/v2/any.cu b/k2/python/csrc/torch/v2/any.cu index 5e56b6c08..8a7a11194 100644 --- a/k2/python/csrc/torch/v2/any.cu +++ b/k2/python/csrc/torch/v2/any.cu @@ -40,19 +40,24 @@ void PybindRaggedAny(py::module &m) { // k2.ragged.Tensor methods //-------------------------------------------------- - any.def( - py::init([](py::list data, - py::object dtype = py::none()) -> std::unique_ptr { - return std::make_unique(data, dtype); - }), - py::arg("data"), py::arg("dtype") = py::none(), kRaggedAnyInitDataDoc); + any.def(py::init(), py::arg("data"), + py::arg("dtype") = py::none(), + py::arg("device") = torch::Device(torch::kCPU), + kRaggedAnyInitDataDeviceDoc); - any.def( - py::init([](const std::string &s, - py::object dtype = py::none()) -> std::unique_ptr { - return std::make_unique(s, dtype); - }), - py::arg("s"), py::arg("dtype") = py::none(), kRaggedAnyInitStrDoc); + any.def(py::init(), + py::arg("data"), py::arg("dtype") = py::none(), + py::arg("device") = "cpu", kRaggedAnyInitDataDeviceDoc); + + any.def(py::init(), + py::arg("s"), py::arg("dtype") = py::none(), + py::arg("device") = torch::Device(torch::kCPU), + kRaggedAnyInitStrDeviceDoc); + + any.def(py::init(), + py::arg("s"), py::arg("dtype") = py::none(), + py::arg("device") = torch::Device(torch::kCPU), + kRaggedAnyInitStrDeviceDoc); any.def(py::init(), py::arg("shape"), py::arg("value"), kRaggedInitFromShapeAndTensorDoc); @@ -408,21 +413,43 @@ void PybindRaggedAny(py::module &m) { // _k2.ragged.functions //-------------------------------------------------- - // TODO: change the function name from "create_tensor" to "tensor" m.def( "create_ragged_tensor", - [](py::list data, py::object dtype = py::none()) -> RaggedAny { - return RaggedAny(data, dtype); + [](py::list data, py::object dtype = py::none(), + torch::Device device = torch::kCPU) -> RaggedAny { + return RaggedAny(data, dtype, device); }, py::arg("data"), py::arg("dtype") = py::none(), + py::arg("device") = torch::Device(torch::kCPU), kCreateRaggedTensorDataDoc); m.def( "create_ragged_tensor", - [](const std::string &s, py::object dtype = py::none()) -> RaggedAny { - return RaggedAny(s, dtype); + [](py::list data, py::object dtype = py::none(), + const std::string &device = "cpu") -> RaggedAny { + return RaggedAny(data, dtype, device); + }, + py::arg("data"), py::arg("dtype") = py::none(), py::arg("device") = "cpu", + kCreateRaggedTensorDataDoc); + + m.def( + "create_ragged_tensor", + [](const std::string &s, py::object dtype = py::none(), + torch::Device device = torch::kCPU) -> RaggedAny { + return RaggedAny(s, dtype, device); + }, + py::arg("s"), py::arg("dtype") = py::none(), + py::arg("device") = torch::Device(torch::kCPU), + kCreateRaggedTensorStrDoc); + + m.def( + "create_ragged_tensor", + [](const std::string &s, py::object dtype = py::none(), + const std::string &device = "cpu") -> RaggedAny { + return RaggedAny(s, dtype, device); }, - py::arg("s"), py::arg("dtype") = py::none(), kCreateRaggedTensorStrDoc); + py::arg("s"), py::arg("dtype") = py::none(), py::arg("device") = "cpu", + kCreateRaggedTensorStrDoc); m.def( "create_ragged_tensor", diff --git a/k2/python/csrc/torch/v2/doc/any.h b/k2/python/csrc/torch/v2/doc/any.h index 69e4a0378..5c995fc9d 100644 --- a/k2/python/csrc/torch/v2/doc/any.h +++ b/k2/python/csrc/torch/v2/doc/any.h @@ -32,15 +32,20 @@ Create a ragged tensor with arbitrary number of axes. Hint: The returned tensor is on CPU. +>>> import torch >>> import k2.ragged as k2r >>> a = k2r.create_ragged_tensor([ [1, 2], [5], [], [9] ]) >>> a -[ [ 1 2 ] [ 5 ] [ ] [ 9 ] ] +RaggedTensor([[1, 2], + [5], + [], + [9]], dtype=torch.int32) >>> a.dtype torch.int32 >>> b = k2r.create_ragged_tensor([ [1, 3.0], [] ]) >>> b -[ [ 1 3 ] [ ] ] +RaggedTensor([[1, 3], + []], dtype=torch.float32) >>> b.dtype torch.float32 >>> c = k2r.create_ragged_tensor([ [1] ], dtype=torch.float64) @@ -48,18 +53,30 @@ torch.float32 torch.float64 >>> d = k2r.create_ragged_tensor([ [[1], [2, 3]], [[4], []] ]) >>> d -[ [ [ 1 ] [ 2 3 ] ] [ [ 4 ] [ ] ] ] +RaggedTensor([[[1], + [2, 3]], + [[4], + []]], dtype=torch.int32) >>> d.num_axes 3 >>> e = k2r.create_ragged_tensor([]) >>> e -[ ] +RaggedTensor([], dtype=torch.int32) >>> e.num_axes 2 >>> e.shape.row_splits(1) tensor([0], dtype=torch.int32) >>> e.shape.row_ids(1) tensor([], dtype=torch.int32) +>>> f = k2r.create_ragged_tensor([ [1, 2], [], [3] ], device=torch.device('cuda', 0)) +>>> f +RaggedTensor([[1, 2], + [], + [3]], device='cuda:0', dtype=torch.int32) +>>> e = k2r.create_ragged_tensor([[1], []], device='cuda:1') +>>> e +RaggedTensor([[1], + []], device='cuda:1', dtype=torch.int32) Args: data: @@ -70,6 +87,12 @@ tensor([], dtype=torch.int32) automatically, which is either ``torch.int32`` or ``torch.float32``. Supported dtypes are: ``torch.int32``, ``torch.float32``, and ``torch.float64``. + device: + It can be either an instance of ``torch.device`` or + a string representing a torch device. Example + values are: ``"cpu"``, ``"cuda:0"``, ``torch.device("cpu")``, + ``torch.device("cuda", 0)``. + Returns: Return a ragged tensor. )doc"; @@ -77,28 +100,28 @@ tensor([], dtype=torch.int32) static constexpr const char *kCreateRaggedTensorStrDoc = R"doc( Create a ragged tensor from its string representation. +Fields are separated by space(s) **or** comma(s). + An example string for a 2-axis ragged tensor is given below:: - [ [1] [2] ] + [ [1] [2] [3, 4], [5 6 7, 8] ] An example string for a 3-axis ragged tensor is given below:: [ [[1]] [[]] ] -Hint: - The returned tensor is on CPU. - >>> import torch >>> import k2.ragged as k2r >>> a = k2r.create_ragged_tensor('[ [1] [] [3 4] ]') >>> a -[ [ 1 ] [ ] [ 3 4 ] ] +RaggedTensor([[1], + [], + [3, 4]], dtype=torch.int32) >>> a.num_axes 2 >>> a.dtype torch.int32 >>> b = k2r.create_ragged_tensor('[ [[] [3]] [[10]] ]', dtype=torch.float32) ->>> b = k2r.create_ragged_tensor('[ [[] [3]] [[10]] ]', dtype=torch.float32) >>> b [ [ [ ] [ 3 ] ] [ [ 10 ] ] ] >>> b.dtype @@ -109,6 +132,10 @@ torch.float32 >>> c.dtype torch.float32 +Note: + Number of spaces or commas in ``s`` does not affect the result. + Of course, numbers have to be separated by at least one space or comma. + Args: s: A string representation of a ragged tensor. @@ -117,6 +144,11 @@ torch.float32 to infer the correct dtype from ``s``, which is assumed to be either ``torch.int32`` or ``torch.float32``. Supported dtypes are: ``torch.int32``, ``torch.float32``, and ``torch.float64``. + device: + It can be either an instance of ``torch.device`` or + a string representing a torch device. Example + values are: ``"cpu"``, ``"cuda:0"``, ``torch.device("cpu")``, + ``torch.device("cuda", 0)``. Returns: Return a ragged tensor. )doc"; @@ -148,14 +180,16 @@ Create a ragged tensor from a torch tensor. tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.int32) >>> b - [ [ 0 1 2 ] [ 3 4 5 ] ] + RaggedTensor([[0, 1, 2], + [3, 4, 5]], dtype=torch.int32) >>> b.dtype torch.int32 >>> a.is_contiguous() True >>> a[0, 0] = 10 >>> b - [ [ 10 1 2 ] [ 3 4 5 ] ] + RaggedTensor([[10, 1, 2], + [3, 4, 5]], dtype=torch.int32) >>> b.values[1] = -2 >>> a tensor([[10, -2, 2], @@ -172,17 +206,17 @@ Create a ragged tensor from a torch tensor. False >>> b = k2r.create_ragged_tensor(a) >>> b - [ [ 0 4 8 ] [ 12 16 20 ] ] + RaggedTensor([[0, 4, 8], + [12, 16, 20]], dtype=torch.int32) >>> b.dtype torch.int32 >>> a[0, 0] = 10 >>> b - [ [ 0 4 8 ] [ 12 16 20 ] ] + RaggedTensor([[0, 4, 8], + [12, 16, 20]], dtype=torch.int32) >>> a tensor([[10, 4, 8], [12, 16, 20]], dtype=torch.int32) - >>> b - [ [ 0 -2 8 ] [ 12 16 20 ] ] **Example 3**: @@ -198,9 +232,12 @@ Create a ragged tensor from a torch tensor. [20., 21., 22., 23.]]]) >>> b = k2r.create_ragged_tensor(a) >>> b - [ [ [ 0 1 2 3 ] [ 4 5 6 7 ] [ 8 9 10 11 ] ] [ [ 12 13 14 15 ] [ 16 17 18 19 ] [ 20 21 22 23 ] ] ] - >>> b.dtype - torch.float32 + RaggedTensor([[[0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11]], + [[12, 13, 14, 15], + [16, 17, 18, 19], + [20, 21, 22, 23]]], dtype=torch.float32) Args: tensor: @@ -218,7 +255,9 @@ Create a ragged tensor from a shape and a value. >>> value = torch.tensor([10, 0, 20, 30, 40], dtype=torch.float32) >>> ragged = k2r.RaggedTensor(shape, value) >>> ragged -[ [ 10 0 ] [ ] [ 20 30 40 ] ] +RaggedTensor([[10, 0], + [], + [20, 30, 40]], dtype=torch.float32) Args: shape: @@ -227,42 +266,46 @@ Create a ragged tensor from a shape and a value. The value of the tensor. )doc"; -static constexpr const char *kRaggedAnyInitDataDoc = R"doc( +static constexpr const char *kRaggedAnyInitDataDeviceDoc = R"doc( Create a ragged tensor with arbitrary number of axes. Note: A ragged tensor has at least two axes. -Hint: - The returned tensor is on CPU. - **Example 1**: >>> import torch >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([ [1, 2], [5], [], [9] ]) >>> a - [ [ 1 2 ] [ 5 ] [ ] [ 9 ] ] + RaggedTensor([[1, 2], + [5], + [], + [9]], dtype=torch.int32) >>> a.dtype torch.int32 >>> b = k2r.RaggedTensor([ [1, 3.0], [] ]) >>> b - [ [ 1 3 ] [ ] ] + RaggedTensor([[1, 3], + []], dtype=torch.float32) >>> b.dtype torch.float32 >>> c = k2r.RaggedTensor([ [1] ], dtype=torch.float64) >>> c - [ [ 1 ] ] + RaggedTensor([[1]], dtype=torch.float64) >>> c.dtype torch.float64 >>> d = k2r.RaggedTensor([ [[1], [2, 3]], [[4], []] ]) >>> d - [ [ [ 1 ] [ 2 3 ] ] [ [ 4 ] [ ] ] ] + RaggedTensor([[[1], + [2, 3]], + [[4], + []]], dtype=torch.int32) >>> d.num_axes 3 >>> e = k2r.RaggedTensor([]) >>> e - [ ] + RaggedTensor([], dtype=torch.int32) >>> e.num_axes 2 >>> e.shape.row_splits(1) @@ -273,7 +316,13 @@ Create a ragged tensor with arbitrary number of axes. **Example 2**: >>> k2r.RaggedTensor([ [[1, 2]], [], [[]] ]) - [ [ [ 1 2 ] ] [ ] [ [ ] ] ] + RaggedTensor([[[1, 2]], + [], + [[]]], dtype=torch.int32) + >>> k2r.RaggedTensor([ [[1, 2]], [], [[]] ], device='cuda:0') + RaggedTensor([[[1, 2]], + [], + [[]]], device='cuda:0', dtype=torch.int32) Args: data: @@ -284,34 +333,42 @@ Create a ragged tensor with arbitrary number of axes. automatically, which is either ``torch.int32`` or ``torch.float32``. Supported dtypes are: ``torch.int32``, ``torch.float32``, and ``torch.float64``. + device: + It can be either an instance of ``torch.device`` or + a string representing a torch device. Example + values are: ``"cpu"``, ``"cuda:0"``, ``torch.device("cpu")``, + ``torch.device("cuda", 0)``. )doc"; -static constexpr const char *kRaggedAnyInitStrDoc = R"doc( +static constexpr const char *kRaggedAnyInitStrDeviceDoc = R"doc( Create a ragged tensor from its string representation. +Fields are separated by space(s) **or** comma(s). + An example string for a 2-axis ragged tensor is given below:: - [ [1] [2] ] + [ [1] [2] [3, 4], [5 6 7, 8] ] An example string for a 3-axis ragged tensor is given below:: [ [[1]] [[]] ] -Hint: - The returned tensor is on CPU. - >>> import torch >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor('[ [1] [] [3 4] ]') >>> a -[ [ 1 ] [ ] [ 3 4 ] ] +RaggedTensor([[1], + [], + [3, 4]], dtype=torch.int32) >>> a.num_axes 2 >>> a.dtype torch.int32 >>> b = k2r.RaggedTensor('[ [[] [3]] [[10]] ]', dtype=torch.float32) >>> b -[ [ [ ] [ 3 ] ] [ [ 10 ] ] ] +RaggedTensor([[[], + [3]], + [[10]]], dtype=torch.float32) >>> b.dtype torch.float32 >>> b.num_axes @@ -319,10 +376,13 @@ torch.float32 >>> c = k2r.RaggedTensor('[[1.]]') >>> c.dtype torch.float32 +>>> d = k2r.RaggedTensor('[[1.]]', device='cuda:0') +>>> d +RaggedTensor([[1]], device='cuda:0', dtype=torch.float32) Note: - Number of spaces in ``s`` does not affect the result. - Of course, numbers have to be separated by at least one space. + Number of spaces or commas in ``s`` does not affect the result. + Of course, numbers have to be separated by at least one space or comma. Args: s: @@ -332,6 +392,11 @@ torch.float32 to infer the correct dtype from ``s``, which is assumed to be either ``torch.int32`` or ``torch.float32``. Supported dtypes are: ``torch.int32``, ``torch.float32``, and ``torch.float64``. + device: + It can be either an instance of ``torch.device`` or + a string representing a torch device. Example + values are: ``"cpu"``, ``"cuda:0"``, ``torch.device("cpu")``, + ``torch.device("cuda", 0)``. )doc"; static constexpr const char *kRaggedAnyInitTensorDoc = R"doc( @@ -361,12 +426,14 @@ Create a ragged tensor from a torch tensor. tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.int32) >>> b - [ [ 0 1 2 ] [ 3 4 5 ] ] + RaggedTensor([[0, 1, 2], + [3, 4, 5]], dtype=torch.int32) >>> a.is_contiguous() True >>> a[0, 0] = 10 >>> b - [ [ 10 1 2 ] [ 3 4 5 ] ] + RaggedTensor([[10, 1, 2], + [3, 4, 5]], dtype=torch.int32) >>> b.values[1] = -2 >>> a tensor([[10, -2, 2], @@ -383,15 +450,15 @@ Create a ragged tensor from a torch tensor. False >>> b = k2r.RaggedTensor(a) >>> b - [ [ 0 4 8 ] [ 12 16 20 ] ] + RaggedTensor([[0, 4, 8], + [12, 16, 20]], dtype=torch.int32) >>> a[0, 0] = 10 >>> b - [ [ 0 4 8 ] [ 12 16 20 ] ] + RaggedTensor([[0, 4, 8], + [12, 16, 20]], dtype=torch.int32) >>> a tensor([[10, 4, 8], [12, 16, 20]], dtype=torch.int32) - >>> b - [ [ 0 -2 8 ] [ 12 16 20 ] ] **Example 3**: @@ -407,10 +474,17 @@ Create a ragged tensor from a torch tensor. [20., 21., 22., 23.]]]) >>> b = k2r.RaggedTensor(a) >>> b - [ [ [ 0 1 2 3 ] [ 4 5 6 7 ] [ 8 9 10 11 ] ] [ [ 12 13 14 15 ] [ 16 17 18 19 ] [ 20 21 22 23 ] ] ] + RaggedTensor([[[0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11]], + [[12, 13, 14, 15], + [16, 17, 18, 19], + [20, 21, 22, 23]]], dtype=torch.float32) >>> b.dtype torch.float32 - + >>> c = torch.tensor([[1, 2]], device='cuda:0', dtype=torch.float32) + >>> k2r.RaggedTensor(c) + RaggedTensor([[1, 2]], device='cuda:0', dtype=torch.float32) Args: tensor: @@ -511,9 +585,14 @@ Return a string representation of this tensor. >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([[1], [2, 3], []]) >>> a -[ [ 1 ] [ 2 3 ] [ ] ] +RaggedTensor([[1], + [2, 3], + []], dtype=torch.int32) >>> str(a) -'[ [ 1 ] [ 2 3 ] [ ] ]' +'RaggedTensor([[1],\n [2, 3],\n []], dtype=torch.int32)' +>>> b = k2r.RaggedTensor([[1, 2]], device='cuda:0') +>>> b +RaggedTensor([[1, 2]], device='cuda:0', dtype=torch.int32) )doc"; static constexpr const char *kRaggedAnyGetItemDoc = R"doc( @@ -522,23 +601,34 @@ Select the i-th sublist along axis 0. Caution: Support for autograd is to be implemented. ->>> import torch ->>> import k2.ragged as k2r ->>> a = k2r.RaggedTensor('[ [[1 3] [] [9]] [[8]] ]') ->>> a -[ [ [ 1 3 ] [ ] [ 9 ] ] [ [ 8 ] ] ] ->>> a[0] -[ [ 1 3 ] [ ] [ 9 ] ] ->>> a[1] -[ [ 8 ] ] +**Example 1**: ->>> a = k2r.RaggedTensor('[ [1 3] [9] [8] ]') ->>> a -[ [ 1 3 ] [ 9 ] [ 8 ] ] ->>> a[0] -tensor([1, 3], dtype=torch.int32) ->>> a[1] -tensor([9], dtype=torch.int32) + >>> import torch + >>> import k2.ragged as k2r + >>> a = k2r.RaggedTensor('[ [[1 3] [] [9]] [[8]] ]') + >>> a + RaggedTensor([[[1, 3], + [], + [9]], + [[8]]], dtype=torch.int32) + >>> a[0] + RaggedTensor([[1, 3], + [], + [9]], dtype=torch.int32) + >>> a[1] + RaggedTensor([[8]], dtype=torch.int32) + +**Example 2**: + + >>> a = k2r.RaggedTensor('[ [1 3] [9] [8] ]') + >>> a + RaggedTensor([[1, 3], + [9], + [8]], dtype=torch.int32) + >>> a[0] + tensor([1, 3], dtype=torch.int32) + >>> a[1] + tensor([9], dtype=torch.int32) Args: i: @@ -559,11 +649,18 @@ equals to 1. >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor('[ [[1 3] [] [9]] [[8]] [[10 11]] ]') >>> a -[ [ [ 1 3 ] [ ] [ 9 ] ] [ [ 8 ] ] [ [ 10 11 ] ] ] +RaggedTensor([[[1, 3], + [], + [9]], + [[8]], + [[10, 11]]], dtype=torch.int32) >>> a[0:2] -[ [ [ 1 3 ] [ ] [ 9 ] [ [ 8 ] ] ] ] +RaggedTensor([[[1, 3], + [], + [9]], + [[8]]], dtype=torch.int32) >>> a[1:2] -[ [ [ 8 ] ] [ [ 10 11 ] ] ] +RaggedTensor([[[8]]], dtype=torch.int32) Args: key: @@ -582,19 +679,25 @@ Return a copy of this tensor. >>> b = a >>> c = a.clone() >>> a -[ [ 1 2 ] [ 3 ] ] +RaggedTensor([[1, 2], + [3]], dtype=torch.int32) >>> b.values[0] = 10 >>> a -[ [ 10 2 ] [ 3 ] ] +RaggedTensor([[10, 2], + [3]], dtype=torch.int32) >>> c -[ [ 1 2 ] [ 3 ] ] +RaggedTensor([[1, 2], + [3]], dtype=torch.int32) >>> c.values[0] = -1 >>> c -[ [ -1 2 ] [ 3 ] ] +RaggedTensor([[-1, 2], + [3]], dtype=torch.int32) >>> a -[ [ 10 2 ] [ 3 ] ] +RaggedTensor([[10, 2], + [3]], dtype=torch.int32) >>> b -[ [ 10 2 ] [ 3 ] ] +RaggedTensor([[10, 2], + [3]], dtype=torch.int32) )doc"; static constexpr const char *kRaggedAnyEqDoc = R"doc( @@ -673,7 +776,10 @@ calls to ``backward()`` will accumulate (add) gradients into it. >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([[1, 2], [3], [5, 6], []], dtype=torch.float32) >>> a.requires_grad_(True) -[ [ 1 2 ] [ 3 ] [ 5 6 ] [ ] ] +RaggedTensor([[1, 2], + [3], + [5, 6], + []], dtype=torch.float32) >>> b = a.sum() >>> b tensor([ 3., 3., 11., 0.], grad_fn=>) @@ -701,7 +807,7 @@ this tensor's :attr:`requires_grad` attribute **in-place**. >>> a.requires_grad False >>> a.requires_grad_(True) -[ [ 1 ] ] +RaggedTensor([[1]], dtype=torch.float64) >>> a.requires_grad True @@ -726,7 +832,10 @@ Compute the sum of sublists over the last axis of this tensor. >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor('[ [[1 2] [] [5]] [[10]] ]', dtype=torch.float32) >>> a.requires_grad_(True) -[ [ [ 1 2 ] [ ] [ 5 ] ] [ [ 10 ] ] ] +RaggedTensor([[[1, 2], + [], + [5]], + [[10]]], dtype=torch.float32) >>> b = a.sum() >>> c = (b * torch.arange(4)).sum() >>> c.backward() @@ -762,6 +871,7 @@ static constexpr const char *kRaggedAnyNumelDoc = R"doc( >>> c.numel() 5 )doc"; + static constexpr const char *kRaggedAnyTotSizeDoc = R"doc( Return the number of elements of an given axis. If axis is 0, it's equivalent to the property ``dim0``. @@ -859,13 +969,25 @@ tensor([ 1, 2, 5, 8, 9, 10], dtype=torch.int32) True >>> a.values[-2] = -1 >>> a -[ [ -1 2 ] [ ] [ 5 ] [ ] [ 8 9 10 ] ] +RaggedTensor([[1, 2], + [], + [5], + [], + [8, -1, 10]], dtype=torch.int32) >>> a.values[3] = -3 >>> a -[ [ -1 2 ] [ ] [ 5 ] [ ] [ -3 9 10 ] ] +RaggedTensor([[1, 2], + [], + [5], + [], + [-3, -1, 10]], dtype=torch.int32) >>> a.values[2] = -2 >>> a -[ [ -1 2 ] [ ] [ -2 ] [ ] [ -3 9 10 ] ] +RaggedTensor([[1, 2], + [], + [-2], + [], + [-3, -1, 10]], dtype=torch.int32) )doc"; static constexpr const char *kRaggedAnyShapeDoc = R"doc( @@ -938,15 +1060,32 @@ last axis it is just removed and the number of elements may be changed. >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([ [[1], [], [0, -1]], [[], [2, 3], []], [[0]], [[]] ]) >>> a - [ [ [ 1 ] [ ] [ 0 -1 ] ] [ [ ] [ 2 3 ] [ ] ] [ [ 0 ] ] [ [ ] ] ] + RaggedTensor([[[1], + [], + [0, -1]], + [[], + [2, 3], + []], + [[0]], + [[]]], dtype=torch.int32) >>> a.num_axes 3 >>> b = a.remove_axis(0) >>> b - [ [ 1 ] [ ] [ 0 -1 ] [ ] [ 2 3 ] [ ] [ 0 ] [ ] ] + RaggedTensor([[1], + [], + [0, -1], + [], + [2, 3], + [], + [0], + []], dtype=torch.int32) >>> c = a.remove_axis(1) >>> c - [ [ 1 0 -1 ] [ 2 3 ] [ 0 ] [ ] ] + RaggedTensor([[1, 0, -1], + [2, 3], + [0], + []], dtype=torch.int32) **Example 2**: @@ -954,16 +1093,42 @@ last axis it is just removed and the number of elements may be changed. >>> a.num_axes 4 >>> a - [ [ [ [ 1 ] [ ] [ 2 ] ] ] [ [ [ 3 4 ] [ ] [ 5 6 ] [ ] ] ] [ [ [ ] [ 0 ] ] ] ] + RaggedTensor([[[[1], + [], + [2]]], + [[[3, 4], + [], + [5, 6], + []]], + [[[], + [0]]]], dtype=torch.int32) >>> b = a.remove_axis(0) >>> b - [ [ [ 1 ] [ ] [ 2 ] ] [ [ 3 4 ] [ ] [ 5 6 ] [ ] ] [ [ ] [ 0 ] ] ] + RaggedTensor([[[1], + [], + [2]], + [[3, 4], + [], + [5, 6], + []], + [[], + [0]]], dtype=torch.int32) >>> c = a.remove_axis(1) >>> c - [ [ [ 1 ] [ ] [ 2 ] ] [ [ 3 4 ] [ ] [ 5 6 ] [ ] ] [ [ ] [ 0 ] ] ] + RaggedTensor([[[1], + [], + [2]], + [[3, 4], + [], + [5, 6], + []], + [[], + [0]]], dtype=torch.int32) >>> d = a.remove_axis(2) >>> d - [ [ [ 1 2 ] ] [ [ 3 4 5 6 ] ] [ [ 0 ] ] ] + RaggedTensor([[[1, 2]], + [[3, 4, 5, 6]], + [[0]]], dtype=torch.int32) Args: axis: @@ -993,27 +1158,46 @@ The ``axis`` argument may be confusing; its behavior is equivalent to: >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([ [[1], [], [2]], [[], [4, 5], []], [[], [1]], [[]] ]) >>> a - [ [ [ 1 ] [ ] [ 2 ] ] [ [ ] [ 4 5 ] [ ] ] [ [ ] [ 1 ] ] [ [ ] ] ] + RaggedTensor([[[1], + [], + [2]], + [[], + [4, 5], + []], + [[], + [1]], + [[]]], dtype=torch.int32) >>> a.num_axes 3 >>> b = a.arange(axis=0, begin=1, end=3) >>> b - [ [ [ ] [ 4 5 ] [ ] ] [ [ ] [ 1 ] ] ] + RaggedTensor([[[], + [4, 5], + []], + [[], + [1]]], dtype=torch.int32) >>> b.num_axes 3 >>> c = a.arange(axis=0, begin=1, end=2) >>> c - [ [ [ ] [ 4 5 ] [ ] ] ] + RaggedTensor([[[], + [4, 5], + []]], dtype=torch.int32) >>> c.num_axes 3 >>> d = a.arange(axis=1, begin=0, end=4) >>> d - [ [ 1 ] [ ] [ 2 ] [ ] ] + RaggedTensor([[1], + [], + [2], + []], dtype=torch.int32) >>> d.num_axes 2 >>> e = a.arange(axis=1, begin=2, end=5) >>> e - [ [ 2 ] [ ] [ 4 5 ] ] + RaggedTensor([[2], + [], + [4, 5]], dtype=torch.int32) >>> e.num_axes 2 @@ -1024,17 +1208,34 @@ The ``axis`` argument may be confusing; its behavior is equivalent to: 4 >>> b = a.arange(axis=0, begin=0, end=2) >>> b - [ [ [ [ ] [ 1 ] [ 2 3 ] ] [ [ 5 8 ] [ ] [ 9 ] ] ] [ [ [ 10 ] [ 0 ] [ ] ] ] ] + RaggedTensor([[[[], + [1], + [2, 3]], + [[5, 8], + [], + [9]]], + [[[10], + [0], + []]]], dtype=torch.int32) >>> b.num_axes 4 >>> c = a.arange(axis=1, begin=1, end=3) >>> c - [ [ [ 5 8 ] [ ] [ 9 ] ] [ [ 10 ] [ 0 ] [ ] ] ] + RaggedTensor([[[5, 8], + [], + [9]], + [[10], + [0], + []]], dtype=torch.int32) >>> c.num_axes 3 >>> d = a.arange(axis=2, begin=0, end=5) >>> d - [ [ ] [ 1 ] [ 2 3 ] [ 5 8 ] [ ] ] + RaggedTensor([[], + [1], + [2, 3], + [5, 8], + []], dtype=torch.int32) >>> d.num_axes 2 @@ -1042,15 +1243,25 @@ The ``axis`` argument may be confusing; its behavior is equivalent to: >>> a = k2r.RaggedTensor([[0], [1], [2], [], [3]]) >>> a - [ [ 0 ] [ 1 ] [ 2 ] [ ] [ 3 ] ] + RaggedTensor([[0], + [1], + [2], + [], + [3]], dtype=torch.int32) >>> a.num_axes 2 >>> b = a.arange(axis=0, begin=1, end=4) >>> b - [ [ 1 ] [ 2 ] [ ] ] + RaggedTensor([[1], + [2], + []], dtype=torch.int32) >>> b.values[0] = -1 >>> a - [ [ 0 ] [ -1 ] [ 2 ] [ ] [ 3 ] ] + RaggedTensor([[0], + [-1], + [2], + [], + [3]], dtype=torch.int32) Args: axis: @@ -1068,12 +1279,23 @@ target. Leaves all layers of the shape except for the last one unaffected. >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([[1, 2, 3, 0, 3, 2], [], [3, 2, 3], [3]]) +>>> a +RaggedTensor([[1, 2, 3, 0, 3, 2], + [], + [3, 2, 3], + [3]], dtype=torch.int32) >>> b = a.remove_values_eq(3) >>> b -[ [ 1 2 0 2 ] [ ] [ 2 ] [ ] ] +RaggedTensor([[1, 2, 0, 2], + [], + [2], + []], dtype=torch.int32) >>> c = a.remove_values_eq(2) >>> c -[ [ 1 3 0 3 ] [ ] [ 3 3 ] [ 3 ] ] +RaggedTensor([[1, 3, 0, 3], + [], + [3, 3], + [3]], dtype=torch.int32) Args: target: @@ -1089,15 +1311,29 @@ Leaves all layers of the shape except for the last one unaffected. >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([[1, 2, 3, 0, 3, 2], [], [3, 2, 3], [3]]) +>>> a +RaggedTensor([[1, 2, 3, 0, 3, 2], + [], + [3, 2, 3], + [3]], dtype=torch.int32) >>> b = a.remove_values_leq(3) >>> b -[ [ ] [ ] [ ] [ ] ] +RaggedTensor([[], + [], + [], + []], dtype=torch.int32) >>> c = a.remove_values_leq(2) >>> c -[ [ 3 3 ] [ ] [ 3 3 ] [ 3 ] ] +RaggedTensor([[3, 3], + [], + [3, 3], + [3]], dtype=torch.int32) >>> d = a.remove_values_leq(1) >>> d -[ [ 2 3 3 2 ] [ ] [ 3 2 3 ] [ 3 ] ] +RaggedTensor([[2, 3, 3, 2], + [], + [3, 2, 3], + [3]], dtype=torch.int32) Args: cutoff: @@ -1217,9 +1453,16 @@ Concatenate a list of ragged tensor over a specified axis. >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([[1], [], [2, 3]]) >>> k2r.cat([a, a], axis=0) - [ [ 1 ] [ ] [ 2 3 ] [ 1 ] [ ] [ 2 3 ] ] + RaggedTensor([[1], + [], + [2, 3], + [1], + [], + [2, 3]], dtype=torch.int32) >>> k2r.cat((a, a), axis=1) - [ [ 1 1 ] [ ] [ 2 3 2 3 ] ] + RaggedTensor([[1, 1], + [], + [2, 3, 2, 3]], dtype=torch.int32) **Example 2** @@ -1228,18 +1471,44 @@ Concatenate a list of ragged tensor over a specified axis. >>> b = k2r.RaggedTensor([[0], [1, 8], [], [-1], [10]]) >>> c = k2r.cat([a, b], axis=0) >>> c - [ [ 1 3 ] [ ] [ 5 8 ] [ ] [ 9 ] [ 0 ] [ 1 8 ] [ ] [ -1 ] [ 10 ] ] + RaggedTensor([[1, 3], + [], + [5, 8], + [], + [9], + [0], + [1, 8], + [], + [-1], + [10]], dtype=torch.int32) >>> c.num_axes 2 >>> d = k2r.cat([a, b], axis=1) >>> d - [ [ 1 3 0 ] [ 1 8 ] [ 5 8 ] [ -1 ] [ 9 10 ] ] + RaggedTensor([[1, 3, 0], + [1, 8], + [5, 8], + [-1], + [9, 10]], dtype=torch.int32) >>> d.num_axes 2 >>> k2r.RaggedTensor.cat([a, b], axis=1) - [ [ 1 3 0 ] [ 1 8 ] [ 5 8 ] [ -1 ] [ 9 10 ] ] + RaggedTensor([[1, 3, 0], + [1, 8], + [5, 8], + [-1], + [9, 10]], dtype=torch.int32) >>> k2r.cat((b, a), axis=0) - [ [ 0 ] [ 1 8 ] [ ] [ -1 ] [ 10 ] [ 1 3 ] [ ] [ 5 8 ] [ ] [ 9 ] ] + RaggedTensor([[0], + [1, 8], + [], + [-1], + [10], + [1, 3], + [], + [5, 8], + [], + [9]], dtype=torch.int32) Args: srcs: @@ -1275,34 +1544,79 @@ index on axis 0; if more than 3 axes, the earliest axes will be ignored. >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([[3, 1], [3], [1], [1], [3, 1], [2]]) >>> a.unique() - ([ [ 1 ] [ 2 ] [ 3 ] [ 3 1 ] ], None, None) + (RaggedTensor([[1], + [2], + [3], + [3, 1]], dtype=torch.int32), None, None) >>> a.unique(need_num_repeats=True, need_new2old_indexes=True) - ([ [ 1 ] [ 2 ] [ 3 ] [ 3 1 ] ], [ [ 2 1 1 2 ] ], tensor([2, 5, 1, 0], dtype=torch.int32)) + (RaggedTensor([[1], + [2], + [3], + [3, 1]], dtype=torch.int32), RaggedTensor([[2, 1, 1, 2]], dtype=torch.int32), tensor([2, 5, 1, 0], dtype=torch.int32)) >>> a.unique(need_num_repeats=True) - ([ [ 1 ] [ 2 ] [ 3 ] [ 3 1 ] ], [ [ 2 1 1 2 ] ], None) + (RaggedTensor([[1], + [2], + [3], + [3, 1]], dtype=torch.int32), RaggedTensor([[2, 1, 1, 2]], dtype=torch.int32), None) >>> a.unique(need_new2old_indexes=True) - ([ [ 1 ] [ 2 ] [ 3 ] [ 3 1 ] ], None, tensor([2, 5, 1, 0], dtype=torch.int32)) + (RaggedTensor([[1], + [2], + [3], + [3, 1]], dtype=torch.int32), None, tensor([2, 5, 1, 0], dtype=torch.int32)) **Example 2** >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([[[1, 2], [2, 1], [1, 2], [1, 2]], [[3], [2], [0, 1], [2]], [[], [2, 3], [], [3]] ]) >>> a.unique() - ([ [ [ 1 2 ] [ 2 1 ] ] [ [ 2 ] [ 3 ] [ 0 1 ] ] [ [ ] [ 3 ] [ 2 3 ] ] ], None, None) + (RaggedTensor([[[1, 2], + [2, 1]], + [[2], + [3], + [0, 1]], + [[], + [3], + [2, 3]]], dtype=torch.int32), None, None) >>> a.unique(need_num_repeats=True, need_new2old_indexes=True) - ([ [ [ 1 2 ] [ 2 1 ] ] [ [ 2 ] [ 3 ] [ 0 1 ] ] [ [ ] [ 3 ] [ 2 3 ] ] ], [ [ 3 1 ] [ 2 1 1 ] [ 2 1 1 ] ], tensor([ 0, 1, 5, 4, 6, 8, 11, 9], dtype=torch.int32)) + (RaggedTensor([[[1, 2], + [2, 1]], + [[2], + [3], + [0, 1]], + [[], + [3], + [2, 3]]], dtype=torch.int32), RaggedTensor([[3, 1], + [2, 1, 1], + [2, 1, 1]], dtype=torch.int32), tensor([ 0, 1, 5, 4, 6, 8, 11, 9], dtype=torch.int32)) >>> a.unique(need_num_repeats=True) - ([ [ [ 1 2 ] [ 2 1 ] ] [ [ 2 ] [ 3 ] [ 0 1 ] ] [ [ ] [ 3 ] [ 2 3 ] ] ], [ [ 3 1 ] [ 2 1 1 ] [ 2 1 1 ] ], None) + (RaggedTensor([[[1, 2], + [2, 1]], + [[2], + [3], + [0, 1]], + [[], + [3], + [2, 3]]], dtype=torch.int32), RaggedTensor([[3, 1], + [2, 1, 1], + [2, 1, 1]], dtype=torch.int32), None) >>> a.unique(need_new2old_indexes=True) - ([ [ [ 1 2 ] [ 2 1 ] ] [ [ 2 ] [ 3 ] [ 0 1 ] ] [ [ ] [ 3 ] [ 2 3 ] ] ], None, tensor([ 0, 1, 5, 4, 6, 8, 11, 9], dtype=torch.int32)) + (RaggedTensor([[[1, 2], + [2, 1]], + [[2], + [3], + [0, 1]], + [[], + [3], + [2, 3]]], dtype=torch.int32), None, tensor([ 0, 1, 5, 4, 6, 8, 11, 9], dtype=torch.int32)) **Example 3** >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([[1], [3], [2]]) >>> a.unique(True, True) - ([ [ 1 ] [ 2 ] [ 3 ] ], [ [ 1 1 1 ] ], tensor([0, 2, 1], dtype=torch.int32)) - + (RaggedTensor([[1], + [2], + [3]], dtype=torch.int32), RaggedTensor([[1, 1, 1]], dtype=torch.int32), tensor([0, 2, 1], dtype=torch.int32)) Args: need_num_repeats: @@ -1370,14 +1684,26 @@ If ``use_log`` is ``False``, the normalization per sublist is done as follows: >>> import k2.ragged as k2r >>> a = k2r.RaggedTensor([[0.1, 0.3], [], [1], [0.2, 0.8]]) >>> a.normalize(use_log=False) -[ [ 0.25 0.75 ] [ ] [ 1 ] [ 0.2 0.8 ] ] +RaggedTensor([[0.25, 0.75], + [], + [1], + [0.2, 0.8]], dtype=torch.float32) >>> a.normalize(use_log=True) -[ [ -0.798139 -0.598139 ] [ ] [ 0 ] [ -1.03749 -0.437488 ] ] +RaggedTensor([[-0.798139, -0.598139], + [], + [0], + [-1.03749, -0.437488]], dtype=torch.float32) >>> b = k2r.RaggedTensor([ [[0.1, 0.3], []], [[1], [0.2, 0.8]] ]) >>> b.normalize(use_log=False) -[ [ [ 0.25 0.75 ] [ ] ] [ [ 1 ] [ 0.2 0.8 ] ] ] +RaggedTensor([[[0.25, 0.75], + []], + [[1], + [0.2, 0.8]]], dtype=torch.float32) >>> b.normalize(use_log=True) -[ [ [ -0.798139 -0.598139 ] [ ] ] [ [ 0 ] [ -1.03749 -0.437488 ] ] ] +RaggedTensor([[[-0.798139, -0.598139], + []], + [[0], + [-1.03749, -0.437488]]], dtype=torch.float32) >>> a.num_axes 2 >>> b.num_axes @@ -1471,7 +1797,10 @@ Sort a ragged tensor over the last axis **in-place**. >>> b tensor([1, 0, 2, 4, 5, 3, 7, 6, 8], dtype=torch.int32) >>> a -[ [ 3 1 0 ] [ 5 3 2 ] [ ] [ 3 1 0 ] ] +RaggedTensor([[3, 1, 0], + [5, 3, 2], + [], + [3, 1, 0]], dtype=torch.float32) >>> a_clone.values[b.long()] tensor([3., 1., 0., 5., 3., 2., 3., 1., 0.]) >>> a_clone = a.clone() @@ -1479,7 +1808,10 @@ tensor([3., 1., 0., 5., 3., 2., 3., 1., 0.]) >>> c tensor([2, 1, 0, 5, 4, 3, 8, 7, 6], dtype=torch.int32) >>> a -[ [ 0 1 3 ] [ 2 3 5 ] [ ] [ 0 1 3 ] ] +RaggedTensor([[0, 1, 3], + [2, 3, 5], + [], + [0, 1, 3]], dtype=torch.float32) >>> a_clone.values[c.long()] tensor([0., 1., 3., 2., 3., 5., 0., 1., 3.]) @@ -1505,10 +1837,14 @@ Index a ragged tensor with a ragged tensor. >>> src = k2r.RaggedTensor([[10, 11], [12, 13.5]]) >>> indexes = k2r.RaggedTensor([[0, 1]]) >>> src.index(indexes) - [ [ [ 10 11 ] [ 12 13.5 ] ] ] + RaggedTensor([[[10, 11], + [12, 13.5]]], dtype=torch.float32) >>> i = k2r.RaggedTensor([[0], [1], [0, 0]]) >>> src.index(i) - [ [ [ 10 11 ] ] [ [ 12 13.5 ] ] [ [ 10 11 ] [ 10 11 ] ] ] + RaggedTensor([[[10, 11]], + [[12, 13.5]], + [[10, 11], + [10, 11]]], dtype=torch.float32) **Example 2**: @@ -1516,7 +1852,17 @@ Index a ragged tensor with a ragged tensor. >>> src = k2r.RaggedTensor([ [[1, 0], [], [2]], [[], [3], [0, 0, 1]], [[1, 2], [-1]]]) >>> i = k2r.RaggedTensor([[[0, 2], [1]], [[0]]]) >>> src.index(i) - [ [ [ [ [ 1 0 ] [ ] [ 2 ] ] [ [ 1 2 ] [ -1 ] ] ] [ [ [ ] [ 3 ] [ 0 0 1 ] ] ] ] [ [ [ [ 1 0 ] [ ] [ 2 ] ] ] ] ] + RaggedTensor([[[[[1, 0], + [], + [2]], + [[1, 2], + [-1]]], + [[[], + [3], + [0, 0, 1]]]], + [[[[1, 0], + [], + [2]]]]], dtype=torch.int32) Args: indexes: @@ -1545,14 +1891,19 @@ the elements of ``indexes`` are interpreted as indexes into axis ``axis`` of >>> i = torch.tensor([2, 0, 3, 5], dtype=torch.int32) >>> b, value_indexes = a.index(i, axis=0, need_value_indexes=True) >>> b - [ [ 0 1 2 ] [ 0 2 3 ] [ ] [ 3 -1.25 ] ] + RaggedTensor([[0, 1, 2], + [0, 2, 3], + [], + [3, -1.25]], dtype=torch.float32) >>> value_indexes tensor([3, 4, 5, 0, 1, 2, 6, 7], dtype=torch.int32) >>> a.values[value_indexes.long()] tensor([ 0.0000, 1.0000, 2.0000, 0.0000, 2.0000, 3.0000, 3.0000, -1.2500]) >>> k = torch.tensor([2, -1, 0], dtype=torch.int32) >>> a.index(k, axis=0, need_value_indexes=True) - ([ [ 0 1 2 ] [ ] [ 0 2 3 ] ], tensor([3, 4, 5, 0, 1, 2], dtype=torch.int32)) + (RaggedTensor([[0, 1, 2], + [], + [0, 2, 3]], dtype=torch.float32), tensor([3, 4, 5, 0, 1, 2], dtype=torch.int32)) **Example 2**: @@ -1563,13 +1914,18 @@ the elements of ``indexes`` are interpreted as indexes into axis ``axis`` of tensor([0, 0, 0, 1, 1, 1, 1], dtype=torch.int32) >>> b, value_indexes = a.index(i, axis=1, need_value_indexes=True) >>> b - [ [ [ 1 3 ] [ 2 ] [ ] ] [ [ 2 ] [ 5 8 ] [ -1 ] [ ] ] ] + RaggedTensor([[[1, 3], + [2], + []], + [[2], + [5, 8], + [-1], + []]], dtype=torch.int32) >>> value_indexes tensor([0, 1, 2, 6, 3, 4, 5], dtype=torch.int32) >>> a.values[value_indexes.long()] tensor([ 1, 3, 2, 2, 5, 8, -1], dtype=torch.int32) - Args: indexes: Array of indexes, which will be interpreted as indexes into axis ``axis`` @@ -1607,15 +1963,23 @@ Use a ragged tensor to index a 1-d torch tensor. >>> src tensor([ 0, 10, 20, 30, 40, 50], dtype=torch.int32) >>> k2r.index(src, i) -[ [ 10 50 30 ] [ 0 20 ] ] +RaggedTensor([[10, 50, 30], + [0, 20]], dtype=torch.int32) >>> k = k2r.RaggedTensor([ [[1, 5, 3], [0]], [[0, 2], [1, 3]] ]) >>> k2r.index(src, k) -[ [ [ 10 50 30 ] [ 0 ] ] [ [ 0 20 ] [ 10 30 ] ] ] +RaggedTensor([[[10, 50, 30], + [0]], + [[0, 20], + [10, 30]]], dtype=torch.int32) >>> n = k2r.RaggedTensor([ [1, -1], [-1, 0], [-1] ]) >>> k2r.index(src, n) -[ [ 10 0 ] [ 0 0 ] [ 0 ] ] +RaggedTensor([[10, 0], + [0, 0], + [0]], dtype=torch.int32) >>> k2r.index(src, n, default_value=-2) -[ [ 10 -2 ] [ -2 0 ] [ -2 ] ] +RaggedTensor([[10, -2], + [-2, 0], + [-2]], dtype=torch.int32) Args: src: diff --git a/k2/python/csrc/torch/v2/ragged_any.cu b/k2/python/csrc/torch/v2/ragged_any.cu index ae9549b5f..987d3e598 100644 --- a/k2/python/csrc/torch/v2/ragged_any.cu +++ b/k2/python/csrc/torch/v2/ragged_any.cu @@ -35,6 +35,44 @@ namespace k2 { +static void PrintSpaces(std::ostream &os, int32_t num_spaces) { + K2_CHECK_GE(num_spaces, 0); + for (int32_t i = 0; i != num_spaces; ++i) os << " "; +} + +template +void RaggedAnyToStringIter(std::ostream &os, const Ragged ragged, + int32_t axis, int32_t begin_pos, int32_t end_pos, + int32_t num_indent) { + const auto &shape = ragged.shape; + K2_CHECK(axis >= 0 && axis < shape.NumAxes() && begin_pos >= 0 && + begin_pos <= end_pos && end_pos <= shape.TotSize(axis)); + std::string sep = ""; + bool is_first_row = true; + for (int32_t d = begin_pos; d < end_pos; d++) { + if (axis == shape.NumAxes() - 1) { + os << sep << ragged.values[d]; + sep = ", "; + } else { + const int32_t *row_splits = shape.RowSplits(axis + 1).Data(); + K2_DCHECK_LE(d, shape.RowSplits(axis + 1).Dim()); + int32_t row_start = row_splits[d], row_end = row_splits[d + 1]; + + if (!is_first_row) { + PrintSpaces(os, num_indent + 1); + } + is_first_row = false; + + os << "["; + + RaggedAnyToStringIter(os, ragged, axis + 1, row_start, row_end, + num_indent + 1); + os << "]"; + if (d != end_pos - 1) os << ",\n"; + } + } +} + /** One iteration of RaggedAnyFromList. @param data It is a list or a list-of sublist(s). @@ -154,21 +192,24 @@ RaggedAny::RaggedAny(const RaggedShape &shape, torch::Tensor value) K2_LOG(FATAL) << "Unsupported dtype: " << TraitsOf(t).Name(); } -RaggedAny::RaggedAny(const std::string &s, py::object dtype /*=py::none()*/) { +RaggedAny::RaggedAny(const std::string &s, py::object dtype /*=py::none()*/, + torch::Device device /*=torch::kCPU*/) { if (!dtype.is_none() && !THPDtype_Check(dtype.ptr())) { K2_LOG(FATAL) << "Expect an instance of torch.dtype. " << "Given: " << py::str(dtype); } + ContextPtr context = GetContext(device); + if (dtype.is_none()) { try { // We try int first, if it fails, use float - any = Ragged(s, /*throw_on_failure*/ true).Generic(); + any = Ragged(s, /*throw_on_failure*/ true).To(context).Generic(); return; } catch (const std::runtime_error &) { // Use float. If it fails again, another exception // is thrown and it is propagated to the user - any = Ragged(s).Generic(); + any = Ragged(s).To(context).Generic(); return; } } @@ -178,7 +219,7 @@ RaggedAny::RaggedAny(const std::string &s, py::object dtype /*=py::none()*/) { Dtype t = ScalarTypeToDtype(scalar_type); FOR_REAL_AND_INT32_TYPES(t, T, { - any = Ragged(s).Generic(); + any = Ragged(s).To(context).Generic(); return; }); @@ -187,21 +228,24 @@ RaggedAny::RaggedAny(const std::string &s, py::object dtype /*=py::none()*/) { << "and torch.float64"; } -RaggedAny::RaggedAny(py::list data, py::object dtype /*= py::none()*/) { +RaggedAny::RaggedAny(py::list data, py::object dtype /*= py::none()*/, + torch::Device device /*=torch::kCPU*/) { if (!dtype.is_none() && !THPDtype_Check(dtype.ptr())) { K2_LOG(FATAL) << "Expect an instance of torch.dtype. " << "Given: " << py::str(dtype); } + ContextPtr context = GetContext(device); + if (dtype.is_none()) { try { // We try int first; if it fails, use float - any = RaggedAnyFromList(data).Generic(); + any = RaggedAnyFromList(data).To(context).Generic(); return; } catch (const std::exception &) { // Use float. If it fails again, another exception // is thrown and it is propagated to the user - any = RaggedAnyFromList(data).Generic(); + any = RaggedAnyFromList(data).To(context).Generic(); return; } } @@ -211,7 +255,7 @@ RaggedAny::RaggedAny(py::list data, py::object dtype /*= py::none()*/) { Dtype t = ScalarTypeToDtype(scalar_type); FOR_REAL_AND_INT32_TYPES(t, T, { - any = RaggedAnyFromList(data).Generic(); + any = RaggedAnyFromList(data).To(context).Generic(); return; }); @@ -277,10 +321,33 @@ const torch::Tensor &RaggedAny::Data() const { return data; } -std::string RaggedAny::ToString() const { +std::string RaggedAny::ToString(int32_t device_id /*=-1*/) const { + ContextPtr context = any.Context(); + if (context->GetDeviceType() != kCpu) { + return To("cpu").ToString(context->GetDeviceId()); + } + std::ostringstream os; Dtype t = any.GetDtype(); - FOR_REAL_AND_INT32_TYPES(t, T, { os << any.Specialize(); }); + std::string dtype; + if (t == kInt32Dtype) + dtype = "torch.int32"; + else if (t == kFloatDtype) + dtype = "torch.float32"; + else if (t == kDoubleDtype) + dtype = "torch.float64"; + else + K2_LOG(FATAL) << "Unsupported dtype: " << TraitsOf(t).Name(); + + FOR_REAL_AND_INT32_TYPES(t, T, { + os << "RaggedTensor(["; + // 13 is strlen("RaggedTensor(") + RaggedAnyToStringIter(os, any.Specialize(), 0, 0, any.shape.Dim0(), 13); + os << "]"; + if (device_id != -1) os << ", device='cuda:" << device_id << "'"; + os << ", dtype=" << dtype; + os << ")"; + }); return os.str(); } diff --git a/k2/python/csrc/torch/v2/ragged_any.h b/k2/python/csrc/torch/v2/ragged_any.h index 5616b68c3..2ddec79cd 100644 --- a/k2/python/csrc/torch/v2/ragged_any.h +++ b/k2/python/csrc/torch/v2/ragged_any.h @@ -87,7 +87,12 @@ struct RaggedAny { @note We can support other dtypes if needed. */ - explicit RaggedAny(const std::string &s, py::object dtype = py::none()); + explicit RaggedAny(const std::string &s, py::object dtype = py::none(), + torch::Device device = torch::kCPU); + + explicit RaggedAny(const std::string &s, py::object dtype = py::none(), + const std::string &device = "cpu") + : RaggedAny(s, dtype, torch::Device(device)) {} /** Create a ragged tensor from a list of sublist(s). @@ -100,16 +105,22 @@ struct RaggedAny { @note It supports `data` with number of axes >= 2. */ - explicit RaggedAny(py::list data, py::object dtype = py::none()); + explicit RaggedAny(py::list data, py::object dtype = py::none(), + torch::Device device = torch::kCPU); + + explicit RaggedAny(py::list data, py::object dtype = py::none(), + const std::string device = "cpu") + : RaggedAny(data, dtype, torch::Device(device)) {} /// Populate `this->data` and return it const torch::Tensor &Data() const; /** Convert a ragged tensor to a string. + @param device_id -1 for CPU. 0 and above is for CUDA. @return Return a string representation of this tensor. */ - std::string ToString() const; + std::string ToString(int device_id = -1) const; /* Move a ragged tensor to a given device.