Skip to content

Commit 5fc2189

Browse files
committed
Create a ragged tensor from a regular tensor.
1 parent 210175c commit 5fc2189

File tree

7 files changed

+269
-35
lines changed

7 files changed

+269
-35
lines changed

.github/workflows/build-doc.yml

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ on:
2222
push:
2323
branches:
2424
- master
25+
- doc
2526

2627
env:
2728
# debug is faster in terms of compilation time

k2/python/csrc/torch/torch_util.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ torch::Tensor ToTorch(Array1<Arc> &array) {
105105
}
106106

107107
template <>
108-
Array1<Arc> FromTorch<Arc>(torch::Tensor &tensor) {
108+
Array1<Arc> FromTorch<Arc>(torch::Tensor tensor) {
109109
K2_CHECK_EQ(tensor.dim(), 2) << "Expected dim: 2. Given: " << tensor.dim();
110110
K2_CHECK_EQ(tensor.scalar_type(), ToScalarType<int32_t>::value)
111111
<< "Expected scalar type: " << ToScalarType<int32_t>::value
@@ -124,7 +124,7 @@ Array1<Arc> FromTorch<Arc>(torch::Tensor &tensor) {
124124
return ans;
125125
}
126126

127-
Tensor FromTorch(torch::Tensor &tensor, TensorTag) {
127+
Tensor FromTorch(torch::Tensor tensor, TensorTag) {
128128
Dtype dtype = ScalarTypeToDtype(tensor.scalar_type());
129129
torch::IntArrayRef sizes = tensor.sizes();
130130
torch::IntArrayRef strides = tensor.strides();

k2/python/csrc/torch/torch_util.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ torch::Tensor ToTorch(Array1<T> &array) {
113113
input tensor.
114114
*/
115115
template <typename T>
116-
Array1<T> FromTorch(torch::Tensor &tensor) {
116+
Array1<T> FromTorch(torch::Tensor tensor) {
117117
K2_CHECK_EQ(tensor.dim(), 1) << "Expected dim: 1. Given: " << tensor.dim();
118118
K2_CHECK_EQ(tensor.scalar_type(), ToScalarType<T>::value)
119119
<< "Expected scalar type: " << ToScalarType<T>::value
@@ -158,12 +158,12 @@ torch::Tensor ToTorch(Array1<Arc> &array);
158158
the input tensor.
159159
*/
160160
template <>
161-
Array1<Arc> FromTorch<Arc>(torch::Tensor &tensor);
161+
Array1<Arc> FromTorch<Arc>(torch::Tensor tensor);
162162

163163
struct Array2Tag {};
164164

165165
template <typename T>
166-
Array2<T> FromTorch(torch::Tensor &tensor, Array2Tag) {
166+
Array2<T> FromTorch(torch::Tensor tensor, Array2Tag) {
167167
K2_CHECK_EQ(tensor.dim(), 2) << "Expected dim: 2. Given: " << tensor.dim();
168168
K2_CHECK_EQ(tensor.scalar_type(), ToScalarType<T>::value)
169169
<< "Expected scalar type: " << ToScalarType<T>::value
@@ -199,7 +199,7 @@ torch::Tensor ToTorch(Array2<T> &array) {
199199

200200
struct TensorTag {};
201201

202-
Tensor FromTorch(torch::Tensor &tensor, TensorTag);
202+
Tensor FromTorch(torch::Tensor tensor, TensorTag);
203203
torch::Tensor ToTorch(Tensor &tensor);
204204

205205
/* Transfer an object to a specific device.

k2/python/csrc/torch/v2/any.cu

+17-12
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,12 @@ void PybindRaggedAny(py::module &m) {
5454
}),
5555
py::arg("s"), py::arg("dtype") = py::none(), kRaggedAnyInitStrDoc);
5656

57-
// TODO(fangjun): add documentation for it
5857
any.def(py::init<const RaggedShape &, torch::Tensor>(), py::arg("shape"),
5958
py::arg("value"), kRaggedInitFromShapeAndTensorDoc);
6059

60+
any.def(py::init<torch::Tensor>(), py::arg("tensor"),
61+
kRaggedAnyInitTensorDoc);
62+
6163
any.def(
6264
"__str__",
6365
[](const RaggedAny &self) -> std::string { return self.ToString(); },
@@ -78,8 +80,7 @@ void PybindRaggedAny(py::module &m) {
7880
K2_CHECK_EQ(self.any.NumAxes(), 2);
7981
Array1<int32_t> row_split = self.any.RowSplits(1).To(GetCpuContext());
8082
const int32_t *row_split_data = row_split.Data();
81-
int32_t begin = row_split_data[i],
82-
end = row_split_data[i + 1];
83+
int32_t begin = row_split_data[i], end = row_split_data[i + 1];
8384
Dtype t = self.any.GetDtype();
8485
FOR_REAL_AND_INT32_TYPES(t, T, {
8586
Array1<T> array =
@@ -100,19 +101,18 @@ void PybindRaggedAny(py::module &m) {
100101
if (!slice.compute(self.any.Dim0(), &start, &stop, &step, &slicelength))
101102
throw py::error_already_set();
102103
int32_t istart = static_cast<int32_t>(start);
103-
int32_t istop = static_cast<int32_t>(stop);
104-
int32_t istep = static_cast<int32_t>(step);
105-
K2_CHECK_EQ(istep, 1) << "Only support slicing with step 1, given : "
106-
<< istep;
104+
int32_t istop = static_cast<int32_t>(stop);
105+
int32_t istep = static_cast<int32_t>(step);
106+
K2_CHECK_EQ(istep, 1)
107+
<< "Only support slicing with step 1, given : " << istep;
107108

108109
return self.Arange(/*axis*/ 0, istart, istop);
109-
}, py::arg("key"), kRaggedAnyGetItemSliceDoc);
110+
},
111+
py::arg("key"), kRaggedAnyGetItemSliceDoc);
110112

111113
any.def("index",
112-
static_cast<RaggedAny (RaggedAny::*)(RaggedAny &)>(
113-
&RaggedAny::Index),
114-
py::arg("indexes"),
115-
kRaggedAnyRaggedIndexDoc);
114+
static_cast<RaggedAny (RaggedAny::*)(RaggedAny &)>(&RaggedAny::Index),
115+
py::arg("indexes"), kRaggedAnyRaggedIndexDoc);
116116

117117
any.def("index",
118118
static_cast<std::pair<RaggedAny, torch::optional<torch::Tensor>> (
@@ -423,6 +423,11 @@ void PybindRaggedAny(py::module &m) {
423423
return RaggedAny(s, dtype);
424424
},
425425
py::arg("s"), py::arg("dtype") = py::none(), kCreateRaggedTensorStrDoc);
426+
427+
m.def(
428+
"create_ragged_tensor",
429+
[](torch::Tensor tensor) -> RaggedAny { return RaggedAny(tensor); },
430+
py::arg("tensor"), kCreateRaggedTensorTensorDoc);
426431
}
427432

428433
} // namespace k2

0 commit comments

Comments
 (0)