@@ -54,10 +54,12 @@ void PybindRaggedAny(py::module &m) {
54
54
}),
55
55
py::arg (" s" ), py::arg (" dtype" ) = py::none (), kRaggedAnyInitStrDoc );
56
56
57
- // TODO(fangjun): add documentation for it
58
57
any.def (py::init<const RaggedShape &, torch::Tensor>(), py::arg (" shape" ),
59
58
py::arg (" value" ), kRaggedInitFromShapeAndTensorDoc );
60
59
60
+ any.def (py::init<torch::Tensor>(), py::arg (" tensor" ),
61
+ kRaggedAnyInitTensorDoc );
62
+
61
63
any.def (
62
64
" __str__" ,
63
65
[](const RaggedAny &self) -> std::string { return self.ToString (); },
@@ -78,8 +80,7 @@ void PybindRaggedAny(py::module &m) {
78
80
K2_CHECK_EQ (self.any .NumAxes (), 2 );
79
81
Array1<int32_t > row_split = self.any .RowSplits (1 ).To (GetCpuContext ());
80
82
const int32_t *row_split_data = row_split.Data ();
81
- int32_t begin = row_split_data[i],
82
- end = row_split_data[i + 1 ];
83
+ int32_t begin = row_split_data[i], end = row_split_data[i + 1 ];
83
84
Dtype t = self.any .GetDtype ();
84
85
FOR_REAL_AND_INT32_TYPES (t, T, {
85
86
Array1<T> array =
@@ -100,19 +101,18 @@ void PybindRaggedAny(py::module &m) {
100
101
if (!slice.compute (self.any .Dim0 (), &start, &stop, &step, &slicelength))
101
102
throw py::error_already_set ();
102
103
int32_t istart = static_cast <int32_t >(start);
103
- int32_t istop = static_cast <int32_t >(stop);
104
- int32_t istep = static_cast <int32_t >(step);
105
- K2_CHECK_EQ (istep, 1 ) << " Only support slicing with step 1, given : "
106
- << istep;
104
+ int32_t istop = static_cast <int32_t >(stop);
105
+ int32_t istep = static_cast <int32_t >(step);
106
+ K2_CHECK_EQ (istep, 1 )
107
+ << " Only support slicing with step 1, given : " << istep;
107
108
108
109
return self.Arange (/* axis*/ 0 , istart, istop);
109
- }, py::arg (" key" ), kRaggedAnyGetItemSliceDoc );
110
+ },
111
+ py::arg (" key" ), kRaggedAnyGetItemSliceDoc );
110
112
111
113
any.def (" index" ,
112
- static_cast <RaggedAny (RaggedAny::*)(RaggedAny &)>(
113
- &RaggedAny::Index),
114
- py::arg (" indexes" ),
115
- kRaggedAnyRaggedIndexDoc );
114
+ static_cast <RaggedAny (RaggedAny::*)(RaggedAny &)>(&RaggedAny::Index),
115
+ py::arg (" indexes" ), kRaggedAnyRaggedIndexDoc );
116
116
117
117
any.def (" index" ,
118
118
static_cast <std::pair<RaggedAny, torch::optional<torch::Tensor>> (
@@ -423,6 +423,11 @@ void PybindRaggedAny(py::module &m) {
423
423
return RaggedAny (s, dtype);
424
424
},
425
425
py::arg (" s" ), py::arg (" dtype" ) = py::none (), kCreateRaggedTensorStrDoc );
426
+
427
+ m.def (
428
+ " create_ragged_tensor" ,
429
+ [](torch::Tensor tensor) -> RaggedAny { return RaggedAny (tensor); },
430
+ py::arg (" tensor" ), kCreateRaggedTensorTensorDoc );
426
431
}
427
432
428
433
} // namespace k2
0 commit comments