Skip to content

Commit ff1f94c

Browse files
committed
Add tests for creating ragged tensors from regular tensors.
1 parent 5fc2189 commit ff1f94c

File tree

2 files changed

+68
-25
lines changed

2 files changed

+68
-25
lines changed

k2/python/csrc/torch/v2/ragged_any.cu

+1
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ RaggedAny::RaggedAny(torch::Tensor tensor) {
224224
int32_t ndim = tensor.dim();
225225
K2_CHECK_GE(ndim, 2) << "Expect a tensor with more than 1-D";
226226
ContextPtr context = GetContext(tensor);
227+
DeviceGuard guard(context);
227228
std::vector<RaggedShape> shapes;
228229
shapes.reserve(ndim - 1);
229230
int32_t dim0 = tensor.size(0);

k2/python/tests/ragged_tensor_test.py

+67-25
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232

3333
class TestRaggedTensor(unittest.TestCase):
34-
3534
@classmethod
3635
def setUpClass(cls):
3736
cls.devices = [torch.device("cpu")]
@@ -65,6 +64,46 @@ def test_create_ragged_tensor_from_string(self):
6564
assert b.num_axes == 3
6665
assert b.dim0 == 2
6766

67+
def test_create_ragged_tensor_from_torch_tensor(self):
68+
for device in self.devices:
69+
for func in [k2r.create_ragged_tensor, k2r.RaggedTensor]:
70+
for dtype in self.dtypes:
71+
a = torch.arange(24, dtype=dtype, device=device).reshape(
72+
2, 3, 4
73+
)
74+
b = func(a)
75+
76+
# a is contiguous, so memory is shared
77+
c = a.reshape(-1)
78+
c[0] = 10
79+
assert b.values[0] == 10
80+
b.values[1] = 100
81+
assert c[1] == 100
82+
83+
assert b.dtype == dtype
84+
assert b.device == device
85+
86+
assert torch.all(torch.eq(c, b.values))
87+
88+
for device in self.devices:
89+
for func in [k2r.create_ragged_tensor, k2r.RaggedTensor]:
90+
for dtype in self.dtypes:
91+
a = torch.arange(100, dtype=dtype, device=device).reshape(
92+
10, 10
93+
)[:, ::2]
94+
b = func(a)
95+
assert b.dtype == dtype
96+
assert b.device == device
97+
98+
c = a.reshape(-1)
99+
assert torch.all(torch.eq(c, b.values))
100+
101+
# a is not contiguous, so memory is copied
102+
c[0] = -10
103+
assert b.values[0] != -10
104+
b.values[1] = -100
105+
assert c[1] != -100
106+
68107
def test_property_values(self):
69108
a = k2r.RaggedTensor([[1], [2], [], [3, 4]])
70109
assert torch.all(torch.eq(a.values, torch.tensor([1, 2, 3, 4])))
@@ -128,17 +167,17 @@ def test_sum_with_grad(self):
128167
a = a.to(device)
129168
a.requires_grad_(True)
130169
b = a.sum()
131-
expected_sum = torch.tensor([3, 0, 5],
132-
dtype=dtype,
133-
device=device)
170+
expected_sum = torch.tensor(
171+
[3, 0, 5], dtype=dtype, device=device
172+
)
134173

135174
assert torch.all(torch.eq(b, expected_sum))
136175

137176
c = b[0] * 10 + b[1] * 20 + b[2] * 30
138177
c.backward()
139-
expected_grad = torch.tensor([10, 10, 30],
140-
device=device,
141-
dtype=dtype)
178+
expected_grad = torch.tensor(
179+
[10, 10, 30], device=device, dtype=dtype
180+
)
142181
assert torch.all(torch.eq(a.grad, expected_grad))
143182

144183
def test_sum_no_grad(self):
@@ -147,26 +186,27 @@ def test_sum_no_grad(self):
147186
a = k2r.RaggedTensor([[1, 2], [], [5]], dtype=dtype)
148187
a = a.to(device)
149188
b = a.sum()
150-
expected_sum = torch.tensor([3, 0, 5],
151-
dtype=dtype,
152-
device=device)
189+
expected_sum = torch.tensor(
190+
[3, 0, 5], dtype=dtype, device=device
191+
)
153192

154193
assert torch.all(torch.eq(b, expected_sum))
155194

156195
def test_getitem(self):
157196
for device in self.devices:
158197
for dtype in self.dtypes:
159-
a = k2r.RaggedTensor("[ [[1 2] [] [10]] [[3] [5]] ]",
160-
dtype=dtype)
198+
a = k2r.RaggedTensor(
199+
"[ [[1 2] [] [10]] [[3] [5]] ]", dtype=dtype
200+
)
161201
a = a.to(device)
162202
b = a[0]
163-
expected = k2r.RaggedTensor("[[1 2] [] [10]]",
164-
dtype=dtype).to(device)
203+
expected = k2r.RaggedTensor("[[1 2] [] [10]]", dtype=dtype).to(
204+
device
205+
)
165206
assert b == expected
166207

167208
b = a[1]
168-
expected = k2r.RaggedTensor("[[3] [5]]",
169-
dtype=dtype).to(device)
209+
expected = k2r.RaggedTensor("[[3] [5]]", dtype=dtype).to(device)
170210
assert b == expected
171211

172212
def test_getstate_2axes(self):
@@ -177,9 +217,9 @@ def test_getstate_2axes(self):
177217
assert isinstance(b, tuple)
178218
assert len(b) == 3
179219
# b contains (row_splits, "row_ids1", values)
180-
b_0 = torch.tensor([0, 2, 3, 3],
181-
dtype=torch.int32,
182-
device=device)
220+
b_0 = torch.tensor(
221+
[0, 2, 3, 3], dtype=torch.int32, device=device
222+
)
183223
b_1 = "row_ids1"
184224
b_2 = a.values
185225

@@ -190,18 +230,19 @@ def test_getstate_2axes(self):
190230
def test_getstate_3axes(self):
191231
for device in self.devices:
192232
for dtype in self.dtypes:
193-
a = k2r.RaggedTensor("[[[1 2] [3] []] [[4] [5 6]]]",
194-
dtype=dtype).to(device)
233+
a = k2r.RaggedTensor(
234+
"[[[1 2] [3] []] [[4] [5 6]]]", dtype=dtype
235+
).to(device)
195236
b = a.__getstate__()
196237
assert isinstance(b, tuple)
197238
assert len(b) == 5
198239
# b contains (row_splits1, "row_ids1", row_splits2,
199240
# "row_ids2", values)
200241
b_0 = torch.tensor([0, 3, 5], dtype=torch.int32, device=device)
201242
b_1 = "row_ids1"
202-
b_2 = torch.tensor([0, 2, 3, 3, 4, 6],
203-
dtype=torch.int32,
204-
device=device) # noqa
243+
b_2 = torch.tensor(
244+
[0, 2, 3, 3, 4, 6], dtype=torch.int32, device=device
245+
) # noqa
205246
b_3 = "row_ids2"
206247
b_4 = a.values
207248

@@ -255,7 +296,8 @@ def test_tot_size_3axes(self):
255296
for dtype in self.dtypes:
256297
a = k2r.RaggedTensor(
257298
"[ [[1 2 3] [] [5 8]] [[] [1 5 9 10 -1] [] [] []] ]",
258-
dtype=dtype)
299+
dtype=dtype,
300+
)
259301
a = a.to(device)
260302

261303
assert a.tot_size(0) == 2

0 commit comments

Comments
 (0)