@@ -160,6 +160,156 @@ def token_probability_fn(inputs):
160160 return prompt
161161
162162
163+ def beam_search (
164+ token_probability_fn ,
165+ prompt ,
166+ max_length ,
167+ num_beams ,
168+ from_logits = False ,
169+ end_token_id = None ,
170+ pad_token_id = 0 ,
171+ ):
172+ """
173+ Text generation utility based on beam search algorithm.
174+
175+ At each time-step, beam search keeps the beams (sequences) of the top
176+ `num_beams` highest accumulated probabilities, and uses each one of the
177+ beams to predict candidate next tokens.
178+
179+ Args:
180+ token_probability_fn: a callable, which takes in input_sequence
181+ and output the probability distribution of the next token. If
182+ `from_logits` set to True, it should output the logits of the next
183+ token. The input shape would be `[batch_size, length]` and the
184+ output should be `[batch_size, vocab_size]`, where batch_size is
185+ variable.
186+ prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to
187+ append generated tokens. The initial beam for beam search.
188+ max_length: int. The max length of generated text.
189+ num_beams: int. The number of beams that should be kept at each
190+ time-step. `num_beams` should be strictly positive.
191+ from_logits: bool. Indicates whether `token_probability_fn` outputs
192+ logits or probabilities.
193+ end_token_id: int, defaults to None. The token marking the end of the
194+ sequence, once encountered the generation is finished for the exact
195+ sequence. If None, every sequence is generated up to `max_length`.
196+ If set, all tokens after encountering `end_token_id` will be
197+ replaced with `pad_token_id`.
198+ pad_token_id: int, defaults to 0. The pad token after `end_token_id`
199+ is received.
200+
201+ Returns:
202+ A 1D int Tensor, or 2D int Tensor representing the generated
203+ sequences.
204+
205+ Examples:
206+ ```python
207+ BATCH_SIZE = 8
208+ VOCAB_SIZE = 10
209+ FEATURE_SIZE = 16
210+ START_ID = 1
211+ END_ID = 2
212+
213+ # Create a dummy model to predict the next token.
214+ model = tf.keras.Sequential(
215+ [
216+ tf.keras.Input(shape=[None]),
217+ tf.keras.layers.Embedding(
218+ input_dim=VOCAB_SIZE,
219+ output_dim=FEATURE_SIZE,
220+ ),
221+ tf.keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
222+ ]
223+ )
224+
225+ # Define a function that outputs the next token's probability given the
226+ # input sequence.
227+ def token_probability_fn(inputs):
228+ return model(inputs)[:, -1, :]
229+
230+ prompt = tf.fill((BATCH_SIZE, 1), START_ID)
231+
232+ # Print the generated sequence (token ids).
233+ keras_nlp.utils.beam_search(
234+ token_probability_fn,
235+ prompt,
236+ max_length=10,
237+ num_beams=5,
238+ end_token_id=END_ID,
239+ )
240+ ```
241+
242+ """
243+ if not tf .executing_eagerly ():
244+ raise RuntimeError (
245+ "`keras_nlp.utils.beam_search` currently requires an eager "
246+ "execution context. Please call `beam_search` outside "
247+ "tf.function or run `tf.config.run_functions_eagerly(True)` to run "
248+ "tf.function in eager mode."
249+ )
250+ if num_beams <= 0 :
251+ raise ValueError (
252+ f"`num_beams` should be strictly positive. Received: `num_beams={ num_beams } `."
253+ )
254+
255+ prompt = validate_prompt (prompt )
256+
257+ input_is_1d = prompt .shape .rank == 1
258+ if input_is_1d :
259+ prompt = prompt [tf .newaxis , :]
260+ validate_token_probability_fn (token_probability_fn , prompt )
261+
262+ batch_size , length = prompt .shape
263+ if length < max_length :
264+ # Initialize beam.
265+ beams = tf .expand_dims (prompt , 1 )
266+ beams_prob = tf .zeros ([batch_size , 1 ])
267+ i = length
268+ while i < max_length :
269+ beam_size = beams .shape [1 ]
270+ beam_preds = []
271+ for j in range (beam_size ):
272+ preds = token_probability_fn (beams [:, j , :])
273+ if from_logits :
274+ preds = tf .keras .activations .softmax (preds , axis = - 1 )
275+ beam_preds .append (preds )
276+ stacked_preds = tf .stack (beam_preds , axis = 1 )
277+ vocab_size = stacked_preds .shape [2 ]
278+ logits = tf .reshape (
279+ stacked_preds , [batch_size , beam_size * vocab_size ]
280+ )
281+ probs = tf .math .log (logits ) + tf .repeat (
282+ beams_prob , repeats = vocab_size , axis = 1
283+ )
284+ num_beams = min (beam_size * vocab_size , num_beams )
285+ candidate_prob , candidate_indexes = tf .math .top_k (
286+ probs , k = num_beams , sorted = False
287+ )
288+ candidate_beam_indexes = candidate_indexes // vocab_size
289+ next_token = candidate_indexes % vocab_size
290+
291+ beams = tf .gather (
292+ beams , candidate_beam_indexes , axis = 1 , batch_dims = 1
293+ )
294+ beams = tf .concat ([beams , next_token [..., tf .newaxis ]], axis = - 1 )
295+ beams_prob = candidate_prob
296+ i += 1
297+ # Get the beam with the maximum probability.
298+ max_indexes = tf .math .argmax (beams_prob , axis = - 1 )
299+ max_beams = tf .gather (
300+ beams , max_indexes [:, tf .newaxis ], axis = 1 , batch_dims = 1
301+ )
302+ prompt = tf .squeeze (max_beams )
303+
304+ if end_token_id is not None :
305+ prompt = mask_tokens_after_end_token (
306+ prompt , max_length , end_token_id , pad_token_id
307+ )
308+ if input_is_1d :
309+ return tf .squeeze (prompt )
310+ return prompt
311+
312+
163313def random_search (
164314 token_probability_fn ,
165315 prompt ,
@@ -361,7 +511,7 @@ def token_probability_fn(inputs):
361511 "tf.function in eager mode."
362512 )
363513 if k <= 0 :
364- raise ValueError (f"`k` should strictly positive. Received: `k={ k } `." )
514+ raise ValueError (f"`k` should be strictly positive. Received: `k={ k } `." )
365515
366516 prompt = validate_prompt (prompt )
367517 input_is_1d = prompt .shape .rank == 1
@@ -378,7 +528,7 @@ def token_probability_fn(inputs):
378528 # If k is greater than the vocabulary size, use the entire vocabulary.
379529 k = min (k , pred .shape [1 ])
380530 # Filter out top-k tokens.
381- top_k_pred , top_k_indices = tf .math .top_k (pred , k = k )
531+ top_k_pred , top_k_indices = tf .math .top_k (pred , k = k , sorted = False )
382532 # Sample the next token from the probability distribution.
383533 next_token = tf .random .categorical (
384534 tf .math .log (top_k_pred ), 1 , seed = seed
0 commit comments