Skip to content

Commit fa898b7

Browse files
committed
Revert "Remove ragged_test_util now that the additional functionality it provided has been added & launched in core."
This reverts commit 760b97f.
1 parent 706b36e commit fa898b7

14 files changed

+395
-235
lines changed

tensorflow_text/BUILD

+26-2
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ py_test(
197197
srcs_version = "PY2AND3",
198198
deps = [
199199
":greedy_constrained_sequence_op",
200+
":ragged_test_util",
200201
# numpy dep,
201202
# python:client_testlib tensorflow dep,
202203
# python:framework_test_lib tensorflow dep,
@@ -288,6 +289,7 @@ py_test(
288289
srcs_version = "PY2AND3",
289290
deps = [
290291
":ngrams_op",
292+
":ragged_test_util",
291293
# python:client_testlib tensorflow dep,
292294
# python:constant_op tensorflow dep,
293295
# python:errors tensorflow dep,
@@ -345,6 +347,7 @@ py_test(
345347
srcs_version = "PY2AND3",
346348
deps = [
347349
":pad_along_dimension_op",
350+
":ragged_test_util",
348351
"@absl_py//absl/testing:parameterized",
349352
# python:array_ops tensorflow dep,
350353
# python:client_testlib tensorflow dep,
@@ -397,6 +400,7 @@ py_test(
397400
srcs_version = "PY2AND3",
398401
deps = [
399402
":pointer_ops",
403+
":ragged_test_util",
400404
"@absl_py//absl/testing:parameterized",
401405
# python:client_testlib tensorflow dep,
402406
# python:framework_test_lib tensorflow dep,
@@ -410,6 +414,7 @@ py_test(
410414
srcs_version = "PY2AND3",
411415
deps = [
412416
":pointer_ops",
417+
":ragged_test_util",
413418
"@absl_py//absl/testing:parameterized",
414419
# python:array_ops tensorflow dep,
415420
# python:client_testlib tensorflow dep,
@@ -422,6 +427,19 @@ py_test(
422427
],
423428
)
424429

430+
py_library(
431+
name = "ragged_test_util",
432+
srcs = ["python/ops/ragged_test_util.py"],
433+
srcs_version = "PY2AND3",
434+
deps = [
435+
# numpy dep,
436+
# python:framework_ops tensorflow dep,
437+
# python:framework_test_lib tensorflow dep,
438+
# python/ops/ragged:ragged_tensor tensorflow dep,
439+
# python/ops/ragged:ragged_tensor_value tensorflow dep,
440+
],
441+
)
442+
425443
py_tf_text_library(
426444
name = "regex_split_ops",
427445
srcs = ["python/ops/regex_split_ops.py"],
@@ -465,12 +483,12 @@ py_test(
465483
srcs = ["python/metrics/text_similarity_metric_ops_test.py"],
466484
srcs_version = "PY2AND3",
467485
deps = [
486+
":ragged_test_util",
468487
":text_similarity_metric_ops",
469488
"@absl_py//absl/testing:parameterized",
470489
# python:array_ops tensorflow dep,
471490
# python:client_testlib tensorflow dep,
472491
# python:dtypes tensorflow dep,
473-
# python:framework_test_lib tensorflow dep,
474492
# python:lookup_ops tensorflow dep,
475493
# python:math_ops tensorflow dep,
476494
# python/ops/ragged:ragged_factory_ops tensorflow dep,
@@ -554,6 +572,7 @@ py_test(
554572
],
555573
srcs_version = "PY2AND3",
556574
deps = [
575+
":ragged_test_util",
557576
":sentencepiece_tokenizer",
558577
"@absl_py//absl/testing:parameterized",
559578
# python:client_testlib tensorflow dep,
@@ -581,6 +600,7 @@ py_test(
581600
srcs = ["python/ops/sliding_window_op_test.py"],
582601
srcs_version = "PY2AND3",
583602
deps = [
603+
":ragged_test_util",
584604
":sliding_window_op",
585605
"@absl_py//absl/testing:parameterized",
586606
# python:array_ops tensorflow dep,
@@ -672,6 +692,7 @@ py_test(
672692
shard_count = 5,
673693
srcs_version = "PY2AND3",
674694
deps = [
695+
":ragged_test_util",
675696
":unicode_char_tokenizer",
676697
# python:client_testlib tensorflow dep,
677698
# python:constant_op tensorflow dep,
@@ -707,6 +728,7 @@ py_test(
707728
shard_count = 5,
708729
srcs_version = "PY2AND3",
709730
deps = [
731+
":ragged_test_util",
710732
":unicode_script_tokenizer",
711733
# python:client_testlib tensorflow dep,
712734
# python:constant_op tensorflow dep,
@@ -735,6 +757,7 @@ py_test(
735757
srcs = ["python/ops/viterbi_constrained_sequence_op_test.py"],
736758
srcs_version = "PY2AND3",
737759
deps = [
760+
":ragged_test_util",
738761
":viterbi_constrained_sequence_op",
739762
":viterbi_decode",
740763
# numpy dep,
@@ -791,6 +814,7 @@ py_test(
791814
shard_count = 4,
792815
srcs_version = "PY2AND3",
793816
deps = [
817+
":ragged_test_util",
794818
":whitespace_tokenizer",
795819
# python:client_testlib tensorflow dep,
796820
# python:constant_op tensorflow dep,
@@ -852,12 +876,12 @@ py_test(
852876
srcs = ["python/ops/wordpiece_tokenizer_test.py"],
853877
srcs_version = "PY2AND3",
854878
deps = [
879+
":ragged_test_util",
855880
":wordpiece_tokenizer",
856881
"@absl_py//absl/testing:parameterized",
857882
# python:array_ops tensorflow dep,
858883
# python:client_testlib tensorflow dep,
859884
# python:dtypes tensorflow dep,
860-
# python:framework_test_lib tensorflow dep,
861885
# python:lookup_ops tensorflow dep,
862886
# python:math_ops tensorflow dep,
863887
# python/compat tensorflow dep,

tensorflow_text/python/metrics/text_similarity_metric_ops_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
import re
2424
from absl.testing import parameterized
2525
from tensorflow.python.framework import dtypes
26-
from tensorflow.python.framework import test_util
2726
from tensorflow.python.ops.ragged import ragged_factory_ops
2827
from tensorflow.python.platform import test
2928
from tensorflow_text.python.metrics import text_similarity_metric_ops
29+
from tensorflow_text.python.ops import ragged_test_util
3030

3131

3232
def _tokenize_whitespace(text):
@@ -85,7 +85,7 @@ def _tokenize_155_compat(text):
8585
"wheel brake upon landing")
8686

8787

88-
class TextSimilarityMetricOpsTest(test_util.TensorFlowTestCase,
88+
class TextSimilarityMetricOpsTest(ragged_test_util.RaggedTensorTestCase,
8989
parameterized.TestCase):
9090

9191
@parameterized.parameters([

tensorflow_text/python/ops/greedy_constrained_sequence_op_test.py

+23-22
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@
2525
from tensorflow.python.ops.ragged import ragged_factory_ops
2626
from tensorflow.python.platform import test
2727
from tensorflow_text.python.ops import greedy_constrained_sequence_op as greedy_op
28+
from tensorflow_text.python.ops import ragged_test_util
2829

2930

3031
# TODO(b/122968457): Refactor this test logic.
3132
@test_util.run_all_in_graph_and_eager_modes
32-
class GreedyConstrainedSequenceOpTest(test_util.TensorFlowTestCase):
33+
class GreedyConstrainedSequenceOpTest(ragged_test_util.RaggedTensorTestCase):
3334

3435
def _last_max(self, array):
3536
"""Helper function that matches the maximum behaviour in the C++ op."""
@@ -134,7 +135,7 @@ def test_sequence_in_exp_space_with_start_end_states_single_batch_item(self):
134135
use_log_space=use_log_space,
135136
use_start_and_end_states=use_start_and_end_states)
136137
single_result = self.evaluate(single_sequence_op)
137-
self.assertAllEqual(single_result, [sequence])
138+
self.assertRaggedEqual(single_result, [sequence])
138139

139140
def test_sequence_in_exp_space_with_start_end_states_multi_batch_item(self):
140141
use_log_space = False
@@ -174,8 +175,8 @@ def test_sequence_in_exp_space_with_start_end_states_multi_batch_item(self):
174175
use_log_space=use_log_space,
175176
use_start_and_end_states=use_start_and_end_states)
176177
multiple_sequence_result = self.evaluate(multiple_sequence_op)
177-
self.assertAllEqual(multiple_sequence_result,
178-
[sequence, sequence, sequence])
178+
self.assertRaggedEqual(multiple_sequence_result,
179+
[sequence, sequence, sequence])
179180

180181
def test_sequence_in_exp_space_without_start_end_states_single_batch_item(
181182
self):
@@ -213,7 +214,7 @@ def test_sequence_in_exp_space_without_start_end_states_single_batch_item(
213214
use_log_space=use_log_space,
214215
use_start_and_end_states=use_start_and_end_states)
215216
single_result = self.evaluate(single_sequence_op)
216-
self.assertAllEqual(single_result, [sequence])
217+
self.assertRaggedEqual(single_result, [sequence])
217218

218219
def test_sequence_in_exp_space_without_start_end_states_multi_batch_item(
219220
self):
@@ -252,8 +253,8 @@ def test_sequence_in_exp_space_without_start_end_states_multi_batch_item(
252253
use_log_space=use_log_space,
253254
use_start_and_end_states=use_start_and_end_states)
254255
multiple_sequence_result = self.evaluate(multiple_sequence_op)
255-
self.assertAllEqual(multiple_sequence_result,
256-
[sequence, sequence, sequence])
256+
self.assertRaggedEqual(multiple_sequence_result,
257+
[sequence, sequence, sequence])
257258

258259
def test_sequence_in_log_space_with_start_end_states_single_batch_item(self):
259260
use_log_space = True
@@ -293,7 +294,7 @@ def test_sequence_in_log_space_with_start_end_states_single_batch_item(self):
293294
use_log_space=use_log_space,
294295
use_start_and_end_states=use_start_and_end_states)
295296
single_result = self.evaluate(single_sequence_op)
296-
self.assertAllEqual(single_result, [sequence])
297+
self.assertRaggedEqual(single_result, [sequence])
297298

298299
def test_sequence_in_log_space_with_start_end_states_multi_batch_item(self):
299300
use_log_space = True
@@ -334,8 +335,8 @@ def test_sequence_in_log_space_with_start_end_states_multi_batch_item(self):
334335
use_log_space=use_log_space,
335336
use_start_and_end_states=use_start_and_end_states)
336337
multiple_sequence_result = self.evaluate(multiple_sequence_op)
337-
self.assertAllEqual(multiple_sequence_result,
338-
[sequence, sequence, sequence])
338+
self.assertRaggedEqual(multiple_sequence_result,
339+
[sequence, sequence, sequence])
339340

340341
def test_sequence_in_log_space_without_start_end_states_single_batch_item(
341342
self):
@@ -373,7 +374,7 @@ def test_sequence_in_log_space_without_start_end_states_single_batch_item(
373374
use_log_space=use_log_space,
374375
use_start_and_end_states=use_start_and_end_states)
375376
single_result = self.evaluate(single_sequence_op)
376-
self.assertAllEqual(single_result, [sequence])
377+
self.assertRaggedEqual(single_result, [sequence])
377378

378379
def test_sequence_in_log_space_without_start_end_states_multi_batch_item(
379380
self):
@@ -412,8 +413,8 @@ def test_sequence_in_log_space_without_start_end_states_multi_batch_item(
412413
use_log_space=use_log_space,
413414
use_start_and_end_states=use_start_and_end_states)
414415
multiple_sequence_result = self.evaluate(multiple_sequence_op)
415-
self.assertAllEqual(multiple_sequence_result,
416-
[sequence, sequence, sequence])
416+
self.assertRaggedEqual(multiple_sequence_result,
417+
[sequence, sequence, sequence])
417418

418419
def test_sequence_with_none_weights_single_batch_item(self):
419420
use_log_space = True
@@ -450,7 +451,7 @@ def test_sequence_with_none_weights_single_batch_item(self):
450451
use_log_space=use_log_space,
451452
use_start_and_end_states=use_start_and_end_states)
452453
single_result = self.evaluate(single_sequence_op)
453-
self.assertAllEqual(single_result, [sequence])
454+
self.assertRaggedEqual(single_result, [sequence])
454455

455456
def test_sequence_with_none_weights_multi_batch_item(self):
456457
use_log_space = True
@@ -488,8 +489,8 @@ def test_sequence_with_none_weights_multi_batch_item(self):
488489
use_log_space=use_log_space,
489490
use_start_and_end_states=use_start_and_end_states)
490491
multiple_sequence_result = self.evaluate(multiple_sequence_op)
491-
self.assertAllEqual(multiple_sequence_result,
492-
[sequence, sequence, sequence])
492+
self.assertRaggedEqual(multiple_sequence_result,
493+
[sequence, sequence, sequence])
493494

494495
def test_sequence_with_none_permissions_single_batch_item(self):
495496
use_log_space = True
@@ -520,7 +521,7 @@ def test_sequence_with_none_permissions_single_batch_item(self):
520521
use_log_space=use_log_space,
521522
use_start_and_end_states=use_start_and_end_states)
522523
single_result = self.evaluate(single_sequence_op)
523-
self.assertAllEqual(single_result, [sequence])
524+
self.assertRaggedEqual(single_result, [sequence])
524525

525526
def test_sequence_with_none_permissions_multi_input(self):
526527
use_log_space = True
@@ -552,8 +553,8 @@ def test_sequence_with_none_permissions_multi_input(self):
552553
use_log_space=use_log_space,
553554
use_start_and_end_states=use_start_and_end_states)
554555
multiple_sequence_result = self.evaluate(multiple_sequence_op)
555-
self.assertAllEqual(multiple_sequence_result,
556-
[sequence, sequence, sequence])
556+
self.assertRaggedEqual(multiple_sequence_result,
557+
[sequence, sequence, sequence])
557558

558559
def test_sequence_with_implicit_sequence_lengths(self):
559560
use_log_space = True
@@ -591,8 +592,8 @@ def test_sequence_with_implicit_sequence_lengths(self):
591592
use_log_space=use_log_space,
592593
use_start_and_end_states=use_start_and_end_states)
593594
multiple_sequence_result = self.evaluate(multiple_sequence_op)
594-
self.assertAllEqual(multiple_sequence_result,
595-
[sequence, sequence, sequence])
595+
self.assertRaggedEqual(multiple_sequence_result,
596+
[sequence, sequence, sequence])
596597

597598
def test_ragged_inputs(self):
598599
use_log_space = True
@@ -640,7 +641,7 @@ def test_ragged_inputs(self):
640641
use_log_space=use_log_space,
641642
use_start_and_end_states=use_start_and_end_states)
642643
ragged_result = self.evaluate(ragged_op)
643-
self.assertAllEqual(ragged_result, expected_sequence)
644+
self.assertRaggedEqual(ragged_result, expected_sequence)
644645

645646

646647
if __name__ == "__main__":

0 commit comments

Comments
 (0)