|
25 | 25 | from tensorflow.python.ops.ragged import ragged_factory_ops
|
26 | 26 | from tensorflow.python.platform import test
|
27 | 27 | from tensorflow_text.python.ops import greedy_constrained_sequence_op as greedy_op
|
| 28 | +from tensorflow_text.python.ops import ragged_test_util |
28 | 29 |
|
29 | 30 |
|
30 | 31 | # TODO(b/122968457): Refactor this test logic.
|
31 | 32 | @test_util.run_all_in_graph_and_eager_modes
|
32 |
| -class GreedyConstrainedSequenceOpTest(test_util.TensorFlowTestCase): |
| 33 | +class GreedyConstrainedSequenceOpTest(ragged_test_util.RaggedTensorTestCase): |
33 | 34 |
|
34 | 35 | def _last_max(self, array):
|
35 | 36 | """Helper function that matches the maximum behaviour in the C++ op."""
|
@@ -134,7 +135,7 @@ def test_sequence_in_exp_space_with_start_end_states_single_batch_item(self):
|
134 | 135 | use_log_space=use_log_space,
|
135 | 136 | use_start_and_end_states=use_start_and_end_states)
|
136 | 137 | single_result = self.evaluate(single_sequence_op)
|
137 |
| - self.assertAllEqual(single_result, [sequence]) |
| 138 | + self.assertRaggedEqual(single_result, [sequence]) |
138 | 139 |
|
139 | 140 | def test_sequence_in_exp_space_with_start_end_states_multi_batch_item(self):
|
140 | 141 | use_log_space = False
|
@@ -174,8 +175,8 @@ def test_sequence_in_exp_space_with_start_end_states_multi_batch_item(self):
|
174 | 175 | use_log_space=use_log_space,
|
175 | 176 | use_start_and_end_states=use_start_and_end_states)
|
176 | 177 | multiple_sequence_result = self.evaluate(multiple_sequence_op)
|
177 |
| - self.assertAllEqual(multiple_sequence_result, |
178 |
| - [sequence, sequence, sequence]) |
| 178 | + self.assertRaggedEqual(multiple_sequence_result, |
| 179 | + [sequence, sequence, sequence]) |
179 | 180 |
|
180 | 181 | def test_sequence_in_exp_space_without_start_end_states_single_batch_item(
|
181 | 182 | self):
|
@@ -213,7 +214,7 @@ def test_sequence_in_exp_space_without_start_end_states_single_batch_item(
|
213 | 214 | use_log_space=use_log_space,
|
214 | 215 | use_start_and_end_states=use_start_and_end_states)
|
215 | 216 | single_result = self.evaluate(single_sequence_op)
|
216 |
| - self.assertAllEqual(single_result, [sequence]) |
| 217 | + self.assertRaggedEqual(single_result, [sequence]) |
217 | 218 |
|
218 | 219 | def test_sequence_in_exp_space_without_start_end_states_multi_batch_item(
|
219 | 220 | self):
|
@@ -252,8 +253,8 @@ def test_sequence_in_exp_space_without_start_end_states_multi_batch_item(
|
252 | 253 | use_log_space=use_log_space,
|
253 | 254 | use_start_and_end_states=use_start_and_end_states)
|
254 | 255 | multiple_sequence_result = self.evaluate(multiple_sequence_op)
|
255 |
| - self.assertAllEqual(multiple_sequence_result, |
256 |
| - [sequence, sequence, sequence]) |
| 256 | + self.assertRaggedEqual(multiple_sequence_result, |
| 257 | + [sequence, sequence, sequence]) |
257 | 258 |
|
258 | 259 | def test_sequence_in_log_space_with_start_end_states_single_batch_item(self):
|
259 | 260 | use_log_space = True
|
@@ -293,7 +294,7 @@ def test_sequence_in_log_space_with_start_end_states_single_batch_item(self):
|
293 | 294 | use_log_space=use_log_space,
|
294 | 295 | use_start_and_end_states=use_start_and_end_states)
|
295 | 296 | single_result = self.evaluate(single_sequence_op)
|
296 |
| - self.assertAllEqual(single_result, [sequence]) |
| 297 | + self.assertRaggedEqual(single_result, [sequence]) |
297 | 298 |
|
298 | 299 | def test_sequence_in_log_space_with_start_end_states_multi_batch_item(self):
|
299 | 300 | use_log_space = True
|
@@ -334,8 +335,8 @@ def test_sequence_in_log_space_with_start_end_states_multi_batch_item(self):
|
334 | 335 | use_log_space=use_log_space,
|
335 | 336 | use_start_and_end_states=use_start_and_end_states)
|
336 | 337 | multiple_sequence_result = self.evaluate(multiple_sequence_op)
|
337 |
| - self.assertAllEqual(multiple_sequence_result, |
338 |
| - [sequence, sequence, sequence]) |
| 338 | + self.assertRaggedEqual(multiple_sequence_result, |
| 339 | + [sequence, sequence, sequence]) |
339 | 340 |
|
340 | 341 | def test_sequence_in_log_space_without_start_end_states_single_batch_item(
|
341 | 342 | self):
|
@@ -373,7 +374,7 @@ def test_sequence_in_log_space_without_start_end_states_single_batch_item(
|
373 | 374 | use_log_space=use_log_space,
|
374 | 375 | use_start_and_end_states=use_start_and_end_states)
|
375 | 376 | single_result = self.evaluate(single_sequence_op)
|
376 |
| - self.assertAllEqual(single_result, [sequence]) |
| 377 | + self.assertRaggedEqual(single_result, [sequence]) |
377 | 378 |
|
378 | 379 | def test_sequence_in_log_space_without_start_end_states_multi_batch_item(
|
379 | 380 | self):
|
@@ -412,8 +413,8 @@ def test_sequence_in_log_space_without_start_end_states_multi_batch_item(
|
412 | 413 | use_log_space=use_log_space,
|
413 | 414 | use_start_and_end_states=use_start_and_end_states)
|
414 | 415 | multiple_sequence_result = self.evaluate(multiple_sequence_op)
|
415 |
| - self.assertAllEqual(multiple_sequence_result, |
416 |
| - [sequence, sequence, sequence]) |
| 416 | + self.assertRaggedEqual(multiple_sequence_result, |
| 417 | + [sequence, sequence, sequence]) |
417 | 418 |
|
418 | 419 | def test_sequence_with_none_weights_single_batch_item(self):
|
419 | 420 | use_log_space = True
|
@@ -450,7 +451,7 @@ def test_sequence_with_none_weights_single_batch_item(self):
|
450 | 451 | use_log_space=use_log_space,
|
451 | 452 | use_start_and_end_states=use_start_and_end_states)
|
452 | 453 | single_result = self.evaluate(single_sequence_op)
|
453 |
| - self.assertAllEqual(single_result, [sequence]) |
| 454 | + self.assertRaggedEqual(single_result, [sequence]) |
454 | 455 |
|
455 | 456 | def test_sequence_with_none_weights_multi_batch_item(self):
|
456 | 457 | use_log_space = True
|
@@ -488,8 +489,8 @@ def test_sequence_with_none_weights_multi_batch_item(self):
|
488 | 489 | use_log_space=use_log_space,
|
489 | 490 | use_start_and_end_states=use_start_and_end_states)
|
490 | 491 | multiple_sequence_result = self.evaluate(multiple_sequence_op)
|
491 |
| - self.assertAllEqual(multiple_sequence_result, |
492 |
| - [sequence, sequence, sequence]) |
| 492 | + self.assertRaggedEqual(multiple_sequence_result, |
| 493 | + [sequence, sequence, sequence]) |
493 | 494 |
|
494 | 495 | def test_sequence_with_none_permissions_single_batch_item(self):
|
495 | 496 | use_log_space = True
|
@@ -520,7 +521,7 @@ def test_sequence_with_none_permissions_single_batch_item(self):
|
520 | 521 | use_log_space=use_log_space,
|
521 | 522 | use_start_and_end_states=use_start_and_end_states)
|
522 | 523 | single_result = self.evaluate(single_sequence_op)
|
523 |
| - self.assertAllEqual(single_result, [sequence]) |
| 524 | + self.assertRaggedEqual(single_result, [sequence]) |
524 | 525 |
|
525 | 526 | def test_sequence_with_none_permissions_multi_input(self):
|
526 | 527 | use_log_space = True
|
@@ -552,8 +553,8 @@ def test_sequence_with_none_permissions_multi_input(self):
|
552 | 553 | use_log_space=use_log_space,
|
553 | 554 | use_start_and_end_states=use_start_and_end_states)
|
554 | 555 | multiple_sequence_result = self.evaluate(multiple_sequence_op)
|
555 |
| - self.assertAllEqual(multiple_sequence_result, |
556 |
| - [sequence, sequence, sequence]) |
| 556 | + self.assertRaggedEqual(multiple_sequence_result, |
| 557 | + [sequence, sequence, sequence]) |
557 | 558 |
|
558 | 559 | def test_sequence_with_implicit_sequence_lengths(self):
|
559 | 560 | use_log_space = True
|
@@ -591,8 +592,8 @@ def test_sequence_with_implicit_sequence_lengths(self):
|
591 | 592 | use_log_space=use_log_space,
|
592 | 593 | use_start_and_end_states=use_start_and_end_states)
|
593 | 594 | multiple_sequence_result = self.evaluate(multiple_sequence_op)
|
594 |
| - self.assertAllEqual(multiple_sequence_result, |
595 |
| - [sequence, sequence, sequence]) |
| 595 | + self.assertRaggedEqual(multiple_sequence_result, |
| 596 | + [sequence, sequence, sequence]) |
596 | 597 |
|
597 | 598 | def test_ragged_inputs(self):
|
598 | 599 | use_log_space = True
|
@@ -640,7 +641,7 @@ def test_ragged_inputs(self):
|
640 | 641 | use_log_space=use_log_space,
|
641 | 642 | use_start_and_end_states=use_start_and_end_states)
|
642 | 643 | ragged_result = self.evaluate(ragged_op)
|
643 |
| - self.assertAllEqual(ragged_result, expected_sequence) |
| 644 | + self.assertRaggedEqual(ragged_result, expected_sequence) |
644 | 645 |
|
645 | 646 |
|
646 | 647 | if __name__ == "__main__":
|
|
0 commit comments