31
31
32
32
33
33
class TestRaggedTensor (unittest .TestCase ):
34
-
35
34
@classmethod
36
35
def setUpClass (cls ):
37
36
cls .devices = [torch .device ("cpu" )]
@@ -65,6 +64,46 @@ def test_create_ragged_tensor_from_string(self):
65
64
assert b .num_axes == 3
66
65
assert b .dim0 == 2
67
66
67
+ def test_create_ragged_tensor_from_torch_tensor (self ):
68
+ for device in self .devices :
69
+ for func in [k2r .create_ragged_tensor , k2r .RaggedTensor ]:
70
+ for dtype in self .dtypes :
71
+ a = torch .arange (24 , dtype = dtype , device = device ).reshape (
72
+ 2 , 3 , 4
73
+ )
74
+ b = func (a )
75
+
76
+ # a is contiguous, so memory is shared
77
+ c = a .reshape (- 1 )
78
+ c [0 ] = 10
79
+ assert b .values [0 ] == 10
80
+ b .values [1 ] = 100
81
+ assert c [1 ] == 100
82
+
83
+ assert b .dtype == dtype
84
+ assert b .device == device
85
+
86
+ assert torch .all (torch .eq (c , b .values ))
87
+
88
+ for device in self .devices :
89
+ for func in [k2r .create_ragged_tensor , k2r .RaggedTensor ]:
90
+ for dtype in self .dtypes :
91
+ a = torch .arange (100 , dtype = dtype , device = device ).reshape (
92
+ 10 , 10
93
+ )[:, ::2 ]
94
+ b = func (a )
95
+ assert b .dtype == dtype
96
+ assert b .device == device
97
+
98
+ c = a .reshape (- 1 )
99
+ assert torch .all (torch .eq (c , b .values ))
100
+
101
+ # a is not contiguous, so memory is copied
102
+ c [0 ] = - 10
103
+ assert b .values [0 ] != - 10
104
+ b .values [1 ] = - 100
105
+ assert c [1 ] != - 100
106
+
68
107
def test_property_values (self ):
69
108
a = k2r .RaggedTensor ([[1 ], [2 ], [], [3 , 4 ]])
70
109
assert torch .all (torch .eq (a .values , torch .tensor ([1 , 2 , 3 , 4 ])))
@@ -128,17 +167,17 @@ def test_sum_with_grad(self):
128
167
a = a .to (device )
129
168
a .requires_grad_ (True )
130
169
b = a .sum ()
131
- expected_sum = torch .tensor ([ 3 , 0 , 5 ],
132
- dtype = dtype ,
133
- device = device )
170
+ expected_sum = torch .tensor (
171
+ [ 3 , 0 , 5 ], dtype = dtype , device = device
172
+ )
134
173
135
174
assert torch .all (torch .eq (b , expected_sum ))
136
175
137
176
c = b [0 ] * 10 + b [1 ] * 20 + b [2 ] * 30
138
177
c .backward ()
139
- expected_grad = torch .tensor ([ 10 , 10 , 30 ],
140
- device = device ,
141
- dtype = dtype )
178
+ expected_grad = torch .tensor (
179
+ [ 10 , 10 , 30 ], device = device , dtype = dtype
180
+ )
142
181
assert torch .all (torch .eq (a .grad , expected_grad ))
143
182
144
183
def test_sum_no_grad (self ):
@@ -147,26 +186,27 @@ def test_sum_no_grad(self):
147
186
a = k2r .RaggedTensor ([[1 , 2 ], [], [5 ]], dtype = dtype )
148
187
a = a .to (device )
149
188
b = a .sum ()
150
- expected_sum = torch .tensor ([ 3 , 0 , 5 ],
151
- dtype = dtype ,
152
- device = device )
189
+ expected_sum = torch .tensor (
190
+ [ 3 , 0 , 5 ], dtype = dtype , device = device
191
+ )
153
192
154
193
assert torch .all (torch .eq (b , expected_sum ))
155
194
156
195
def test_getitem (self ):
157
196
for device in self .devices :
158
197
for dtype in self .dtypes :
159
- a = k2r .RaggedTensor ("[ [[1 2] [] [10]] [[3] [5]] ]" ,
160
- dtype = dtype )
198
+ a = k2r .RaggedTensor (
199
+ "[ [[1 2] [] [10]] [[3] [5]] ]" , dtype = dtype
200
+ )
161
201
a = a .to (device )
162
202
b = a [0 ]
163
- expected = k2r .RaggedTensor ("[[1 2] [] [10]]" ,
164
- dtype = dtype ).to (device )
203
+ expected = k2r .RaggedTensor ("[[1 2] [] [10]]" , dtype = dtype ).to (
204
+ device
205
+ )
165
206
assert b == expected
166
207
167
208
b = a [1 ]
168
- expected = k2r .RaggedTensor ("[[3] [5]]" ,
169
- dtype = dtype ).to (device )
209
+ expected = k2r .RaggedTensor ("[[3] [5]]" , dtype = dtype ).to (device )
170
210
assert b == expected
171
211
172
212
def test_getstate_2axes (self ):
@@ -177,9 +217,9 @@ def test_getstate_2axes(self):
177
217
assert isinstance (b , tuple )
178
218
assert len (b ) == 3
179
219
# b contains (row_splits, "row_ids1", values)
180
- b_0 = torch .tensor ([ 0 , 2 , 3 , 3 ],
181
- dtype = torch .int32 ,
182
- device = device )
220
+ b_0 = torch .tensor (
221
+ [ 0 , 2 , 3 , 3 ], dtype = torch .int32 , device = device
222
+ )
183
223
b_1 = "row_ids1"
184
224
b_2 = a .values
185
225
@@ -190,18 +230,19 @@ def test_getstate_2axes(self):
190
230
def test_getstate_3axes (self ):
191
231
for device in self .devices :
192
232
for dtype in self .dtypes :
193
- a = k2r .RaggedTensor ("[[[1 2] [3] []] [[4] [5 6]]]" ,
194
- dtype = dtype ).to (device )
233
+ a = k2r .RaggedTensor (
234
+ "[[[1 2] [3] []] [[4] [5 6]]]" , dtype = dtype
235
+ ).to (device )
195
236
b = a .__getstate__ ()
196
237
assert isinstance (b , tuple )
197
238
assert len (b ) == 5
198
239
# b contains (row_splits1, "row_ids1", row_splits2,
199
240
# "row_ids2", values)
200
241
b_0 = torch .tensor ([0 , 3 , 5 ], dtype = torch .int32 , device = device )
201
242
b_1 = "row_ids1"
202
- b_2 = torch .tensor ([ 0 , 2 , 3 , 3 , 4 , 6 ],
203
- dtype = torch .int32 ,
204
- device = device ) # noqa
243
+ b_2 = torch .tensor (
244
+ [ 0 , 2 , 3 , 3 , 4 , 6 ], dtype = torch .int32 , device = device
245
+ ) # noqa
205
246
b_3 = "row_ids2"
206
247
b_4 = a .values
207
248
@@ -255,7 +296,8 @@ def test_tot_size_3axes(self):
255
296
for dtype in self .dtypes :
256
297
a = k2r .RaggedTensor (
257
298
"[ [[1 2 3] [] [5 8]] [[] [1 5 9 10 -1] [] [] []] ]" ,
258
- dtype = dtype )
299
+ dtype = dtype ,
300
+ )
259
301
a = a .to (device )
260
302
261
303
assert a .tot_size (0 ) == 2
0 commit comments