2929from keras_nlp .utils .keras_utils import is_xla_compatible
3030from keras_nlp .utils .python_utils import classproperty
3131from keras_nlp .utils .tf_utils import tensor_to_string_list
32- from keras_nlp .utils .tf_utils import truncate_at_token
3332
3433
3534@keras_nlp_export ("keras_nlp.models.OPTCausalLM" )
@@ -49,7 +48,7 @@ class OPTCausalLM(Task):
4948 default, `"top_k"` sampling will be used.
5049
5150 This model can optionally be configured with a `preprocessor` layer, in
52- which case it will automatically apply preprocessing to raw inputs during
51+ which case it will automatically apply preprocessing to string inputs during
5352 `fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default
5453 when creating the model with `from_preset()`.
5554
@@ -301,28 +300,23 @@ def make_generate_function(self):
301300
302301 def generate_step (
303302 self ,
304- token_ids ,
305- padding_mask ,
303+ inputs ,
306304 end_token_id = None ,
307305 ):
308306 """A compilable generation function for a single batch of inputs.
309307
310308 This function represents the inner, XLA-compilable, generation function
311- for a single batch of inputs. It takes in a dense `tf.Tensor` of token
312- ids, and return a dense `tf.Tensor` of token ids, and includes no
313- preprocessing. This function is wrapped by the `generate()` method.
309+ for a single batch of inputs. Inputs should have the same structure as
310+ model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.
314311
315312 Args:
316- token_ids: A dense int Tensor, with shape
317- `(batch_size, max_length)`. The user provided token ids
318- padded to `max_length`.
319- padding_mask: A dense boolean Tensor, with the same shape as
320- `token_ids`. Positions that are True in the `padding_mask`
321- are assumed to be user input and never updated.
313+ inputs: A dictionary with two keys `"token_ids"` and
314+ `"padding_mask"` and batched tensor values.
322315 end_token_id: The id of the end token to stop on. If all
323316 sequences have produced a new `end_token_id`, generation
324317 will stop.
325318 """
319+ token_ids , padding_mask = inputs ["token_ids" ], inputs ["padding_mask" ]
326320 # Create and seed cache with a single forward pass.
327321 hidden_states , cache = self ._build_cache (token_ids )
328322 # Compute the lengths of all user inputted tokens ids.
@@ -347,7 +341,7 @@ def next(prompt, cache, index):
347341 cache ,
348342 )
349343
350- return self ._sampler (
344+ token_ids = self ._sampler (
351345 next = next ,
352346 prompt = token_ids ,
353347 cache = cache ,
@@ -357,6 +351,78 @@ def next(prompt, cache, index):
357351 hidden_states = hidden_states ,
358352 )
359353
354+ # Compute an output padding mask with the token ids we updated.
355+ if end_token_id is not None :
356+ # Build a mask of `end_token_id` locations not in the original
357+ # prompt (not in locations where `padding_mask` is True).
358+ end_locations = (token_ids == end_token_id ) & (~ padding_mask )
359+ end_locations = tf .cast (end_locations , tf .int32 )
360+ # Use cumsum to get ones in all locations after end_locations.
361+ overflow = tf .math .cumsum (end_locations , exclusive = True )
362+ # Our padding mask is the inverse of these overflow locations.
363+ padding_mask = ~ tf .cast (overflow , tf .bool )
364+ else :
365+ # Without early stopping, all locations will have been updated.
366+ padding_mask = tf .ones_like (token_ids , dtype = tf .bool )
367+ return {
368+ "token_ids" : token_ids ,
369+ "padding_mask" : padding_mask ,
370+ }
371+
372+ def _normalize_generate_inputs (
373+ self ,
374+ inputs ,
375+ ):
376+ """Normalize user input to the generate function.
377+
378+ This function coverts all inputs to tensors, adds a batch dimension if
379+ necessary, and returns a iterable "dataset like" object (either an
380+ actual `tf.data.Dataset` or a list with a single batch element).
381+ """
382+ input_is_scalar = False
383+
384+ if isinstance (inputs , tf .data .Dataset ):
385+ return inputs , input_is_scalar
386+
387+ if isinstance (inputs , str ) or isinstance (inputs , list ):
388+ inputs = tf .convert_to_tensor (inputs )
389+
390+ if isinstance (inputs , tf .Tensor ) and inputs .shape .rank == 0 :
391+ input_is_scalar = True
392+ inputs = inputs [tf .newaxis ]
393+
394+ # We avoid coverting to a dataset purely for speed, for a single batch
395+ # of input, creating a dataset would add significant overhead.
396+ return [inputs ], input_is_scalar
397+
398+ def _normalize_generate_outputs (
399+ self ,
400+ outputs ,
401+ input_is_scalar ,
402+ ):
403+ """Normalize user output from the generate function.
404+
405+ This function converts all output to numpy (for integer output), or
406+ python strings (for string output). If a batch dimension was added to
407+ the input, it is removed from the output (so generate can be string in,
408+ string out).
409+ """
410+
411+ def normalize (x ):
412+ x = tf .concat (x , axis = 0 )
413+ x = tf .squeeze (x , 0 ) if input_is_scalar else x
414+ is_string = x .dtype == tf .string
415+ # Convert outputs to a friendly pythonic type. For numerical outputs
416+ # that is numpy, for string outputs that is `list` and `str`.
417+ return tensor_to_string_list (x ) if is_string else x .numpy ()
418+
419+ if isinstance (outputs [0 ], dict ):
420+ return {
421+ "token_ids" : normalize ([x ["token_ids" ] for x in outputs ]),
422+ "padding_mask" : normalize ([x ["padding_mask" ] for x in outputs ]),
423+ }
424+ return normalize ([x for x in outputs ])
425+
360426 def generate (
361427 self ,
362428 inputs ,
@@ -367,14 +433,14 @@ def generate(
367433 This method generates text based on given `inputs`. The sampling method
368434 used for generation can be set in the `compile` method.
369435
370- If `inputs` is a `tf.data.Dataset`, outputs will be generated
436+ If `inputs` are a `tf.data.Dataset`, outputs will be generated
371437 "batch-by-batch" and concatenated. Otherwise, all inputs will be handled
372438 as a single batch.
373439
374440 If a `preprocessor` is attached to the model, `inputs` should be
375441 strings and returned sequences will be strings. Otherwise, inputs should
376- be preprocessed into token ids before calling `generate()`, and returned
377- sequences will also be token ids.
442+ be preprocessed before calling `generate()`, and returned sequences will
443+ be token ids.
378444
379445 Args:
380446 inputs: a string `tf.Tensor`, a `tf.data.Dataset` of strings, a
@@ -383,73 +449,52 @@ def generate(
383449 `tf.Tensor` or `tf.data.Dataset` with keys `"token_ids"` and
384450 `"padding_mask"`.
385451 max_length: Optional. int. The max length of the generated sequence.
386- Will default to the configured `sequence_length` of the
452+ Will default to the max configured `sequence_length` of the
387453 `preprocessor`. If `preprocessor` is `None`, `inputs` should be
388- padded to the desired max length and this argument is ignored.
454+ should be padded to the desired maximum length and this argument
455+ will be ignored.
389456
390457 Returns:
391458 A string or string list if `preprocessor` is set, and a integer
392- tensor of token ids if `preprocessor is None`.
459+ tensor of token IDs if `preprocessor is None`.
393460 """
394- input_is_scalar = False
395-
461+ # Setup our three main passes.
462+ # 1. Optionally preprocessing strings to dense integer tensors.
463+ # 2. Generate new tokens via a compiled function on dense tensors.
464+ # 3. Optionally postprocess dense integer tensors back to string.
465+ generate_function = self .make_generate_function ()
466+ end_token_id = None
396467 if self .preprocessor is not None :
468+ end_token_id = self .preprocessor .tokenizer .end_token_id
397469
398- def preprocess (x ):
399- return self .preprocessor (
400- x ,
401- sequence_length = max_length ,
402- return_labels = False ,
403- # We do not append an end token by default during
404- # generation, as generating directly in the same sequence is
405- # the most common workflow. If an end token directly after
406- # a prompt is desired, it can be added to the prompt string.
407- add_end_token = False ,
408- )
409-
410- if not isinstance (inputs , tf .data .Dataset ):
411- inputs = tf .convert_to_tensor (inputs )
412- input_is_scalar = inputs .shape .rank == 0
413- inputs = inputs [tf .newaxis ] if input_is_scalar else inputs
414- # Wrap a list to avoid the overhead of converting to dataset.
415- inputs = [preprocess (inputs )]
416- else :
470+ def preprocess (x ):
471+ return self .preprocessor .generate_preprocess (
472+ x , sequence_length = max_length
473+ )
474+
475+ def generate (x ):
476+ return generate_function (x , end_token_id = end_token_id )
477+
478+ def postprocess (x ):
479+ return self .preprocessor .generate_postprocess (x )
480+
481+ # Normalize inputs, apply our three passes, and normalize outputs.
482+ inputs , input_is_scalar = self ._normalize_generate_inputs (inputs )
483+
484+ if self .preprocessor is not None :
485+ if isinstance (inputs , tf .data .Dataset ):
417486 inputs = inputs .map (preprocess , tf .data .AUTOTUNE )
418487 inputs = inputs .prefetch (tf .data .AUTOTUNE )
419- else :
420- if not isinstance (inputs , tf .data .Dataset ):
421- # Wrap a list to avoid the overhead of converting to dataset.
422- inputs = [inputs ]
488+ else :
489+ # Fast path for non-dataset, single-batch input.
490+ inputs = [preprocess (x ) for x in inputs ]
423491
424- generate_function = self .make_generate_function ()
425- outputs = []
426- for batch in inputs :
427- token_ids , padding_mask = batch ["token_ids" ], batch ["padding_mask" ]
428- # If `preprocessor` is attached, we can stop after `end_token_id``.
429- end_token_id = None
430- if self .preprocessor is not None :
431- end_token_id = self .preprocessor .tokenizer .end_token_id
432- # Run the compiled generate function.
433- output = generate_function (token_ids , padding_mask , end_token_id )
434-
435- if self .preprocessor is not None :
436- # Truncate to ragged by removing tokens after the first
437- # generated `end_token_id`.
438- output = truncate_at_token (output , end_token_id , padding_mask )
439- # Strip start token if added.
440- if self .preprocessor .add_start_token :
441- output = output [:, 1 :]
442- # Detokenize.
443- output = self .preprocessor .tokenizer .detokenize (output )
444- outputs .append (output )
445-
446- outputs = tf .concat (outputs , axis = 0 )
447- outputs = tf .squeeze (outputs , 0 ) if input_is_scalar else outputs
448- # Convert outputs to a friendly pythonic type. For numerical outputs
449- # that is numpy, for string outputs that is `list` and `str`.
450- if outputs .dtype == tf .string :
451- return tensor_to_string_list (outputs )
452- return outputs .numpy ()
492+ outputs = [generate (x ) for x in inputs ]
493+
494+ if self .preprocessor is not None :
495+ outputs = [postprocess (x ) for x in outputs ]
496+
497+ return self ._normalize_generate_outputs (outputs , input_is_scalar )
453498
454499 @classmethod
455500 def create_layout_map (cls , mesh ):
0 commit comments