Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 6acc9e3

Browse files
committedNov 13, 2021
Add more doc.
1 parent ef3568b commit 6acc9e3

File tree

7 files changed

+257
-18
lines changed

7 files changed

+257
-18
lines changed
 

‎docs/source/python_tutorials/ragged/basics.rst

+110-14
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ In this tutorial, we describe
1313
- What is ``RaggedShape``?
1414
- What is ``row_splits`` ?
1515
- What is ``row_ids`` ?
16-
- What is ``dim0`` ?
17-
- What is ``tot_size`` ?
1816

1917
What are ragged tensors?
2018
------------------------
@@ -29,7 +27,7 @@ tensors, i.e., regular tensors, look like.
2927
:lines: 8-20
3028

3129
The shape of the 2-D regular tensor ``a`` is ``(3, 4)``, meaning it has 3
32-
rows and 4 columns. Each row has **exactly** 4 elements, no more, no less.
30+
rows and 4 columns. Each row has **exactly** 4 elements.
3331

3432
- 3-D regular tensors
3533

@@ -38,8 +36,8 @@ tensors, i.e., regular tensors, look like.
3836
:lines: 24-45
3937

4038
The shape of the 3-D regular tensor ``b`` is ``(3, 3, 2)``, meaning it has
41-
3 planes. Each plane has **exactly** 3 rows, no more, no less. Each row has
42-
**exactly** two entries, no more, no less.
39+
3 planes. Each plane has **exactly** 3 rows and each row has **exactly** two
40+
entries
4341

4442
- N-D regular tensors (N >= 4)
4543

@@ -89,7 +87,7 @@ tensors in ``k2``.
8987
A ragged tensor in ``k2`` has ``N`` (``N >= 2``) axes. Unlike regular tensors,
9088
each axis of a ragged tensor can have different number of elements.
9189

92-
Ragged tensors are **the most important** data structures in ``k2``. FSAs are
90+
Ragged tensors are **the most important** data structure in ``k2``. FSAs are
9391
represented as ragged tensors. There are also various operations defined on ragged
9492
tensors.
9593

@@ -113,7 +111,7 @@ Exercise 1
113111
- Row 1 is empty, i.e., it has no elements.
114112
- Row 2 has two elements: ``-1.5, 2``
115113

116-
(Click ▶ to see it)
114+
(Click ▶ to view the solution)
117115

118116
.. literalinclude:: code/basics/ragged-tensors.py
119117
:language: python
@@ -130,11 +128,34 @@ Exercise 2
130128

131129
How to create a ragged tensor with only 1 axis?
132130

133-
(Click ▶ to see it)
131+
(Click ▶ to view the solution)
134132

135133
You **cannot** create a ragged tensor with only 1 axis. Ragged tensors
136134
in ``k2`` have at least 2 axes.
137135

136+
dtype and device
137+
^^^^^^^^^^^^^^^^
138+
139+
Like tensors in PyTorch. ragged tensors in ``k2`` has attributes ``dtype`` and
140+
``device``. The following code shows that you can specify the ``dtype`` and
141+
``device`` while constructing ragged tensors.
142+
143+
.. literalinclude:: code/basics/dtype-device.py
144+
:language: python
145+
:lines: 3-23
146+
147+
.. container:: toggle
148+
149+
.. container:: header
150+
151+
.. Note::
152+
153+
(Click ▶ to view the output)
154+
155+
.. literalinclude:: code/basics/dtype-device.py
156+
:language: python
157+
:lines: 25-50
158+
138159
Concepts about ragged tensors
139160
-----------------------------
140161

@@ -144,18 +165,18 @@ A ragged tensor in ``k2`` consists of two parts:
144165

145166
.. Caution::
146167

147-
It is assumed that a shape within a ragged tensor in ``k2`` is a constant.
168+
It is assumed that a shape within a ragged tensor in ``k2`` is a constant.
148169
Once constructed, you are not expected to modify it. Otherwise, unexpected
149170
things can happen; you will be SAD.
150171

151-
- ``data``, which is an **array** of type ``T``
172+
- ``values``, which is an **array** of type ``T``
152173

153174
.. Hint::
154175

155-
``data`` is stored ``contiguously`` in memory, whose entries have to be
176+
``values`` is stored ``contiguously`` in memory, whose entries have to be
156177
of the same type ``T``. ``T`` can be either primitive types, such as
157178
``int``, ``float``, and ``double`` or can be user defined types. For instance,
158-
``data`` in FSAs contains ``arcs``, which is defined in C++
179+
``values`` in FSAs contains ``arcs``, which is defined in C++
159180
`as follows <https://github.com/k2-fsa/k2/blob/master/k2/csrc/fsa.h#L31>`_:
160181

161182
.. code-block:: c++
@@ -167,8 +188,83 @@ A ragged tensor in ``k2`` consists of two parts:
167188
float score;
168189
}
169190

170-
In the following, we describe what is inside a ``shape`` and how to manipulate
171-
``data``.
191+
Before explaining what ``shape`` and ``values`` contain, let us look at an example of
192+
how to use a ragged tensor to represent the following
193+
FSA (see :numref:`ragged_basics_simple_fsa_1`).
194+
195+
.. _ragged_basics_simple_fsa_1:
196+
.. figure:: code/basics/images/simple-fsa.svg
197+
:alt: A simple FSA
198+
:align: center
199+
:figwidth: 600px
200+
201+
An simple FSA that is to be represented by a ragged tensor.
202+
203+
The FSA in :numref:`ragged_basics_simple_fsa_1` has 3 arcs and 3 states.
204+
205+
+---------+--------------------+--------------------+--------------------+--------------------+
206+
| | src_state | dst_state | label | score |
207+
+---------+--------------------+--------------------+--------------------+--------------------+
208+
| Arc 0 | 0 | 1 | 1 | 0.1 |
209+
+---------+--------------------+--------------------+--------------------+--------------------+
210+
| Arc 1 | 0 | 1 | 2 | 0.2 |
211+
+---------+--------------------+--------------------+--------------------+--------------------+
212+
| Arc 2 | 1 | 2 | -1 | 0.3 |
213+
+---------+--------------------+--------------------+--------------------+--------------------+
214+
215+
When the above FSA is saved in a ragged tensor, its arcs are saved in a 1-D contiguous
216+
``values`` array containing ``[Arc0, Arc1, Arc2]``.
217+
At this point, you might ask:
218+
219+
- As we can construct the original FSA by using the ``values`` array,
220+
what's the point of saving it in a ragged tensor?
221+
222+
Using the ``values`` array alone is not possible to answer the following questions in ``O(1)``
223+
time:
224+
225+
- How many states does the FSA have ?
226+
- How many arcs does each state have ?
227+
- Where do the arcs belonging to state 0 start in the ``values`` array ?
228+
229+
To handle the above questions, we introduce another 1-D array, called ``row_splits``.
230+
``row_splits[s] = p`` means for state ``s`` its first outgoing arc starts at position
231+
``p`` in the ``values`` array. As a side effect, it also indicates that the last outgoing
232+
arc for state ``s-1`` ends at position ``p`` (exclusive) in the ``values`` array.
233+
234+
In our example, ``row_splits`` would be ``[0, 2, 3, 3]``, meaning:
235+
236+
- The first outgoing arc for state 0 is at position ``row_splits[0] = 0``
237+
in the ``values`` array
238+
- State 0 has ``row_splits[1] - row_splits[0] = 2 - 0 = 2`` arcs
239+
- The first outgoing arc for state 1 is at position ``row_splits[1] = 2``
240+
in the ``values`` array
241+
- State 1 has ``row_splits[2] - row_splits[1] = 3 - 2 = 1`` arc
242+
- State 2 has no arcs since ``row_splits[3] - row_splits[2] = 3 - 3 = 0``
243+
- The FSA has ``len(row_splits) - 1 = 3`` states.
244+
245+
We can construct a ``RaggedShape`` from a ``row_splits`` array:
246+
247+
.. literalinclude:: code/basics/ragged_shape_1.py
248+
:language: python
249+
:lines: 3-14
250+
251+
Pay attention to the string form of the shape ``[ [x x] [x] [ ] ]``.
252+
``x`` means we don't care about the actual content inside a ragged tensor.
253+
The above shape has 2 axes, 3 rows, and 3 elements. Row 0 has two elements as there
254+
are two ``x`` inside the 0th ``[]``. Row 1 has only one element, while
255+
row 2 has no elements at all. We can assign names to the axes. In our case,
256+
we say the shape has axes ``[state][arc]``.
257+
258+
Combining the ragged shape and the ``values`` array, the above FSA can
259+
be represented using a ragged tensor ``[ [Arc0 Arc1] [Arc2] [ ] ]``.
260+
261+
The following code displays the string from of the above FSA when represented
262+
as a ragged tensor in k2:
263+
264+
.. literalinclude:: code/basics/single-fsa.py
265+
:language: python
266+
:lines: 2-14
267+
172268

173269
Shape
174270
^^^^^
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/usr/bin/env python3
2+
3+
import k2
4+
import torch
5+
6+
a = k2.RaggedTensor([[1, 2], [1]])
7+
b = k2.RaggedTensor([[1, 2], [1]], dtype=torch.int32)
8+
c = k2.RaggedTensor([[1, 2], [1.5]])
9+
d = k2.RaggedTensor([[1, 2], [1.5]], dtype=torch.float32)
10+
e = k2.RaggedTensor([[1, 2], [1.5]], dtype=torch.float64)
11+
f = k2.RaggedTensor([[1, 2], [1]], dtype=torch.float32, device=torch.device("cuda", 0))
12+
g = k2.RaggedTensor([[1, 2], [1]], device="cuda:0", dtype=torch.float64)
13+
print(f"a:\n{a}")
14+
print(f"b:\n{b}")
15+
print(f"c:\n{c}")
16+
print(f"d:\n{d}")
17+
print(f"e:\n{e}")
18+
print(f"f:\n{f}")
19+
print(f"g:\n{g}")
20+
print(f"g.to_str_simple():\n{g.to_str_simple()}")
21+
print(f"a.dtype: {a.dtype}, g.device: {g.device}")
22+
print(f"a.to(g.device).device: {a.to(g.device).device}")
23+
print(f"a.to(g.dtype).dtype: {a.to(g.dtype).dtype}")
24+
"""
25+
a:
26+
RaggedTensor([[1, 2],
27+
[1]], dtype=torch.int32)
28+
b:
29+
RaggedTensor([[1, 2],
30+
[1]], dtype=torch.int32)
31+
c:
32+
RaggedTensor([[1, 2],
33+
[1.5]], dtype=torch.float32)
34+
d:
35+
RaggedTensor([[1, 2],
36+
[1.5]], dtype=torch.float32)
37+
e:
38+
RaggedTensor([[1, 2],
39+
[1.5]], dtype=torch.float64)
40+
f:
41+
RaggedTensor([[1, 2],
42+
[1]], device='cuda:0', dtype=torch.float32)
43+
g:
44+
RaggedTensor([[1, 2],
45+
[1]], device='cuda:0', dtype=torch.float64)
46+
g.to_str_simple():
47+
RaggedTensor([[1, 2], [1]], device='cuda:0', dtype=torch.float64)
48+
a.dtype: torch.int32, g.device: cuda:0
49+
a.to(g.device).device: cuda:0
50+
a.to(g.dtype).dtype: torch.float64
51+
"""
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/usr/bin/env python3
2+
3+
import k2
4+
import torch
5+
6+
shape = k2.ragged.create_ragged_shape2(
7+
row_splits=torch.tensor([0, 2, 3, 3], dtype=torch.int32),
8+
)
9+
print(type(shape))
10+
print(shape)
11+
"""
12+
<class '_k2.ragged.RaggedShape'>
13+
[ [ x x ] [ x ] [ ] ]
14+
"""
15+
print("num_states:", shape.dim0)
16+
print("num_arcs:", shape.numel())
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#!/usr/bin/env python3
2+
import k2
3+
4+
s = """
5+
0 1 1 0.1
6+
0 1 2 0.2
7+
1 2 -1 0.3
8+
2
9+
"""
10+
fsa = k2.Fsa.from_str(s)
11+
print(fsa.arcs)
12+
"""
13+
[ [ 0 1 1 0.1 0 1 2 0.2 ] [ 1 2 -1 0.3 ] [ ] ]
14+
"""
15+
16+
sym_str = """
17+
a 1
18+
b 2
19+
"""
20+
21+
# fsa.labels_sym = k2.SymbolTable.from_str(sym_str)
22+
# fsa.draw("images/simple-fsa.svg")
23+
# print(k2.to_dot(fsa))

‎k2/python/csrc/torch/v2/ragged_any.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ std::string RaggedAny::ToString(bool compact /*=false*/,
335335
int32_t device_id /*=-1*/) const {
336336
ContextPtr context = any.Context();
337337
if (context->GetDeviceType() != kCpu) {
338-
return To("cpu").ToString(context->GetDeviceId());
338+
return To("cpu").ToString(compact, context->GetDeviceId());
339339
}
340340

341341
std::ostringstream os;

‎k2/python/csrc/torch/v2/ragged_shape.cu

+3-3
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ void PybindRaggedShape(py::module &m) {
232232

233233
m.def(
234234
"create_ragged_shape2",
235-
[](torch::optional<torch::Tensor> row_splits,
236-
torch::optional<torch::Tensor> row_ids,
235+
[](torch::optional<torch::Tensor> row_splits = torch::nullopt,
236+
torch::optional<torch::Tensor> row_ids = torch::nullopt,
237237
int32_t cached_tot_size = -1) -> RaggedShape {
238238
if (!row_splits.has_value() && !row_ids.has_value())
239239
K2_LOG(FATAL) << "Both row_splits and row_ids are None";
@@ -257,7 +257,7 @@ void PybindRaggedShape(py::module &m) {
257257
row_splits.has_value() ? &array_row_splits : nullptr,
258258
row_ids.has_value() ? &array_row_ids : nullptr, cached_tot_size);
259259
},
260-
py::arg("row_splits"), py::arg("row_ids"),
260+
py::arg("row_splits") = py::none(), py::arg("row_ids") = py::none(),
261261
py::arg("cached_tot_size") = -1, kCreateRaggedShape2Doc);
262262

263263
m.def("random_ragged_shape", &RandomRaggedShape, "RandomRaggedShape",

0 commit comments

Comments
 (0)
Please sign in to comment.