Skip to content

Commit

Permalink
Create a ragged tensor from a regular tensor.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Sep 15, 2021
1 parent 210175c commit 5fc2189
Show file tree
Hide file tree
Showing 7 changed files with 269 additions and 35 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build-doc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ on:
push:
branches:
- master
- doc

env:
# debug is faster in terms of compilation time
Expand Down
4 changes: 2 additions & 2 deletions k2/python/csrc/torch/torch_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ torch::Tensor ToTorch(Array1<Arc> &array) {
}

template <>
Array1<Arc> FromTorch<Arc>(torch::Tensor &tensor) {
Array1<Arc> FromTorch<Arc>(torch::Tensor tensor) {
K2_CHECK_EQ(tensor.dim(), 2) << "Expected dim: 2. Given: " << tensor.dim();
K2_CHECK_EQ(tensor.scalar_type(), ToScalarType<int32_t>::value)
<< "Expected scalar type: " << ToScalarType<int32_t>::value
Expand All @@ -124,7 +124,7 @@ Array1<Arc> FromTorch<Arc>(torch::Tensor &tensor) {
return ans;
}

Tensor FromTorch(torch::Tensor &tensor, TensorTag) {
Tensor FromTorch(torch::Tensor tensor, TensorTag) {
Dtype dtype = ScalarTypeToDtype(tensor.scalar_type());
torch::IntArrayRef sizes = tensor.sizes();
torch::IntArrayRef strides = tensor.strides();
Expand Down
8 changes: 4 additions & 4 deletions k2/python/csrc/torch/torch_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ torch::Tensor ToTorch(Array1<T> &array) {
input tensor.
*/
template <typename T>
Array1<T> FromTorch(torch::Tensor &tensor) {
Array1<T> FromTorch(torch::Tensor tensor) {
K2_CHECK_EQ(tensor.dim(), 1) << "Expected dim: 1. Given: " << tensor.dim();
K2_CHECK_EQ(tensor.scalar_type(), ToScalarType<T>::value)
<< "Expected scalar type: " << ToScalarType<T>::value
Expand Down Expand Up @@ -158,12 +158,12 @@ torch::Tensor ToTorch(Array1<Arc> &array);
the input tensor.
*/
template <>
Array1<Arc> FromTorch<Arc>(torch::Tensor &tensor);
Array1<Arc> FromTorch<Arc>(torch::Tensor tensor);

struct Array2Tag {};

template <typename T>
Array2<T> FromTorch(torch::Tensor &tensor, Array2Tag) {
Array2<T> FromTorch(torch::Tensor tensor, Array2Tag) {
K2_CHECK_EQ(tensor.dim(), 2) << "Expected dim: 2. Given: " << tensor.dim();
K2_CHECK_EQ(tensor.scalar_type(), ToScalarType<T>::value)
<< "Expected scalar type: " << ToScalarType<T>::value
Expand Down Expand Up @@ -199,7 +199,7 @@ torch::Tensor ToTorch(Array2<T> &array) {

struct TensorTag {};

Tensor FromTorch(torch::Tensor &tensor, TensorTag);
Tensor FromTorch(torch::Tensor tensor, TensorTag);
torch::Tensor ToTorch(Tensor &tensor);

/* Transfer an object to a specific device.
Expand Down
29 changes: 17 additions & 12 deletions k2/python/csrc/torch/v2/any.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@ void PybindRaggedAny(py::module &m) {
}),
py::arg("s"), py::arg("dtype") = py::none(), kRaggedAnyInitStrDoc);

// TODO(fangjun): add documentation for it
any.def(py::init<const RaggedShape &, torch::Tensor>(), py::arg("shape"),
py::arg("value"), kRaggedInitFromShapeAndTensorDoc);

any.def(py::init<torch::Tensor>(), py::arg("tensor"),
kRaggedAnyInitTensorDoc);

any.def(
"__str__",
[](const RaggedAny &self) -> std::string { return self.ToString(); },
Expand All @@ -78,8 +80,7 @@ void PybindRaggedAny(py::module &m) {
K2_CHECK_EQ(self.any.NumAxes(), 2);
Array1<int32_t> row_split = self.any.RowSplits(1).To(GetCpuContext());
const int32_t *row_split_data = row_split.Data();
int32_t begin = row_split_data[i],
end = row_split_data[i + 1];
int32_t begin = row_split_data[i], end = row_split_data[i + 1];
Dtype t = self.any.GetDtype();
FOR_REAL_AND_INT32_TYPES(t, T, {
Array1<T> array =
Expand All @@ -100,19 +101,18 @@ void PybindRaggedAny(py::module &m) {
if (!slice.compute(self.any.Dim0(), &start, &stop, &step, &slicelength))
throw py::error_already_set();
int32_t istart = static_cast<int32_t>(start);
int32_t istop = static_cast<int32_t>(stop);
int32_t istep = static_cast<int32_t>(step);
K2_CHECK_EQ(istep, 1) << "Only support slicing with step 1, given : "
<< istep;
int32_t istop = static_cast<int32_t>(stop);
int32_t istep = static_cast<int32_t>(step);
K2_CHECK_EQ(istep, 1)
<< "Only support slicing with step 1, given : " << istep;

return self.Arange(/*axis*/ 0, istart, istop);
}, py::arg("key"), kRaggedAnyGetItemSliceDoc);
},
py::arg("key"), kRaggedAnyGetItemSliceDoc);

any.def("index",
static_cast<RaggedAny (RaggedAny::*)(RaggedAny &)>(
&RaggedAny::Index),
py::arg("indexes"),
kRaggedAnyRaggedIndexDoc);
static_cast<RaggedAny (RaggedAny::*)(RaggedAny &)>(&RaggedAny::Index),
py::arg("indexes"), kRaggedAnyRaggedIndexDoc);

any.def("index",
static_cast<std::pair<RaggedAny, torch::optional<torch::Tensor>> (
Expand Down Expand Up @@ -423,6 +423,11 @@ void PybindRaggedAny(py::module &m) {
return RaggedAny(s, dtype);
},
py::arg("s"), py::arg("dtype") = py::none(), kCreateRaggedTensorStrDoc);

m.def(
"create_ragged_tensor",
[](torch::Tensor tensor) -> RaggedAny { return RaggedAny(tensor); },
py::arg("tensor"), kCreateRaggedTensorTensorDoc);
}

} // namespace k2
Loading

0 comments on commit 5fc2189

Please sign in to comment.