diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 8f34b0ef3..10233df58 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -137,6 +137,11 @@ def decode_with_fallback(segment: torch.Tensor) -> List[DecodingResult]: all_segments = [] prompt_reset_since = 0 + initial_prompt = decode_options.pop("initial_prompt", None) or [] + if initial_prompt: + initial_prompt = tokenizer.encode(" " + initial_prompt.strip()) + all_tokens.extend(initial_prompt) + def add_segment( *, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult ): @@ -237,7 +242,7 @@ def add_segment( pbar.update(min(num_frames, seek) - previous_seek_value) previous_seek_value = seek - return dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language) + return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language) def cli(): @@ -260,6 +265,7 @@ def cli(): parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default") parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") + parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")