Skip to content

Commit ee2015a

Browse files
authored
Standalone functions for generate pre/post processing for OPT (#1015)
* Standalone functions for generate pre/post processing for OPT * A few more doc updates * Fix
1 parent f1ef72c commit ee2015a

File tree

6 files changed

+210
-129
lines changed

6 files changed

+210
-129
lines changed

keras_nlp/models/opt/opt_causal_lm.py

Lines changed: 119 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from keras_nlp.utils.keras_utils import is_xla_compatible
3030
from keras_nlp.utils.python_utils import classproperty
3131
from keras_nlp.utils.tf_utils import tensor_to_string_list
32-
from keras_nlp.utils.tf_utils import truncate_at_token
3332

3433

3534
@keras_nlp_export("keras_nlp.models.OPTCausalLM")
@@ -49,7 +48,7 @@ class OPTCausalLM(Task):
4948
default, `"top_k"` sampling will be used.
5049
5150
This model can optionally be configured with a `preprocessor` layer, in
52-
which case it will automatically apply preprocessing to raw inputs during
51+
which case it will automatically apply preprocessing to string inputs during
5352
`fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default
5453
when creating the model with `from_preset()`.
5554
@@ -301,28 +300,23 @@ def make_generate_function(self):
301300

302301
def generate_step(
303302
self,
304-
token_ids,
305-
padding_mask,
303+
inputs,
306304
end_token_id=None,
307305
):
308306
"""A compilable generation function for a single batch of inputs.
309307
310308
This function represents the inner, XLA-compilable, generation function
311-
for a single batch of inputs. It takes in a dense `tf.Tensor` of token
312-
ids, and return a dense `tf.Tensor` of token ids, and includes no
313-
preprocessing. This function is wrapped by the `generate()` method.
309+
for a single batch of inputs. Inputs should have the same structure as
310+
model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.
314311
315312
Args:
316-
token_ids: A dense int Tensor, with shape
317-
`(batch_size, max_length)`. The user provided token ids
318-
padded to `max_length`.
319-
padding_mask: A dense boolean Tensor, with the same shape as
320-
`token_ids`. Positions that are True in the `padding_mask`
321-
are assumed to be user input and never updated.
313+
inputs: A dictionary with two keys `"token_ids"` and
314+
`"padding_mask"` and batched tensor values.
322315
end_token_id: The id of the end token to stop on. If all
323316
sequences have produced a new `end_token_id`, generation
324317
will stop.
325318
"""
319+
token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"]
326320
# Create and seed cache with a single forward pass.
327321
hidden_states, cache = self._build_cache(token_ids)
328322
# Compute the lengths of all user inputted tokens ids.
@@ -347,7 +341,7 @@ def next(prompt, cache, index):
347341
cache,
348342
)
349343

350-
return self._sampler(
344+
token_ids = self._sampler(
351345
next=next,
352346
prompt=token_ids,
353347
cache=cache,
@@ -357,6 +351,78 @@ def next(prompt, cache, index):
357351
hidden_states=hidden_states,
358352
)
359353

354+
# Compute an output padding mask with the token ids we updated.
355+
if end_token_id is not None:
356+
# Build a mask of `end_token_id` locations not in the original
357+
# prompt (not in locations where `padding_mask` is True).
358+
end_locations = (token_ids == end_token_id) & (~padding_mask)
359+
end_locations = tf.cast(end_locations, tf.int32)
360+
# Use cumsum to get ones in all locations after end_locations.
361+
overflow = tf.math.cumsum(end_locations, exclusive=True)
362+
# Our padding mask is the inverse of these overflow locations.
363+
padding_mask = ~tf.cast(overflow, tf.bool)
364+
else:
365+
# Without early stopping, all locations will have been updated.
366+
padding_mask = tf.ones_like(token_ids, dtype=tf.bool)
367+
return {
368+
"token_ids": token_ids,
369+
"padding_mask": padding_mask,
370+
}
371+
372+
def _normalize_generate_inputs(
373+
self,
374+
inputs,
375+
):
376+
"""Normalize user input to the generate function.
377+
378+
This function coverts all inputs to tensors, adds a batch dimension if
379+
necessary, and returns a iterable "dataset like" object (either an
380+
actual `tf.data.Dataset` or a list with a single batch element).
381+
"""
382+
input_is_scalar = False
383+
384+
if isinstance(inputs, tf.data.Dataset):
385+
return inputs, input_is_scalar
386+
387+
if isinstance(inputs, str) or isinstance(inputs, list):
388+
inputs = tf.convert_to_tensor(inputs)
389+
390+
if isinstance(inputs, tf.Tensor) and inputs.shape.rank == 0:
391+
input_is_scalar = True
392+
inputs = inputs[tf.newaxis]
393+
394+
# We avoid coverting to a dataset purely for speed, for a single batch
395+
# of input, creating a dataset would add significant overhead.
396+
return [inputs], input_is_scalar
397+
398+
def _normalize_generate_outputs(
399+
self,
400+
outputs,
401+
input_is_scalar,
402+
):
403+
"""Normalize user output from the generate function.
404+
405+
This function converts all output to numpy (for integer output), or
406+
python strings (for string output). If a batch dimension was added to
407+
the input, it is removed from the output (so generate can be string in,
408+
string out).
409+
"""
410+
411+
def normalize(x):
412+
x = tf.concat(x, axis=0)
413+
x = tf.squeeze(x, 0) if input_is_scalar else x
414+
is_string = x.dtype == tf.string
415+
# Convert outputs to a friendly pythonic type. For numerical outputs
416+
# that is numpy, for string outputs that is `list` and `str`.
417+
return tensor_to_string_list(x) if is_string else x.numpy()
418+
419+
if isinstance(outputs[0], dict):
420+
return {
421+
"token_ids": normalize([x["token_ids"] for x in outputs]),
422+
"padding_mask": normalize([x["padding_mask"] for x in outputs]),
423+
}
424+
return normalize([x for x in outputs])
425+
360426
def generate(
361427
self,
362428
inputs,
@@ -367,14 +433,14 @@ def generate(
367433
This method generates text based on given `inputs`. The sampling method
368434
used for generation can be set in the `compile` method.
369435
370-
If `inputs` is a `tf.data.Dataset`, outputs will be generated
436+
If `inputs` are a `tf.data.Dataset`, outputs will be generated
371437
"batch-by-batch" and concatenated. Otherwise, all inputs will be handled
372438
as a single batch.
373439
374440
If a `preprocessor` is attached to the model, `inputs` should be
375441
strings and returned sequences will be strings. Otherwise, inputs should
376-
be preprocessed into token ids before calling `generate()`, and returned
377-
sequences will also be token ids.
442+
be preprocessed before calling `generate()`, and returned sequences will
443+
be token ids.
378444
379445
Args:
380446
inputs: a string `tf.Tensor`, a `tf.data.Dataset` of strings, a
@@ -383,73 +449,52 @@ def generate(
383449
`tf.Tensor` or `tf.data.Dataset` with keys `"token_ids"` and
384450
`"padding_mask"`.
385451
max_length: Optional. int. The max length of the generated sequence.
386-
Will default to the configured `sequence_length` of the
452+
Will default to the max configured `sequence_length` of the
387453
`preprocessor`. If `preprocessor` is `None`, `inputs` should be
388-
padded to the desired max length and this argument is ignored.
454+
should be padded to the desired maximum length and this argument
455+
will be ignored.
389456
390457
Returns:
391458
A string or string list if `preprocessor` is set, and a integer
392-
tensor of token ids if `preprocessor is None`.
459+
tensor of token IDs if `preprocessor is None`.
393460
"""
394-
input_is_scalar = False
395-
461+
# Setup our three main passes.
462+
# 1. Optionally preprocessing strings to dense integer tensors.
463+
# 2. Generate new tokens via a compiled function on dense tensors.
464+
# 3. Optionally postprocess dense integer tensors back to string.
465+
generate_function = self.make_generate_function()
466+
end_token_id = None
396467
if self.preprocessor is not None:
468+
end_token_id = self.preprocessor.tokenizer.end_token_id
397469

398-
def preprocess(x):
399-
return self.preprocessor(
400-
x,
401-
sequence_length=max_length,
402-
return_labels=False,
403-
# We do not append an end token by default during
404-
# generation, as generating directly in the same sequence is
405-
# the most common workflow. If an end token directly after
406-
# a prompt is desired, it can be added to the prompt string.
407-
add_end_token=False,
408-
)
409-
410-
if not isinstance(inputs, tf.data.Dataset):
411-
inputs = tf.convert_to_tensor(inputs)
412-
input_is_scalar = inputs.shape.rank == 0
413-
inputs = inputs[tf.newaxis] if input_is_scalar else inputs
414-
# Wrap a list to avoid the overhead of converting to dataset.
415-
inputs = [preprocess(inputs)]
416-
else:
470+
def preprocess(x):
471+
return self.preprocessor.generate_preprocess(
472+
x, sequence_length=max_length
473+
)
474+
475+
def generate(x):
476+
return generate_function(x, end_token_id=end_token_id)
477+
478+
def postprocess(x):
479+
return self.preprocessor.generate_postprocess(x)
480+
481+
# Normalize inputs, apply our three passes, and normalize outputs.
482+
inputs, input_is_scalar = self._normalize_generate_inputs(inputs)
483+
484+
if self.preprocessor is not None:
485+
if isinstance(inputs, tf.data.Dataset):
417486
inputs = inputs.map(preprocess, tf.data.AUTOTUNE)
418487
inputs = inputs.prefetch(tf.data.AUTOTUNE)
419-
else:
420-
if not isinstance(inputs, tf.data.Dataset):
421-
# Wrap a list to avoid the overhead of converting to dataset.
422-
inputs = [inputs]
488+
else:
489+
# Fast path for non-dataset, single-batch input.
490+
inputs = [preprocess(x) for x in inputs]
423491

424-
generate_function = self.make_generate_function()
425-
outputs = []
426-
for batch in inputs:
427-
token_ids, padding_mask = batch["token_ids"], batch["padding_mask"]
428-
# If `preprocessor` is attached, we can stop after `end_token_id``.
429-
end_token_id = None
430-
if self.preprocessor is not None:
431-
end_token_id = self.preprocessor.tokenizer.end_token_id
432-
# Run the compiled generate function.
433-
output = generate_function(token_ids, padding_mask, end_token_id)
434-
435-
if self.preprocessor is not None:
436-
# Truncate to ragged by removing tokens after the first
437-
# generated `end_token_id`.
438-
output = truncate_at_token(output, end_token_id, padding_mask)
439-
# Strip start token if added.
440-
if self.preprocessor.add_start_token:
441-
output = output[:, 1:]
442-
# Detokenize.
443-
output = self.preprocessor.tokenizer.detokenize(output)
444-
outputs.append(output)
445-
446-
outputs = tf.concat(outputs, axis=0)
447-
outputs = tf.squeeze(outputs, 0) if input_is_scalar else outputs
448-
# Convert outputs to a friendly pythonic type. For numerical outputs
449-
# that is numpy, for string outputs that is `list` and `str`.
450-
if outputs.dtype == tf.string:
451-
return tensor_to_string_list(outputs)
452-
return outputs.numpy()
492+
outputs = [generate(x) for x in inputs]
493+
494+
if self.preprocessor is not None:
495+
outputs = [postprocess(x) for x in outputs]
496+
497+
return self._normalize_generate_outputs(outputs, input_is_scalar)
453498

454499
@classmethod
455500
def create_layout_map(cls, mesh):

keras_nlp/models/opt/opt_causal_lm_preprocessor.py

Lines changed: 65 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414

1515
"""OPT Causal LM preprocessor layer."""
1616

17+
import tensorflow as tf
1718
from absl import logging
1819

1920
from keras_nlp.api_export import keras_nlp_export
2021
from keras_nlp.models.opt.opt_preprocessor import OPTPreprocessor
22+
from keras_nlp.utils.keras_utils import (
23+
convert_inputs_to_list_of_tensor_segments,
24+
)
2125
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
2226

2327

@@ -95,36 +99,73 @@ def call(
9599
y=None,
96100
sample_weight=None,
97101
sequence_length=None,
98-
add_start_token=None,
99-
add_end_token=None,
100-
return_labels=True,
101102
):
102103
if y is not None or sample_weight is not None:
103104
logging.warning(
104-
"`OPTCausalLMPreprocessor` generates `y` and `sample_weight` "
105+
"`GPT2CausalLMPreprocessor` generates `y` and `sample_weight` "
105106
"based on your input data, but your data already contains `y` "
106107
"or `sample_weight`. Your `y` and `sample_weight` will be "
107108
"ignored."
108109
)
109-
if return_labels:
110-
# Tokenize with one extra token to account for the truncation below.
111-
sequence_length = (sequence_length or self.sequence_length) + 1
112-
x = super().call(
110+
sequence_length = sequence_length or self.sequence_length
111+
112+
x = convert_inputs_to_list_of_tensor_segments(x)[0]
113+
x = self.tokenizer(x)
114+
# Pad with one extra token to account for the truncation below.
115+
token_ids, padding_mask = self.packer(
113116
x,
114-
sequence_length=sequence_length,
115-
add_start_token=add_start_token,
116-
add_end_token=add_end_token,
117+
sequence_length=sequence_length + 1,
118+
add_start_value=self.add_start_token,
119+
add_end_value=self.add_end_token,
120+
)
121+
# The last token does not have a next token, so we truncate it out.
122+
x = {
123+
"token_ids": token_ids[..., :-1],
124+
"padding_mask": padding_mask[..., :-1],
125+
}
126+
# Target `y` will be the next token.
127+
y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:]
128+
return pack_x_y_sample_weight(x, y, sample_weight)
129+
130+
def generate_preprocess(
131+
self,
132+
x,
133+
sequence_length=None,
134+
):
135+
"""Covert strings to integer token input for generation.
136+
137+
Similar to calling the layer for training, this method takes in strings
138+
or tensor strings, tokenizes and packs the input, and computes a padding
139+
mask masking all inputs not filled in with a padded value.
140+
141+
Unlike calling the the layer for training, this method does not compute
142+
labels and will never append a `tokenizer.end_token_id` to the end of
143+
the sequence (as generation is expected to continue at the end of the
144+
inputted prompt).
145+
"""
146+
x = convert_inputs_to_list_of_tensor_segments(x)[0]
147+
x = self.tokenizer(x)
148+
token_ids, padding_mask = self.packer(
149+
x, sequence_length=sequence_length, add_end_value=False
117150
)
118-
if return_labels:
119-
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
120-
# The last token does not have a next token, so we truncate it out.
121-
x = {
122-
"token_ids": token_ids[..., :-1],
123-
"padding_mask": padding_mask[..., :-1],
124-
}
125-
# Target `y` will be the next token.
126-
y = token_ids[..., 1:]
127-
sample_weight = padding_mask[..., 1:]
128-
return pack_x_y_sample_weight(x, y, sample_weight)
129-
else:
130-
return x
151+
return {
152+
"token_ids": token_ids,
153+
"padding_mask": padding_mask,
154+
}
155+
156+
def generate_postprocess(
157+
self,
158+
x,
159+
):
160+
"""Covert integer token output to strings for generation.
161+
162+
This method reverses `generate_preprocess()`, by first removing all
163+
padding and start/end tokens, and then converting the interger sequence
164+
back to a string.
165+
"""
166+
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
167+
# Strip any special tokens during detokenization (e.g. the start and
168+
# end markers). In the future we could make this configurable.
169+
padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id)
170+
token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)
171+
return self.tokenizer.detokenize(token_ids)

0 commit comments

Comments
 (0)