diff --git a/.flake8 b/.flake8 new file mode 100644 index 000000000..c61c928fa --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +per-file-ignores = + */__init__.py: F401 + diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f06bff79a..bca49b575 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,4 +22,7 @@ jobs: - uses: actions/checkout@v2 - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH - run: pip install .["dev"] + - run: black --check --diff -t py38 --include '(\.pyi?)$' . + - run: isort --check --diff . + - run: flake8 --ignore E203,W503,W504,E501,E731,E741 . - run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda' diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..84637eb2e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,8 @@ +[tool.black] + +[tool.isort] +profile = "black" +include_trailing_comma = true +line_length = 88 +multi_line_output = 3 + diff --git a/setup.py b/setup.py index a548c8d35..736c1ac67 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ import sys import pkg_resources -from setuptools import setup, find_packages +from setuptools import find_packages, setup def read_version(fname="whisper/version.py"): @@ -16,7 +16,10 @@ def read_version(fname="whisper/version.py"): try: import re import subprocess - version_line = subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1] + + version_line = ( + subprocess.check_output(["nvcc", "--version"]).strip().split(b"\n")[-1] + ) major, minor = re.findall(rb"([\d]+)\.([\d]+)", version_line)[0] if (int(major), int(minor)) < (11, 4): # the last version supporting CUDA < 11.4 @@ -38,7 +41,8 @@ def read_version(fname="whisper/version.py"): url="https://github.com/openai/whisper", license="MIT", packages=find_packages(exclude=["tests*"]), - install_requires=requirements + [ + install_requires=requirements + + [ str(r) for r in pkg_resources.parse_requirements( open(os.path.join(os.path.dirname(__file__), "requirements.txt")) @@ -48,5 +52,5 @@ def read_version(fname="whisper/version.py"): "console_scripts": ["whisper=whisper.transcribe:cli"], }, include_package_data=True, - extras_require={"dev": ["pytest", "scipy"]}, + extras_require={"dev": ["pytest", "scipy", "black", "flake8", "isort"]}, ) diff --git a/tests/test_audio.py b/tests/test_audio.py index ad6a6dacb..dfd78bc09 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -2,7 +2,7 @@ import numpy as np -from whisper.audio import load_audio, log_mel_spectrogram, SAMPLE_RATE +from whisper.audio import SAMPLE_RATE, load_audio, log_mel_spectrogram def test_audio(): diff --git a/tests/test_normalizer.py b/tests/test_normalizer.py index 3785e6cab..3bc65744c 100644 --- a/tests/test_normalizer.py +++ b/tests/test_normalizer.py @@ -1,7 +1,10 @@ import pytest from whisper.normalizers import EnglishTextNormalizer -from whisper.normalizers.english import EnglishNumberNormalizer, EnglishSpellingNormalizer +from whisper.normalizers.english import ( + EnglishNumberNormalizer, + EnglishSpellingNormalizer, +) @pytest.mark.parametrize("std", [EnglishNumberNormalizer(), EnglishTextNormalizer()]) diff --git a/tests/test_timing.py b/tests/test_timing.py index 50a2583f6..9bab838e8 100644 --- a/tests/test_timing.py +++ b/tests/test_timing.py @@ -1,16 +1,21 @@ -import pytest import numpy as np +import pytest import scipy.ndimage import torch from whisper.timing import dtw_cpu, dtw_cuda, median_filter - sizes = [ - (10, 20), (32, 16), (123, 1500), (234, 189), + (10, 20), + (32, 16), + (123, 1500), + (234, 189), ] shapes = [ - (10,), (1, 15), (4, 5, 345), (6, 12, 240, 512), + (10,), + (1, 15), + (4, 5, 345), + (6, 12, 240, 512), ] @@ -68,8 +73,12 @@ def test_median_filter(shape): # using np.pad to reflect-pad, because Scipy's behavior is different near the edges. pad_width = filter_width // 2 - padded_x = np.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect") - scipy_filtered = scipy.ndimage.median_filter(padded_x, [1] * (x.ndim - 1) + [filter_width]) + padded_x = np.pad( + x, [(0, 0)] * (x.ndim - 1) + [(pad_width, pad_width)], mode="reflect" + ) + scipy_filtered = scipy.ndimage.median_filter( + padded_x, [1] * (x.ndim - 1) + [filter_width] + ) scipy_filtered = scipy_filtered[..., pad_width:-pad_width] assert np.allclose(filtered, scipy_filtered) diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 9802f734b..e5d530783 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -13,7 +13,9 @@ def test_transcribe(model_name: str): audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") language = "en" if model_name.endswith(".en") else None - result = model.transcribe(audio_path, language=language, temperature=0.0, word_timestamps=True) + result = model.transcribe( + audio_path, language=language, temperature=0.0, word_timestamps=True + ) assert result["language"] == "en" transcription = result["text"].lower() diff --git a/whisper/__init__.py b/whisper/__init__.py index 26d1e0eaf..379133b6a 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -10,11 +10,10 @@ from .audio import load_audio, log_mel_spectrogram, pad_or_trim from .decoding import DecodingOptions, DecodingResult, decode, detect_language -from .model import Whisper, ModelDimensions +from .model import ModelDimensions, Whisper from .transcribe import transcribe from .version import __version__ - _MODELS = { "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", @@ -41,12 +40,11 @@ "medium.en": b"ABzY8usPae0{>%R7%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj', - "large": b'ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj', + "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", + "large": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", } - def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: os.makedirs(root, exist_ok=True) @@ -62,10 +60,18 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: return model_bytes if in_memory else download_target else: - warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + warnings.warn( + f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" + ) with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: - with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: + with tqdm( + total=int(source.info().get("Content-Length")), + ncols=80, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as loop: while True: buffer = source.read(8192) if not buffer: @@ -76,7 +82,9 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: model_bytes = open(download_target, "rb").read() if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: - raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.") + raise RuntimeError( + "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." + ) return model_bytes if in_memory else download_target @@ -86,7 +94,12 @@ def available_models() -> List[str]: return list(_MODELS.keys()) -def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper: +def load_model( + name: str, + device: Optional[Union[str, torch.device]] = None, + download_root: str = None, + in_memory: bool = False, +) -> Whisper: """ Load a Whisper ASR model @@ -111,15 +124,8 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" if download_root is None: - download_root = os.path.join( - os.getenv( - "XDG_CACHE_HOME", - os.path.join( - os.path.expanduser("~"), ".cache" - ) - ), - "whisper" - ) + default = os.path.join(os.path.expanduser("~"), ".cache") + download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") if name in _MODELS: checkpoint_file = _download(_MODELS[name], download_root, in_memory) @@ -128,9 +134,13 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow checkpoint_file = open(name, "rb").read() if in_memory else name alignment_heads = None else: - raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + raise RuntimeError( + f"Model {name} not found; available models = {available_models()}" + ) - with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp: + with ( + io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") + ) as fp: checkpoint = torch.load(fp, map_location=device) del checkpoint_file diff --git a/whisper/__main__.py b/whisper/__main__.py index bc9b04a39..d14f2058e 100644 --- a/whisper/__main__.py +++ b/whisper/__main__.py @@ -1,4 +1,3 @@ from .transcribe import cli - cli() diff --git a/whisper/audio.py b/whisper/audio.py index 964d41579..a19b7ab0d 100644 --- a/whisper/audio.py +++ b/whisper/audio.py @@ -16,11 +16,13 @@ HOP_LENGTH = 160 CHUNK_LENGTH = 30 N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk -N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input +N_FRAMES = exact_div( + N_SAMPLES, HOP_LENGTH +) # 3000: number of frames in a mel spectrogram input N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 -FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 100 mel frames in 1s (10ms each) -TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 50 audio tokens in 1s (20ms each) +FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame +TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token def load_audio(file: str, sr: int = SAMPLE_RATE): @@ -59,7 +61,9 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): """ if torch.is_tensor(array): if array.shape[axis] > length: - array = array.index_select(dim=axis, index=torch.arange(length, device=array.device)) + array = array.index_select( + dim=axis, index=torch.arange(length, device=array.device) + ) if array.shape[axis] < length: pad_widths = [(0, 0)] * array.ndim @@ -89,11 +93,15 @@ def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: ) """ assert n_mels == 80, f"Unsupported n_mels: {n_mels}" - with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f: + with np.load( + os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") + ) as f: return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) -def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS): +def log_mel_spectrogram( + audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS +): """ Compute the log-Mel spectrogram of diff --git a/whisper/decoding.py b/whisper/decoding.py index 7c51f251c..ff9261e04 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -16,7 +16,9 @@ @torch.no_grad() -def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]: +def detect_language( + model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None +) -> Tuple[Tensor, List[dict]]: """ Detect the spoken language in the audio, and return them as list of strings, along with the ids of the most probable language tokens and the probability distribution over all language tokens. @@ -31,8 +33,13 @@ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) """ if tokenizer is None: tokenizer = get_tokenizer(model.is_multilingual) - if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence: - raise ValueError(f"This model doesn't have language tokens so it can't perform lang id") + if ( + tokenizer.language is None + or tokenizer.language_token not in tokenizer.sot_sequence + ): + raise ValueError( + "This model doesn't have language tokens so it can't perform lang id" + ) single = mel.ndim == 2 if single: @@ -70,31 +77,36 @@ def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) @dataclass(frozen=True) class DecodingOptions: - task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate" - language: Optional[str] = None # language that the audio is in; uses detected language if None + # whether to perform X->X "transcribe" or X->English "translate" + task: str = "transcribe" + + # language that the audio is in; uses detected language if None + language: Optional[str] = None # sampling-related options temperature: float = 0.0 sample_len: Optional[int] = None # maximum number of tokens to sample - best_of: Optional[int] = None # number of independent samples to collect, when t > 0 - beam_size: Optional[int] = None # number of beams in beam search, when t == 0 - patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424) + best_of: Optional[int] = None # number of independent sample trajectories, if t > 0 + beam_size: Optional[int] = None # number of beams in beam search, if t == 0 + patience: Optional[float] = None # patience in beam search (arxiv:2204.05424) - # options for ranking generations (either beams or best-of-N samples) - length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm + # "alpha" in Google NMT, or None for length norm, when ranking generations + # to select which to return among the beams or best-of-N samples + length_penalty: Optional[float] = None - # prompt, prefix, and token suppression - prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context - prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context - suppress_blank: bool = True # this will suppress blank outputs + # text or tokens to feed as the prompt or the prefix; for more info: + # https://github.com/openai/whisper/discussions/117#discussioncomment-3727051 + prompt: Optional[Union[str, List[int]]] = None # for the previous context + prefix: Optional[Union[str, List[int]]] = None # to prefix the current context # list of tokens ids (or comma-separated token ids) to suppress # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()` suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1" + suppress_blank: bool = True # this will suppress blank outputs # timestamp sampling options - without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only - max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this + without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only + max_initial_timestamp: Optional[float] = 1.0 # implementation details fp16: bool = True # use fp16 for most of the calculation @@ -158,7 +170,9 @@ def rearrange_kv_cache(self, source_indices): class SequenceRanker: - def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]: + def rank( + self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]] + ) -> List[int]: """ Given a list of groups of samples and their cumulative log probabilities, return the indices of the samples in each group to select as the final result @@ -196,7 +210,9 @@ class TokenDecoder: def reset(self): """Initialize any stateful variables for decoding a new sequence""" - def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: + def update( + self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor + ) -> Tuple[Tensor, bool]: """Specify how to select the next token, based on the current trace and logits Parameters @@ -251,7 +267,9 @@ def __init__(self, temperature: float, eot: int): self.temperature = temperature self.eot = eot - def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: + def update( + self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor + ) -> Tuple[Tensor, bool]: if self.temperature == 0: next_tokens = logits.argmax(dim=-1) else: @@ -274,7 +292,13 @@ def finalize(self, tokens: Tensor, sum_logprobs: Tensor): class BeamSearchDecoder(TokenDecoder): - def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None): + def __init__( + self, + beam_size: int, + eot: int, + inference: Inference, + patience: Optional[float] = None, + ): self.beam_size = beam_size self.eot = eot self.inference = inference @@ -282,12 +306,16 @@ def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Opt self.max_candidates: int = round(beam_size * self.patience) self.finished_sequences = None - assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})" + assert ( + self.max_candidates > 0 + ), f"Invalid beam size ({beam_size}) or patience ({patience})" def reset(self): self.finished_sequences = None - def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: + def update( + self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor + ) -> Tuple[Tensor, bool]: if tokens.shape[0] % self.beam_size != 0: raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0") @@ -331,7 +359,9 @@ def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[ # add newly finished sequences to self.finished_sequences assert len(self.finished_sequences) == len(finished_sequences) - for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences): + for previously_finished, newly_finished in zip( + self.finished_sequences, finished_sequences + ): for seq in sorted(newly_finished, key=newly_finished.get, reverse=True): if len(previously_finished) >= self.max_candidates: break # the candidate list is full @@ -339,7 +369,8 @@ def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[ # mark as completed if all audio has enough number of samples completed = all( - len(sequences) >= self.max_candidates for sequences in self.finished_sequences + len(sequences) >= self.max_candidates + for sequences in self.finished_sequences ) return tokens, completed @@ -347,7 +378,9 @@ def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor): # collect all finished sequences, including patience, and add unfinished ones if not enough sum_logprobs = sum_logprobs.cpu() for i, sequences in enumerate(self.finished_sequences): - if len(sequences) < self.beam_size: # when not enough sequences are finished + if ( + len(sequences) < self.beam_size + ): # when not enough sequences are finished for j in list(np.argsort(sum_logprobs[i]))[::-1]: sequence = preceding_tokens[i, j].tolist() + [self.eot] sequences[tuple(sequence)] = sum_logprobs[i][j].item() @@ -355,7 +388,8 @@ def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor): break tokens: List[List[Tensor]] = [ - [torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences + [torch.tensor(seq) for seq in sequences.keys()] + for sequences in self.finished_sequences ] sum_logprobs: List[List[float]] = [ list(sequences.values()) for sequences in self.finished_sequences @@ -399,7 +433,10 @@ def apply(self, logits: Tensor, tokens: Tensor): class ApplyTimestampRules(LogitFilter): def __init__( - self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int] + self, + tokenizer: Tokenizer, + sample_begin: int, + max_initial_timestamp_index: Optional[int], ): self.tokenizer = tokenizer self.sample_begin = sample_begin @@ -414,8 +451,12 @@ def apply(self, logits: Tensor, tokens: Tensor): for k in range(tokens.shape[0]): sampled_tokens = tokens[k, self.sample_begin :] seq = [t for t in sampled_tokens.tolist()] - last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin - penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin + last_was_timestamp = ( + len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin + ) + penultimate_was_timestamp = ( + len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin + ) if last_was_timestamp: if penultimate_was_timestamp: # has to be non-timestamp @@ -423,7 +464,9 @@ def apply(self, logits: Tensor, tokens: Tensor): else: # cannot be normal text tokens logits[k, : self.tokenizer.eot] = -np.inf - timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)] + timestamps = sampled_tokens[ + sampled_tokens.ge(self.tokenizer.timestamp_begin) + ] if timestamps.numel() > 0: # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf @@ -434,13 +477,17 @@ def apply(self, logits: Tensor, tokens: Tensor): # apply the `max_initial_timestamp` option if self.max_initial_timestamp_index is not None: - last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index + last_allowed = ( + self.tokenizer.timestamp_begin + self.max_initial_timestamp_index + ) logits[:, last_allowed + 1 :] = -np.inf # if sum of probability over timestamps is above any other token, sample timestamp logprobs = F.log_softmax(logits.float(), dim=-1) for k in range(tokens.shape[0]): - timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1) + timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp( + dim=-1 + ) max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max() if timestamp_logprob > max_text_token_logprob: logits[k, : self.tokenizer.timestamp_begin] = -np.inf @@ -456,7 +503,9 @@ def __init__(self, model: "Whisper", options: DecodingOptions): self.model = model language = options.language or "en" - tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task) + tokenizer = get_tokenizer( + model.is_multilingual, language=language, task=options.task + ) self.tokenizer: Tokenizer = tokenizer self.options: DecodingOptions = self._verify_options(options) @@ -496,9 +545,13 @@ def __init__(self, model: "Whisper", options: DecodingOptions): precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds max_initial_timestamp_index = None if options.max_initial_timestamp: - max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision) + max_initial_timestamp_index = round( + self.options.max_initial_timestamp / precision + ) self.logit_filters.append( - ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index) + ApplyTimestampRules( + tokenizer, self.sample_begin, max_initial_timestamp_index + ) ) def _verify_options(self, options: DecodingOptions) -> DecodingOptions: @@ -509,7 +562,9 @@ def _verify_options(self, options: DecodingOptions) -> DecodingOptions: raise ValueError("best_of with greedy sampling (T=0) is not compatible") if options.patience is not None and options.beam_size is None: raise ValueError("patience requires beam_size to be given") - if options.length_penalty is not None and not (0 <= options.length_penalty <= 1): + if options.length_penalty is not None and not ( + 0 <= options.length_penalty <= 1 + ): raise ValueError("length_penalty (alpha) should be a value between 0 and 1") return options @@ -519,7 +574,9 @@ def _get_initial_tokens(self) -> Tuple[int]: if prefix := self.options.prefix: prefix_tokens = ( - self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix + self.tokenizer.encode(" " + prefix.strip()) + if isinstance(prefix, str) + else prefix ) if self.sample_len is not None: max_prefix_len = self.n_ctx // 2 - self.sample_len @@ -528,9 +585,15 @@ def _get_initial_tokens(self) -> Tuple[int]: if prompt := self.options.prompt: prompt_tokens = ( - self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt + self.tokenizer.encode(" " + prompt.strip()) + if isinstance(prompt, str) + else prompt + ) + tokens = ( + [self.tokenizer.sot_prev] + + prompt_tokens[-(self.n_ctx // 2 - 1) :] + + tokens ) - tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens return tuple(tokens) @@ -554,7 +617,7 @@ def _get_suppress_tokens(self) -> Tuple[int]: self.tokenizer.translate, self.tokenizer.sot, self.tokenizer.sot_prev, - self.tokenizer.sot_lm + self.tokenizer.sot_lm, ] ) if self.tokenizer.no_speech is not None: @@ -567,14 +630,21 @@ def _get_audio_features(self, mel: Tensor): if self.options.fp16: mel = mel.half() - if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state): + if mel.shape[-2:] == ( + self.model.dims.n_audio_ctx, + self.model.dims.n_audio_state, + ): # encoded audio features are given; skip audio encoding audio_features = mel else: audio_features = self.model.encoder(mel) - if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32): - return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}") + if audio_features.dtype != ( + torch.float16 if self.options.fp16 else torch.float32 + ): + return TypeError( + f"audio_features has an incorrect dtype: {audio_features.dtype}" + ) return audio_features @@ -583,7 +653,9 @@ def _detect_language(self, audio_features: Tensor, tokens: Tensor): lang_probs = None if self.options.language is None or self.options.task == "lang_id": - lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer) + lang_tokens, lang_probs = self.model.detect_language( + audio_features, self.tokenizer + ) languages = [max(probs, key=probs.get) for probs in lang_probs] if self.options.language is None: tokens[:, self.sot_index + 1] = lang_tokens # write language tokens @@ -600,7 +672,9 @@ def _main_loop(self, audio_features: Tensor, tokens: Tensor): for i in range(self.sample_len): logits = self.inference.logits(tokens, audio_features) - if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs + if ( + i == 0 and self.tokenizer.no_speech is not None + ): # save no_speech_probs probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1) no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() @@ -634,8 +708,12 @@ def run(self, mel: Tensor) -> List[DecodingResult]: languages, language_probs = self._detect_language(audio_features, tokens) if self.options.task == "lang_id": return [ - DecodingResult(audio_features=features, language=language, language_probs=probs) - for features, language, probs in zip(audio_features, languages, language_probs) + DecodingResult( + audio_features=features, language=language, language_probs=probs + ) + for features, language, probs in zip( + audio_features, languages, language_probs + ) ] # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling @@ -656,7 +734,8 @@ def run(self, mel: Tensor) -> List[DecodingResult]: # get the final candidates for each group, and slice between the first sampled token and EOT tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) tokens: List[List[Tensor]] = [ - [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens + [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] + for s in tokens ] # select the top-ranked sample in each group @@ -665,9 +744,18 @@ def run(self, mel: Tensor) -> List[DecodingResult]: texts: List[str] = [tokenizer.decode(t).strip() for t in tokens] sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)] - avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)] + avg_logprobs: List[float] = [ + lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs) + ] - fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs) + fields = ( + texts, + languages, + tokens, + audio_features, + avg_logprobs, + no_speech_probs, + ) if len(set(map(len, fields))) != 1: raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}") @@ -682,12 +770,16 @@ def run(self, mel: Tensor) -> List[DecodingResult]: temperature=self.options.temperature, compression_ratio=compression_ratio(text), ) - for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields) + for text, language, tokens, features, avg_logprob, no_speech_prob in zip( + *fields + ) ] @torch.no_grad() -def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]: +def decode( + model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions() +) -> Union[DecodingResult, List[DecodingResult]]: """ Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). diff --git a/whisper/model.py b/whisper/model.py index a1ab2e349..3457fcfc6 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -1,16 +1,15 @@ import base64 import gzip from dataclasses import dataclass -from typing import Dict -from typing import Iterable, Optional +from typing import Dict, Iterable, Optional import numpy as np import torch import torch.nn.functional as F -from torch import Tensor -from torch import nn +from torch import Tensor, nn -from .decoding import detect_language as detect_language_function, decode as decode_function +from .decoding import decode as decode_function +from .decoding import detect_language as detect_language_function from .transcribe import transcribe as transcribe_function @@ -36,12 +35,16 @@ def forward(self, x: Tensor) -> Tensor: class Linear(nn.Linear): def forward(self, x: Tensor) -> Tensor: return F.linear( - x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype) + x, + self.weight.to(x.dtype), + None if self.bias is None else self.bias.to(x.dtype), ) class Conv1d(nn.Conv1d): - def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: + def _conv_forward( + self, x: Tensor, weight: Tensor, bias: Optional[Tensor] + ) -> Tensor: return super()._conv_forward( x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) ) @@ -87,7 +90,9 @@ def forward( wv, qk = self.qkv_attention(q, k, v, mask) return self.out(wv), qk - def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None): + def qkv_attention( + self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None + ): n_batch, n_ctx, n_state = q.shape scale = (n_state // self.n_head) ** -0.25 q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale @@ -110,11 +115,15 @@ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): self.attn = MultiHeadAttention(n_state, n_head) self.attn_ln = LayerNorm(n_state) - self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None + self.cross_attn = ( + MultiHeadAttention(n_state, n_head) if cross_attention else None + ) self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None n_mlp = n_state * 4 - self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)) + self.mlp = nn.Sequential( + Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state) + ) self.mlp_ln = LayerNorm(n_state) def forward( @@ -132,7 +141,9 @@ def forward( class AudioEncoder(nn.Module): - def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): + def __init__( + self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int + ): super().__init__() self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) @@ -163,14 +174,19 @@ def forward(self, x: Tensor): class TextDecoder(nn.Module): - def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): + def __init__( + self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int + ): super().__init__() self.token_embedding = nn.Embedding(n_vocab, n_state) self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( - [ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)] + [ + ResidualAttentionBlock(n_state, n_head, cross_attention=True) + for _ in range(n_layer) + ] ) self.ln = LayerNorm(n_state) @@ -185,14 +201,19 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): the encoded audio features to be attended on """ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 - x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]] + x = ( + self.token_embedding(x) + + self.positional_embedding[offset : offset + x.shape[-1]] + ) x = x.to(xa.dtype) for block in self.blocks: x = block(x, xa, mask=self.mask, kv_cache=kv_cache) x = self.ln(x) - logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float() + logits = ( + x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) + ).float() return logits @@ -216,13 +237,19 @@ def __init__(self, dims: ModelDimensions): self.dims.n_text_layer, ) # use the last half layers for alignment by default; see `set_alignment_heads()` below - all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool) - all_heads[self.dims.n_text_layer // 2:] = True + all_heads = torch.zeros( + self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool + ) + all_heads[self.dims.n_text_layer // 2 :] = True self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) def set_alignment_heads(self, dump: bytes): - array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy() - mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head) + array = np.frombuffer( + gzip.decompress(base64.b85decode(dump)), dtype=bool + ).copy() + mask = torch.from_numpy(array).reshape( + self.dims.n_text_layer, self.dims.n_text_head + ) self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False) def embed_audio(self, mel: torch.Tensor): @@ -231,7 +258,9 @@ def embed_audio(self, mel: torch.Tensor): def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): return self.decoder(tokens, audio_features) - def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]: + def forward( + self, mel: torch.Tensor, tokens: torch.Tensor + ) -> Dict[str, torch.Tensor]: return self.decoder(tokens, self.encoder(mel)) @property @@ -260,8 +289,9 @@ def install_kv_cache_hooks(self, cache: Optional[dict] = None): hooks = [] def save_to_cache(module, _, output): - if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]: - cache[module] = output # save as-is, for the first token or cross attention + if module not in cache or output.shape[1] > self.dims.n_text_ctx: + # save as-is, for the first token or cross attention + cache[module] = output else: cache[module] = torch.cat([cache[module], output], dim=1).detach() return cache[module] diff --git a/whisper/normalizers/__init__.py b/whisper/normalizers/__init__.py index 0b10d5a9d..896d5e336 100644 --- a/whisper/normalizers/__init__.py +++ b/whisper/normalizers/__init__.py @@ -1,2 +1,2 @@ -from .basic import BasicTextNormalizer -from .english import EnglishTextNormalizer +from .basic import BasicTextNormalizer as BasicTextNormalizer +from .english import EnglishTextNormalizer as EnglishTextNormalizer diff --git a/whisper/normalizers/basic.py b/whisper/normalizers/basic.py index ef8d249bf..a82403203 100644 --- a/whisper/normalizers/basic.py +++ b/whisper/normalizers/basic.py @@ -48,13 +48,16 @@ def remove_symbols(s: str): Replace any other markers, symbols, punctuations with a space, keeping diacritics """ return "".join( - " " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s) + " " if unicodedata.category(c)[0] in "MSP" else c + for c in unicodedata.normalize("NFKC", s) ) class BasicTextNormalizer: def __init__(self, remove_diacritics: bool = False, split_letters: bool = False): - self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols + self.clean = ( + remove_symbols_and_diacritics if remove_diacritics else remove_symbols + ) self.split_letters = split_letters def __call__(self, s: str): @@ -66,6 +69,8 @@ def __call__(self, s: str): if self.split_letters: s = " ".join(regex.findall(r"\X", s, regex.U)) - s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space + s = re.sub( + r"\s+", " ", s + ) # replace any successive whitespace characters with a space return s diff --git a/whisper/normalizers/english.py b/whisper/normalizers/english.py index d5c2bb4eb..4932042bc 100644 --- a/whisper/normalizers/english.py +++ b/whisper/normalizers/english.py @@ -84,7 +84,8 @@ def __init__(self): name.replace("y", "ies"): (value, "s") for name, value in self.tens.items() } self.tens_ordinal = { - name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items() + name.replace("y", "ieth"): (value, "th") + for name, value in self.tens.items() } self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal} @@ -108,7 +109,10 @@ def __init__(self): self.multipliers_ordinal = { name + "th": (value, "th") for name, value in self.multipliers.items() } - self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal} + self.multipliers_suffixed = { + **self.multipliers_plural, + **self.multipliers_ordinal, + } self.decimals = {*self.ones, *self.tens, *self.zeros} self.preceding_prefixers = { @@ -128,7 +132,8 @@ def __init__(self): "cents": "¢", } self.prefixes = set( - list(self.preceding_prefixers.values()) + list(self.following_prefixers.values()) + list(self.preceding_prefixers.values()) + + list(self.following_prefixers.values()) ) self.suffixers = { "per": {"cent": "%"}, @@ -218,7 +223,9 @@ def output(result: Union[str, int]): if value is None: value = ones elif isinstance(value, str) or prev in self.ones: - if prev in self.tens and ones < 10: # replace the last zero with the digit + if ( + prev in self.tens and ones < 10 + ): # replace the last zero with the digit assert value[-1] == "0" value = value[:-1] + str(ones) else: @@ -522,14 +529,14 @@ def __call__(self, s: str): s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis s = re.sub(self.ignore_patterns, "", s) - s = re.sub(r"\s+'", "'", s) # standardize when there's a space before an apostrophe + s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe for pattern, replacement in self.replacers.items(): s = re.sub(pattern, replacement, s) s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers - s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep some symbols for numerics + s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols s = self.standardize_numbers(s) s = self.standardize_spellings(s) @@ -538,6 +545,6 @@ def __call__(self, s: str): s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s) s = re.sub(r"([^0-9])%", r"\1 ", s) - s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space + s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space return s diff --git a/whisper/timing.py b/whisper/timing.py index 98927aa04..39f38728a 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -1,7 +1,7 @@ import subprocess import warnings from dataclasses import dataclass -from typing import List, TYPE_CHECKING +from typing import TYPE_CHECKING, List import numba import numpy as np @@ -26,13 +26,16 @@ def median_filter(x: torch.Tensor, filter_width: int): # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D x = x[None, None, :] - assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number" + assert ( + filter_width > 0 and filter_width % 2 == 1 + ), "`filter_width` should be an odd number" result = None x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect") if x.is_cuda: try: from .triton_ops import median_filter_cuda + result = median_filter_cuda(x, filter_width) except (RuntimeError, subprocess.CalledProcessError): warnings.warn( @@ -49,6 +52,7 @@ def median_filter(x: torch.Tensor, filter_width: int): return result + @numba.jit def backtrace(trace: np.ndarray): i = trace.shape[0] - 1 @@ -106,7 +110,9 @@ def dtw_cuda(x, BLOCK_SIZE=1024): M, N = x.shape assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}" - x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M) + x_skew = ( + F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M) + ) x_skew = x_skew.T.contiguous() cost = torch.ones(N + M + 2, M + 2) * np.inf cost[0, 0] = 0 @@ -122,10 +128,12 @@ def dtw_cuda(x, BLOCK_SIZE=1024): trace.stride(0), N, M, - BLOCK_SIZE=BLOCK_SIZE + BLOCK_SIZE=BLOCK_SIZE, ) - trace = trace.T.flatten()[:(M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, :N + 1] + trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[ + :, : N + 1 + ] return backtrace(trace.cpu().numpy()) @@ -181,8 +189,10 @@ def find_alignment( with torch.no_grad(): logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] - token_probs = logits[len(tokenizer.sot_sequence):, :tokenizer.eot].softmax(dim=-1) - text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens].tolist() + sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot] + token_probs = sampled_logits.softmax(dim=-1) + text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens] + text_token_probs = text_token_probs.tolist() for hook in hooks: hook.remove() @@ -196,7 +206,7 @@ def find_alignment( weights = median_filter(weights, medfilt_width) matrix = weights.mean(axis=0) - matrix = matrix[len(tokenizer.sot_sequence):-1] + matrix = matrix[len(tokenizer.sot_sequence) : -1] text_indices, time_indices = dtw(-matrix) words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot]) @@ -207,7 +217,8 @@ def find_alignment( start_times = jump_times[word_boundaries[:-1]] end_times = jump_times[word_boundaries[1:]] word_probabilities = [ - np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) + np.mean(text_token_probs[i:j]) + for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) ] # hack: ensure the first and second word is not longer than twice the median word duration. @@ -218,7 +229,8 @@ def find_alignment( median_duration = np.median(word_durations) max_duration = median_duration * 2 if len(word_durations) >= 2 and word_durations[1] > max_duration: - end_times[0] = start_times[1] = max(end_times[2] / 2, end_times[2] - max_duration) + boundary = max(end_times[2] / 2, end_times[2] - max_duration) + end_times[0] = start_times[1] = boundary if len(word_durations) >= 1 and end_times[0] - start_times[0] > max_duration: start_times[0] = max(0, end_times[0] - max_duration) @@ -271,19 +283,20 @@ def add_word_timestamps( tokenizer: Tokenizer, mel: torch.Tensor, num_frames: int, - prepend_punctuations: str = "\"\'“¿([{-", - append_punctuations: str = "\"\'.。,,!!??::”)]}、", - **hyperparams, + prepend_punctuations: str = "\"'“¿([{-", + append_punctuations: str = "\"'.。,,!!??::”)]}、", + **kwargs, ): if len(segments) == 0: return text_tokens = [t for segment in segments for t in segment["tokens"]] - alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **hyperparams) + alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs) merge_punctuations(alignment, prepend_punctuations, append_punctuations) time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE - token_sources = np.repeat(np.arange(len(segments)), [len(s["tokens"]) for s in segments]) + segment_lengths = [len(s["tokens"]) for s in segments] + token_sources = np.repeat(np.arange(len(segments)), segment_lengths) for segment in segments: segment["words"] = [] @@ -295,7 +308,12 @@ def add_word_timestamps( start = round(time_offset + timing.start, 2) end = round(time_offset + timing.end, 2) segment["words"].append( - dict(word=timing.word, start=start, end=end, probability=timing.probability) + dict( + word=timing.word, + start=start, + end=end, + probability=timing.probability, + ) ) for segment in segments: diff --git a/whisper/tokenizer.py b/whisper/tokenizer.py index ea1117f76..cacb18d3c 100644 --- a/whisper/tokenizer.py +++ b/whisper/tokenizer.py @@ -1,7 +1,7 @@ import os import string from dataclasses import dataclass -from functools import lru_cache, cached_property +from functools import cached_property, lru_cache from typing import List, Optional, Tuple, Union import numpy as np @@ -138,7 +138,9 @@ class Tokenizer: def encode(self, text, **kwargs): return self.tokenizer.encode(text, **kwargs) - def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs): + def decode( + self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs + ): return self.tokenizer.decode(token_ids, **kwargs) def decode_with_timestamps(self, tokens) -> str: @@ -154,8 +156,9 @@ def decode_with_timestamps(self, tokens) -> str: outputs.append([]) else: outputs[-1].append(token) - outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs] - return "".join(outputs) + return "".join( + [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs] + ) @cached_property def eot(self) -> int: @@ -197,7 +200,7 @@ def timestamp_begin(self) -> int: def language_token(self) -> int: """Returns the token id corresponding to the value of the `language` field""" if self.language is None: - raise ValueError(f"This tokenizer does not have language token configured") + raise ValueError("This tokenizer does not have language token configured") additional_tokens = dict( zip( @@ -242,8 +245,10 @@ def non_speech_tokens(self) -> Tuple[int]: keeping basic punctuations like commas, periods, question marks, exclamation points, etc. """ - symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』") - symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() + symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』') + symbols += ( + "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() + ) # symbols that may be a single token or multiple tokens depending on the tokenizer. # In case they're multiple tokens, suppress the first token, which is safe because: @@ -255,7 +260,10 @@ def non_speech_tokens(self) -> Tuple[int]: # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]} for symbol in symbols + list(miscellaneous): - for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]: + for tokens in [ + self.tokenizer.encode(symbol), + self.tokenizer.encode(" " + symbol), + ]: if len(tokens) == 1 or symbol in miscellaneous: result.add(tokens[0]) @@ -367,4 +375,6 @@ def get_tokenizer( if task is not None: sot_sequence.append(transcribe if task == "transcribe" else translate) - return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)) + return Tokenizer( + tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence) + ) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index d7c048749..20f01477e 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -1,17 +1,32 @@ import argparse import os import warnings -from typing import Optional, Tuple, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Tuple, Union import numpy as np import torch import tqdm -from .audio import HOP_LENGTH, N_FRAMES, SAMPLE_RATE, FRAMES_PER_SECOND, log_mel_spectrogram, pad_or_trim +from .audio import ( + FRAMES_PER_SECOND, + HOP_LENGTH, + N_FRAMES, + SAMPLE_RATE, + log_mel_spectrogram, + pad_or_trim, +) from .decoding import DecodingOptions, DecodingResult from .timing import add_word_timestamps from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer -from .utils import exact_div, format_timestamp, make_safe, optional_int, optional_float, str2bool, get_writer +from .utils import ( + exact_div, + format_timestamp, + get_writer, + make_safe, + optional_float, + optional_int, + str2bool, +) if TYPE_CHECKING: from .model import Whisper @@ -29,8 +44,8 @@ def transcribe( condition_on_previous_text: bool = True, initial_prompt: Optional[str] = None, word_timestamps: bool = False, - prepend_punctuations: str = "\"\'“¿([{-", - append_punctuations: str = "\"\'.。,,!!??::”)]}、", + prepend_punctuations: str = "\"'“¿([{-", + append_punctuations: str = "\"'.。,,!!??::”)]}、", **decode_options, ): """ @@ -108,12 +123,16 @@ def transcribe( decode_options["language"] = "en" else: if verbose: - print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language") + print( + "Detecting language using up to the first 30 seconds. Use `--language` to specify the language" + ) mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype) _, probs = model.detect_language(mel_segment) decode_options["language"] = max(probs, key=probs.get) if verbose is not None: - print(f"Detected language: {LANGUAGES[decode_options['language']].title()}") + print( + f"Detected language: {LANGUAGES[decode_options['language']].title()}" + ) language: str = decode_options["language"] task: str = decode_options.get("task", "transcribe") @@ -123,7 +142,9 @@ def transcribe( warnings.warn("Word-level timestamps on translations may not be reliable.") def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: - temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature + temperatures = ( + [temperature] if isinstance(temperature, (int, float)) else temperature + ) decode_result = None for t in temperatures: @@ -140,9 +161,15 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: decode_result = model.decode(segment, options) needs_fallback = False - if compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold: + if ( + compression_ratio_threshold is not None + and decode_result.compression_ratio > compression_ratio_threshold + ): needs_fallback = True # too repetitive - if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold: + if ( + logprob_threshold is not None + and decode_result.avg_logprob < logprob_threshold + ): needs_fallback = True # average log probability is too low if not needs_fallback: @@ -186,7 +213,9 @@ def new_segment( # show the progress bar when verbose is False (if True, transcribed text will be printed) num_frames = mel.shape[-1] - with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar: + with tqdm.tqdm( + total=num_frames, unit="frames", disable=verbose is not False + ) as pbar: while seek < num_frames: time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) mel_segment = mel[:, seek:] @@ -201,7 +230,10 @@ def new_segment( if no_speech_threshold is not None: # no voice activity check should_skip = result.no_speech_prob > no_speech_threshold - if logprob_threshold is not None and result.avg_logprob > logprob_threshold: + if ( + logprob_threshold is not None + and result.avg_logprob > logprob_threshold + ): # don't skip if the logprob is high enough, despite the no_speech_prob should_skip = False @@ -214,22 +246,35 @@ def new_segment( current_tokens = [] timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin) - consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1) - if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens - if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [False, True]: + consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[ + 0 + ].add_(1) + if ( + len(consecutive) > 0 + ): # if the output contains two consecutive timestamp tokens + if ended_with_single_timestamp := timestamp_tokens[-2:].tolist() == [ + False, + True, + ]: consecutive = consecutive.tolist() + [len(tokens)] last_slice = 0 for current_slice in consecutive: sliced_tokens = tokens[last_slice:current_slice] - start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin - end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin - current_segments.append(new_segment( - start=time_offset + start_timestamp_pos * time_precision, - end=time_offset + end_timestamp_pos * time_precision, - tokens=sliced_tokens, - result=result, - )) + start_timestamp_pos = ( + sliced_tokens[0].item() - tokenizer.timestamp_begin + ) + end_timestamp_pos = ( + sliced_tokens[-1].item() - tokenizer.timestamp_begin + ) + current_segments.append( + new_segment( + start=time_offset + start_timestamp_pos * time_precision, + end=time_offset + end_timestamp_pos * time_precision, + tokens=sliced_tokens, + result=result, + ) + ) current_tokens.append(sliced_tokens.tolist()) last_slice = current_slice @@ -238,23 +283,32 @@ def new_segment( seek += segment_size else: # otherwise, ignore the unfinished segment and seek to the last timestamp - last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin + last_timestamp_pos = ( + tokens[last_slice - 1].item() - tokenizer.timestamp_begin + ) seek += last_timestamp_pos * input_stride all_tokens.extend(tokens[: last_slice + 1].tolist()) else: duration = segment_duration timestamps = tokens[timestamp_tokens.nonzero().flatten()] - if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin: + if ( + len(timestamps) > 0 + and timestamps[-1].item() != tokenizer.timestamp_begin + ): # no consecutive timestamps but it has a timestamp; use the last one. - last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin + last_timestamp_pos = ( + timestamps[-1].item() - tokenizer.timestamp_begin + ) duration = last_timestamp_pos * time_precision - current_segments.append(new_segment( - start=time_offset, - end=time_offset + duration, - tokens=tokens, - result=result, - )) + current_segments.append( + new_segment( + start=time_offset, + end=time_offset + duration, + tokens=tokens, + result=result, + ) + ) current_tokens.append(tokens.tolist()) seek += segment_size @@ -272,9 +326,13 @@ def new_segment( prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, ) - word_end_timestamps = [w["end"] for s in current_segments for w in s["words"]] + word_end_timestamps = [ + w["end"] for s in current_segments for w in s["words"] + ] if len(consecutive) > 0 and len(word_end_timestamps) > 0: - seek_shift = round((word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND) + seek_shift = round( + (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND + ) if seek_shift > 0: seek = previous_seek + seek_shift @@ -293,21 +351,24 @@ def new_segment( current_tokens[i] = [] all_segments.extend(current_segments) - all_tokens.extend([token for segment in current_tokens for token in segment]) + all_tokens.extend( + [token for segment in current_tokens for token in segment] + ) # update progress bar pbar.update(min(num_frames, seek) - previous_seek) return dict( - text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]), + text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), segments=all_segments, - language=language + language=language, ) def cli(): from . import available_models + # fmt: off parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") @@ -339,6 +400,7 @@ def cli(): parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word") parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word") parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") + # fmt: on args = parser.parse_args().__dict__ model_name: str = args.pop("model") @@ -350,7 +412,9 @@ def cli(): if model_name.endswith(".en") and args["language"] not in {"en", "English"}: if args["language"] is not None: - warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.") + warnings.warn( + f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead." + ) args["language"] = "en" temperature = args.pop("temperature") @@ -363,6 +427,7 @@ def cli(): torch.set_num_threads(threads) from . import load_model + model = load_model(model_name, device=device, download_root=model_dir) writer = get_writer(output_format, output_dir) @@ -371,5 +436,5 @@ def cli(): writer(result, audio_path) -if __name__ == '__main__': +if __name__ == "__main__": cli() diff --git a/whisper/triton_ops.py b/whisper/triton_ops.py index d829e204a..edd456414 100644 --- a/whisper/triton_ops.py +++ b/whisper/triton_ops.py @@ -1,8 +1,7 @@ -import math +from functools import lru_cache import numpy as np import torch -from functools import lru_cache try: import triton @@ -12,7 +11,9 @@ @triton.jit -def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr): +def dtw_kernel( + cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr +): offsets = tl.arange(0, BLOCK_SIZE) mask = offsets < M @@ -42,37 +43,53 @@ def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_ @lru_cache(maxsize=None) def median_kernel(filter_width: int): @triton.jit - def kernel(y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr): # x.shape[-1] == filter_width + def kernel( + y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr + ): # x.shape[-1] == filter_width row_idx = tl.program_id(0) offsets = tl.arange(0, BLOCK_SIZE) mask = offsets < y_stride - x_ptr = x + row_idx * x_stride + x_ptr = x + row_idx * x_stride # noqa: F841 y_ptr = y + row_idx * y_stride - LOAD_ALL_ROWS_HERE + LOAD_ALL_ROWS_HERE # noqa: F821 - BUBBLESORT_HERE + BUBBLESORT_HERE # noqa: F821 - tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) + tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821 kernel = triton.JITFunction(kernel.fn) - kernel.src = kernel.src.replace(" LOAD_ALL_ROWS_HERE", "\n".join([ - f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)" - for i in range(filter_width) - ])) - kernel.src = kernel.src.replace(" BUBBLESORT_HERE", "\n\n".join([ - "\n\n".join([ - "\n".join([ - f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})", - f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})", - f" row{j} = smaller", - f" row{j + 1} = larger", - ]) - for j in range(filter_width - i - 1) - ]) - for i in range(filter_width // 2 + 1) - ])) + kernel.src = kernel.src.replace( + " LOAD_ALL_ROWS_HERE", + "\n".join( + [ + f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)" + for i in range(filter_width) + ] + ), + ) + kernel.src = kernel.src.replace( + " BUBBLESORT_HERE", + "\n\n".join( + [ + "\n\n".join( + [ + "\n".join( + [ + f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})", + f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})", + f" row{j} = smaller", + f" row{j + 1} = larger", + ] + ) + for j in range(filter_width - i - 1) + ] + ) + for i in range(filter_width // 2 + 1) + ] + ), + ) kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}") return kernel diff --git a/whisper/utils.py b/whisper/utils.py index 8ee912932..7712d3131 100644 --- a/whisper/utils.py +++ b/whisper/utils.py @@ -7,11 +7,14 @@ system_encoding = sys.getdefaultencoding() if system_encoding != "utf-8": + def make_safe(string): # replaces any character not representable using the system default encoding with an '?', # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729). return string.encode(system_encoding, errors="replace").decode(system_encoding) + else: + def make_safe(string): # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding return string @@ -43,7 +46,9 @@ def compression_ratio(text) -> float: return len(text_bytes) / len(zlib.compress(text_bytes)) -def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'): +def format_timestamp( + seconds: float, always_include_hours: bool = False, decimal_marker: str = "." +): assert seconds >= 0, "non-negative timestamp expected" milliseconds = round(seconds * 1000.0) @@ -57,7 +62,9 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal milliseconds -= seconds * 1_000 hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" - return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" + return ( + f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" + ) class ResultWriter: @@ -68,7 +75,9 @@ def __init__(self, output_dir: str): def __call__(self, result: dict, audio_path: str): audio_basename = os.path.basename(audio_path) - output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension) + output_path = os.path.join( + self.output_dir, audio_basename + "." + self.extension + ) with open(output_path, "w", encoding="utf-8") as f: self.write_result(result, file=f) @@ -82,7 +91,7 @@ class WriteTXT(ResultWriter): def write_result(self, result: dict, file: TextIO): for segment in result["segments"]: - print(segment['text'].strip(), file=file, flush=True) + print(segment["text"].strip(), file=file, flush=True) class SubtitlesWriter(ResultWriter): @@ -93,7 +102,7 @@ def iterate_result(self, result: dict): for segment in result["segments"]: segment_start = self.format_timestamp(segment["start"]) segment_end = self.format_timestamp(segment["end"]) - segment_text = segment['text'].strip().replace('-->', '->') + segment_text = segment["text"].strip().replace("-->", "->") if word_timings := segment.get("words", None): all_words = [timing["word"] for timing in word_timings] @@ -106,7 +115,10 @@ def iterate_result(self, result: dict): yield last, start, segment_text yield start, end, "".join( - [f"{word}" if j == i else word for j, word in enumerate(all_words)] + [ + f"{word}" if j == i else word + for j, word in enumerate(all_words) + ] ) last = end @@ -126,7 +138,7 @@ def format_timestamp(self, seconds: float): class WriteVTT(SubtitlesWriter): extension: str = "vtt" always_include_hours: bool = False - decimal_marker: str = '.' + decimal_marker: str = "." def write_result(self, result: dict, file: TextIO): print("WEBVTT\n", file=file) @@ -137,7 +149,7 @@ def write_result(self, result: dict, file: TextIO): class WriteSRT(SubtitlesWriter): extension: str = "srt" always_include_hours: bool = True - decimal_marker: str = ',' + decimal_marker: str = "," def write_result(self, result: dict, file: TextIO): for i, (start, end, text) in enumerate(self.iterate_result(result), start=1): @@ -153,14 +165,15 @@ class WriteTSV(ResultWriter): an environment setting a language encoding that causes the decimal in a floating point number to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++. """ + extension: str = "tsv" def write_result(self, result: dict, file: TextIO): print("start", "end", "text", sep="\t", file=file) for segment in result["segments"]: - print(round(1000 * segment['start']), file=file, end="\t") - print(round(1000 * segment['end']), file=file, end="\t") - print(segment['text'].strip().replace("\t", " "), file=file, flush=True) + print(round(1000 * segment["start"]), file=file, end="\t") + print(round(1000 * segment["end"]), file=file, end="\t") + print(segment["text"].strip().replace("\t", " "), file=file, flush=True) class WriteJSON(ResultWriter): @@ -189,4 +202,3 @@ def write_all(result: dict, file: TextIO): return write_all return writers[output_format](output_dir) -