Skip to content

Commit

Permalink
Add text tokenization preprocessing op.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 583802279
  • Loading branch information
jpuigcerver authored and copybara-github committed Nov 19, 2023
1 parent f8b56f6 commit 702c66e
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 2 deletions.
5 changes: 4 additions & 1 deletion vmoe/data/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def get_dataset(
process: str,
cache: Optional[str] = None,
num_parallel_calls: int = 128,
pre_filter_fn: Optional[Callable[[Data], bool]] = None,
prefetch: Optional[Union[int, str]] = None,
shuffle_buffer: int = DEFAULT_SHUFFLE_BUFFER,
shuffle_seed: Optional[int] = None,
Expand All @@ -78,6 +79,7 @@ def get_dataset(
cache: If 'loaded' caches the dataset after loading it. If 'batched',
caches it after batching. If `None`, no caching is done.
num_parallel_calls: Process this number of examples in parallel.
pre_filter_fn: If given, filters the dataset according to this function.
prefetch: If given, prefetches this number of batches.
shuffle_buffer: Size of the shuffle buffer. Only used for training.
shuffle_seed: Optional seed for shuffling files and examples.
Expand All @@ -104,6 +106,7 @@ def get_dataset(
f'and {jax.device_count()} respectively.')
batch_size_per_process = batch_size // jax.process_count()
data = builder.as_dataset()
data = data.filter(pre_filter_fn) if pre_filter_fn is not None else data
# Optionally, cache loaded data.
if cache == 'loaded':
data = data.cache()
Expand Down Expand Up @@ -154,7 +157,7 @@ def get_data_num_examples(config: ml_collections.ConfigDict) -> int:
# These are kwarg keys used when creating the pipeline, not the builder.
pipeline_keys = ('variant', 'batch_size', 'process', 'cache',
'num_parallel_calls', 'prefetch', 'prefetch_device',
'shuffle_buffer')
'shuffle_buffer', 'pre_filter_fn')
builder_kwargs = {k: v for k, v in config.items() if k not in pipeline_keys}
builder = vmoe.data.builder.get_dataset_builder(**builder_kwargs)
return builder.num_examples
Expand Down
51 changes: 50 additions & 1 deletion vmoe/data/pp_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
import tensorflow.compat.v1 as tf
import tensorflow.compat.v2 as tf2

try:
from big_vision.pp import ops_text as bv_ops_text # pylint: disable=g-import-not-at-top
except ImportError:
bv_ops_text = None

try:
from cloud_tpu.models.efficientnet import autoaugment # pylint: disable=g-import-not-at-top
except ImportError:
Expand Down Expand Up @@ -258,7 +263,7 @@ def randaug(num_layers: int = 2, magnitude: int = 10):
a function that applies RandAugment.
"""
if autoaugment is None:
raise ValueError(
raise NotImplementedError(
"In order to use RandAugment you need to install the 'cloud_tpu' "
"package. Clone the https://github.com/tensorflow/tpu repository, "
"name it 'cloud_tpu', and add the corresponding directory to your "
Expand Down Expand Up @@ -350,6 +355,50 @@ def _reshape(image):
return _reshape


@InKeyOutKey(indefault='text', outdefault='text')
def tokenize(
max_len,
eos,
model='c4_en',
lower=True,
sample_if_multi=True,
pad_value='<pad>',
):
"""Tokenizes text using big_vision.pp.ops_text.tokenize."""
if bv_ops_text is None:
raise NotImplementedError(
"In order to tokenize text you must install the Big Vision package. "
"Clone the https://github.com/google-research/big_vision repository, "
"and add the 'big_vision' directory to your PYTHONPATH.")

if eos not in ('yes', 'none', 'sticky'):
raise ValueError(f"Invalid value for eos: '{eos}'.")

tokenizer = bv_ops_text.create_tokenizer(model, add_eos=eos != 'none')

if isinstance(pad_value, str):
pad_value = tokenizer.string_to_id(pad_value)

def _pp_tokenize(txt):
if sample_if_multi:
txt = bv_ops_text.ops_general.get_choice(
empty_fallback='', key='t')({'t': txt})['t']

if lower:
txt = tf.strings.lower(txt) if sample_if_multi else tf.map_fn(
tf.strings.lower, txt)

return bv_ops_text.tokenize(
txt,
tokenizer,
max_len,
pad_value=pad_value,
force_eos=eos == 'sticky',
multi_text=not sample_if_multi)

return _pp_tokenize


@InKeyOutKey()
def value_range(vmin, vmax, in_min=0, in_max=255.0, clip_values=False):
"""Transforms a [in_min,in_max] image to [vmin,vmax] range.
Expand Down
19 changes: 19 additions & 0 deletions vmoe/data/pp_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vmoe.data import pp_ops



class PreprocessOpsTest(tf.test.TestCase):

def get_data(self, dtype=tf.uint8):
Expand Down Expand Up @@ -126,5 +127,23 @@ def test_reshape(self):
output_data = pp_ops.reshape(new_shape=(8, 32, 32, 3))(data)
self.assertAllEqual(output_data['image'].shape, [8, 32, 32, 3])

def test_tokenize(self):
if pp_ops.bv_ops_text is None:
self.skipTest('Big Vision is not installed.')
model = 'c4_en' # pylint: disable=unused-variable
max_len = 5
data = {'text': tf.constant(['FOO', 'BAR'], dtype=tf.string)}
output_data = pp_ops.tokenize(
max_len=max_len, eos='yes', model=model, sample_if_multi=True)(data)
self.assertEqual(output_data['text'].shape, [max_len])

def test_tokenize_raises(self):
original_bv_ops_text = pp_ops.bv_ops_text
pp_ops.bv_ops_text = None
with self.assertRaisesRegex(NotImplementedError, 'you must install'):
data = {'text': tf.constant(['FOO', 'BAR'], dtype=tf.string)}
pp_ops.tokenize(max_len=5, eos='yes')(data)
pp_ops.bv_ops_text = original_bv_ops_text

if __name__ == '__main__':
tf.test.main()

0 comments on commit 702c66e

Please sign in to comment.