diff --git a/speaker_diarization.txt b/requirements/speaker_diarization.txt similarity index 100% rename from speaker_diarization.txt rename to requirements/speaker_diarization.txt diff --git a/whisperplus/pipelines/mlx_whisper/assets/gpt2.tiktoken b/whisperplus/pipelines/assets/gpt2.tiktoken similarity index 100% rename from whisperplus/pipelines/mlx_whisper/assets/gpt2.tiktoken rename to whisperplus/pipelines/assets/gpt2.tiktoken diff --git a/whisperplus/pipelines/mlx_whisper/assets/mel_filters.npz b/whisperplus/pipelines/assets/mel_filters.npz similarity index 100% rename from whisperplus/pipelines/mlx_whisper/assets/mel_filters.npz rename to whisperplus/pipelines/assets/mel_filters.npz diff --git a/whisperplus/pipelines/mlx_whisper/assets/multilingual.tiktoken b/whisperplus/pipelines/assets/multilingual.tiktoken similarity index 100% rename from whisperplus/pipelines/mlx_whisper/assets/multilingual.tiktoken rename to whisperplus/pipelines/assets/multilingual.tiktoken diff --git a/whisperplus/pipelines/lightning_whisper_mlx/__init__.py b/whisperplus/pipelines/lightning_whisper_mlx/__init__.py new file mode 100644 index 0000000..e84f556 --- /dev/null +++ b/whisperplus/pipelines/lightning_whisper_mlx/__init__.py @@ -0,0 +1 @@ +from .lightning import LightningWhisperMLX diff --git a/whisperplus/pipelines/lightning_whisper_mlx/audio.py b/whisperplus/pipelines/lightning_whisper_mlx/audio.py new file mode 100644 index 0000000..71f03c8 --- /dev/null +++ b/whisperplus/pipelines/lightning_whisper_mlx/audio.py @@ -0,0 +1,169 @@ +# Copyright © 2023 Apple Inc. + +import os +from functools import lru_cache +from subprocess import CalledProcessError, run +from typing import Optional, Union + +import mlx.core as mx +import numpy as np + +# hard-coded audio hyperparameters +SAMPLE_RATE = 16000 +N_FFT = 400 +HOP_LENGTH = 160 +CHUNK_LENGTH = 30 +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE +N_FRAMES = N_SAMPLES // HOP_LENGTH + +N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 +FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame +TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token + + +def load_audio(file: str, sr: int = SAMPLE_RATE): + """ + Open an audio file and read as mono waveform, resampling as necessary. + + Parameters + ---------- + file: str + The audio file to open + + sr: int + The sample rate to resample the audio if necessary + + Returns + ------- + A NumPy array containing the audio waveform, in float32 dtype. + """ + + # This launches a subprocess to decode audio while down-mixing + # and resampling as necessary. Requires the ffmpeg CLI in PATH. + # fmt: off + cmd = [ + "ffmpeg", + "-nostdin", + "-threads", "0", + "-i", file, + "-f", "s16le", + "-ac", "1", + "-acodec", "pcm_s16le", + "-ar", str(sr), + "-" + ] + # fmt: on + try: + out = run(cmd, capture_output=True, check=True).stdout + except CalledProcessError as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + + return mx.array(np.frombuffer(out, np.int16)).flatten().astype(mx.float32) / 32768.0 + + +def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): + """Pad or trim the audio array to N_SAMPLES, as expected by the encoder.""" + if array.shape[axis] > length: + sl = [slice(None)] * array.ndim + sl[axis] = slice(0, length) + array = array[tuple(sl)] + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = mx.pad(array, pad_widths) + + return array + + +@lru_cache(maxsize=None) +def mel_filters(n_mels: int) -> mx.array: + """ + Load the mel filterbank matrix for projecting STFT into a Mel spectrogram. Allows decoupling librosa + dependency; saved using: + + np.savez_compressed( "mel_filters.npz", mel_80=librosa.filters.mel(sr=16000, n_fft=400, + n_mels=80), mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), ) + """ + assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" + + filename = os.path.join(os.path.dirname(__file__), "../assets", "mel_filters.npz") + return mx.load(filename)[f"mel_{n_mels}"] + + +@lru_cache(maxsize=None) +def hanning(size): + return mx.array(np.hanning(size + 1)[:-1]) + + +def stft(x, window, nperseg=256, noverlap=None, nfft=None, axis=-1, pad_mode="reflect"): + if nfft is None: + nfft = nperseg + if noverlap is None: + noverlap = nfft // 4 + + def _pad(x, padding, pad_mode="constant"): + if pad_mode == "constant": + return mx.pad(x, [(padding, padding)]) + elif pad_mode == "reflect": + prefix = x[1:padding + 1][::-1] + suffix = x[-(padding + 1):-1][::-1] + return mx.concatenate([prefix, x, suffix]) + else: + raise ValueError(f"Invalid pad_mode {pad_mode}") + + padding = nperseg // 2 + x = _pad(x, padding, pad_mode) + + strides = [noverlap, 1] + t = (x.size - nperseg + noverlap) // noverlap + shape = [t, nfft] + x = mx.as_strided(x, shape=shape, strides=strides) + return mx.fft.rfft(x * window) + + +def log_mel_spectrogram( + audio: Union[str, np.ndarray], + n_mels: int = 80, + padding: int = 0, +): + """ + Compute the log-Mel spectrogram of. + + Parameters + ---------- + audio: Union[str, np.ndarray, mx.array], shape = (*) + The path to audio or either a NumPy or mlx array containing the audio waveform in 16 kHz + + n_mels: int + The number of Mel-frequency filters, only 80 is supported + + padding: int + Number of zero samples to pad to the right + + Returns + ------- + mx.array, shape = (80, n_frames) + An array that contains the Mel spectrogram + """ + device = mx.default_device() + mx.set_default_device(mx.cpu) + if isinstance(audio, str): + audio = load_audio(audio) + elif not isinstance(audio, mx.array): + audio = mx.array(audio) + + if padding > 0: + audio = mx.pad(audio, (0, padding)) + window = hanning(N_FFT) + freqs = stft(audio, window, nperseg=N_FFT, noverlap=HOP_LENGTH) + magnitudes = freqs[:-1, :].abs().square() + + filters = mel_filters(n_mels) + mel_spec = magnitudes @ filters.T + + log_spec = mx.maximum(mel_spec, 1e-10).log10() + log_spec = mx.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + mx.set_default_device(device) + return log_spec diff --git a/whisperplus/pipelines/lightning_whisper_mlx/decoding.py b/whisperplus/pipelines/lightning_whisper_mlx/decoding.py new file mode 100644 index 0000000..411c43d --- /dev/null +++ b/whisperplus/pipelines/lightning_whisper_mlx/decoding.py @@ -0,0 +1,661 @@ +# Copyright © 2023 Apple Inc. + +import zlib +from dataclasses import dataclass, field, replace +from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from mlx.utils import tree_map + +from .audio import CHUNK_LENGTH +from .tokenizer import Tokenizer, get_tokenizer + + +def compression_ratio(text) -> float: + text_bytes = text.encode("utf-8") + return len(text_bytes) / len(zlib.compress(text_bytes)) + + +def detect_language(model, mel: mx.array, tokenizer: Tokenizer = None) -> Tuple[mx.array, List[dict]]: + """ + Detect the spoken language in the audio, and return them as list of strings, along with the ids of the + most probable language tokens and the probability distribution over all language tokens. This is performed + outside the main decode loop in order to not interfere with kv-caching. + + Returns + ------- + language_tokens : mx.array, shape = (n_audio,) + ids of the most probable language tokens, which appears after the startoftranscript token. + language_probs : List[Dict[str, float]], length = n_audio + list of dictionaries containing the probability distribution over all languages. + """ + if tokenizer is None: + tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages) + if (tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence): + raise ValueError("This model doesn't have language tokens so it can't perform lang id") + + single = mel.ndim == 2 + if single: + mel = mel[None] + + # skip encoder forward pass if already-encoded audio features were given + if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state): + mel = model.encoder(mel) + + # forward pass using a single token, startoftranscript + n_audio = mel.shape[0] + x = mx.array([[tokenizer.sot]] * n_audio) # [n_audio, 1] + logits = model.logits(x, mel)[:, 0] + + # collect detected languages; suppress all non-language tokens + mask = np.full(logits.shape[-1], -np.inf, dtype=np.float32) + mask[list(tokenizer.all_language_tokens)] = 0.0 + logits += mx.array(mask) + language_tokens = mx.argmax(logits, axis=-1) + language_token_probs = mx.softmax(logits, axis=-1) + language_probs = [{ + c: language_token_probs[i, j].item() + for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes) + } for i in range(n_audio)] + + if single: + language_tokens = language_tokens[0] + language_probs = language_probs[0] + + return language_tokens, language_probs + + +@dataclass(frozen=True) +class DecodingOptions: + # whether to perform X->X "transcribe" or X->English "translate" + task: str = "transcribe" + + # language that the audio is in; uses detected language if None + language: Optional[str] = None + + # sampling-related options + temperature: float = 0.0 + sample_len: Optional[int] = None # maximum number of tokens to sample + best_of: Optional[int] = None # number of independent sample trajectories, if t > 0 + beam_size: Optional[int] = None # number of beams in beam search, if t == 0 + patience: Optional[float] = None # patience in beam search (arxiv:2204.05424) + + # "alpha" in Google NMT, or None for length norm, when ranking generations + # to select which to return among the beams or best-of-N samples + length_penalty: Optional[float] = None + + # text or tokens to feed as the prompt or the prefix; for more info: + # https://github.com/openai/whisper/discussions/117#discussioncomment-3727051 + prompt: Optional[Union[str, List[int]]] = None # for the previous context + prefix: Optional[Union[str, List[int]]] = None # to prefix the current context + + # list of tokens ids (or comma-separated token ids) to suppress + # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()` + suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1" + suppress_blank: bool = True # this will suppress blank outputs + + # timestamp sampling options + without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only + max_initial_timestamp: Optional[float] = 1.0 + + # implementation details + fp16: bool = True # use fp16 for most of the calculation + + +@dataclass(frozen=True) +class DecodingResult: + audio_features: mx.array + language: str + language_probs: Optional[Dict[str, float]] = None + tokens: List[int] = field(default_factory=list) + text: str = "" + avg_logprob: float = np.nan + no_speech_prob: float = np.nan + temperature: float = np.nan + compression_ratio: float = np.nan + + +class Inference: + + def __init__(self, model, initial_token_length: int): + self.model = model + self.initial_token_length = initial_token_length + self.kv_cache = None + + def logits(self, tokens: mx.array, audio_features: mx.array) -> mx.array: + """Perform a forward pass on the decoder and return per-token logits.""" + if tokens.shape[-1] > self.initial_token_length: + # only need to use the last token except in the first forward pass + tokens = tokens[:, -1:] + + logits, self.kv_cache, _ = self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache) + return logits.astype(mx.float32) + + def rearrange_kv_cache(self, source_indices): + """Update the key-value cache according to the updated beams.""" + # update the key/value cache to contain the selected sequences + if source_indices != list(range(len(source_indices))): + self.kv_cache = tree_map(lambda x: x[source_indices], self.kv_cache) + + def reset(self): + self.kv_cache = None + + +class SequenceRanker: + + def rank(self, tokens: List[List[mx.array]], sum_logprobs: List[List[float]]) -> List[int]: + """Given a list of groups of samples and their cumulative log probabilities, return the indices of the + samples in each group to select as the final result. + """ + raise NotImplementedError + + +class MaximumLikelihoodRanker(SequenceRanker): + """Select the sample with the highest log probabilities, penalized using either a simple length + normalization or Google NMT paper's length penalty. + """ + + def __init__(self, length_penalty: Optional[float]): + self.length_penalty = length_penalty + + def rank(self, tokens: List[List[List[int]]], sum_logprobs: List[List[float]]): + + def scores(logprobs, lengths): + result = [] + for logprob, length in zip(logprobs, lengths): + if self.length_penalty is None: + penalty = length + else: + # from the Google NMT paper + penalty = ((5 + length) / 6)**self.length_penalty + result.append(logprob / penalty) + return result + + # get the sequence with the highest score + lengths = [[len(t) for t in s] for s in tokens] + return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)] + + +class TokenDecoder: + + def reset(self): + """Initialize any stateful variables for decoding a new sequence.""" + + def update(self, tokens: mx.array, logits: mx.array, + sum_logprobs: mx.array) -> Tuple[mx.array, bool, mx.array]: + """ + Specify how to select the next token, based on the current trace and logits. + + Parameters + ---------- + tokens : mx.array, shape = (n_batch, current_sequence_length) + all tokens in the context so far, including the prefix and sot_sequence tokens + + logits : mx.array, shape = (n_batch, vocab_size) + per-token logits of the probability distribution at the current step + + sum_logprobs : mx.array, shape = (n_batch) + cumulative log probabilities for each sequence + + Returns + ------- + tokens : mx.array, shape = (n_batch, current_sequence_length + 1) + the tokens, appended with the selected next token + + completed : bool + True if all sequences has reached the end of text + + sum_logprobs: mx.array, shape = (n_batch) + updated cumulative log probabilities for each sequence + """ + raise NotImplementedError + + def finalize(self, tokens: mx.array, + sum_logprobs: mx.array) -> Tuple[Sequence[Sequence[mx.array]], List[List[float]]]: + """ + Finalize search and return the final candidate sequences. + + Parameters + ---------- + tokens : mx.array, shape = (n_audio, n_group, current_sequence_length) + all tokens in the context so far, including the prefix and sot_sequence + + sum_logprobs : mx.array, shape = (n_audio, n_group) + cumulative log probabilities for each sequence + + Returns + ------- + tokens : Sequence[Sequence[mx.array]], length = n_audio + sequence of mx.arrays containing candidate token sequences, for each audio input + + sum_logprobs : List[List[float]], length = n_audio + sequence of cumulative log probabilities corresponding to the above + """ + raise NotImplementedError + + +class GreedyDecoder(TokenDecoder): + + def __init__(self, temperature: float, eot: int): + self.temperature = temperature + self.eot = eot + + def update(self, tokens: mx.array, logits: mx.array, + sum_logprobs: mx.array) -> Tuple[mx.array, bool, mx.array]: + if self.temperature == 0: + next_tokens = logits.argmax(axis=-1) + else: + next_tokens = mx.random.categorical(logits=logits / self.temperature) + + next_tokens = mx.argmax(logits, axis=-1) + logits = logits.astype(mx.float32) + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + + current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens] + sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot) + + eot_mask = tokens[:, -1] == self.eot + next_tokens = next_tokens * (1 - eot_mask) + self.eot * eot_mask + tokens = mx.concatenate([tokens, next_tokens[:, None]], axis=-1) + + completed = mx.all(tokens[:, -1] == self.eot) + return tokens, completed, sum_logprobs + + def finalize(self, tokens: mx.array, sum_logprobs: mx.array): + # make sure each sequence has at least one EOT token at the end + tokens = mx.pad(tokens, [(0, 0), (0, 0), (0, 1)], constant_values=self.eot) + return tokens, sum_logprobs.tolist() + + +class LogitFilter: + + def apply(self, logits: mx.array, tokens: mx.array) -> mx.array: + """ + Apply any filtering or masking to logits. + + Parameters + ---------- + logits : mx.array, shape = (n_batch, vocab_size) + per-token logits of the probability distribution at the current step + + tokens : mx.array, shape = (n_batch, current_sequence_length) + all tokens in the context so far, including the prefix and sot_sequence tokens + """ + raise NotImplementedError + + +class SuppressBlank(LogitFilter): + + def __init__(self, tokenizer: Tokenizer, sample_begin: int, n_vocab: int): + self.sample_begin = sample_begin + mask = np.zeros(n_vocab, np.float32) + mask[tokenizer.encode(" ") + [tokenizer.eot]] = -np.inf + self.mask = mx.array(mask) + + def apply(self, logits: mx.array, tokens: mx.array) -> mx.array: + if tokens.shape[1] == self.sample_begin: + return logits + self.mask + return logits + + +class SuppressTokens(LogitFilter): + + def __init__(self, suppress_tokens: Sequence[int], n_vocab: int): + mask = np.zeros(n_vocab, np.float32) + mask[list(suppress_tokens)] = -np.inf + self.mask = mx.array(mask) + + def apply(self, logits: mx.array, tokens: mx.array) -> mx.array: + return logits + self.mask + + +class ApplyTimestampRules(LogitFilter): + + def __init__( + self, + tokenizer: Tokenizer, + sample_begin: int, + max_initial_timestamp_index: Optional[int], + ): + self.tokenizer = tokenizer + self.sample_begin = sample_begin + self.max_initial_timestamp_index = max_initial_timestamp_index + + def apply(self, logits: mx.array, tokens: mx.array) -> mx.array: + mask = np.zeros(logits.shape, np.float32) + # suppress <|notimestamps|> which is handled by without_timestamps + if self.tokenizer.no_timestamps is not None: + mask[:, self.tokenizer.no_timestamps] = -np.inf + + # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly + for k in range(tokens.shape[0]): + sampled_tokens = tokens[k, self.sample_begin:] + seq = sampled_tokens.tolist() + last_was_timestamp = (len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin) + penultimate_was_timestamp = (len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin) + + if last_was_timestamp: + if penultimate_was_timestamp: # has to be non-timestamp + mask[k, self.tokenizer.timestamp_begin:] = -np.inf + else: # cannot be normal text tokens + mask[k, :self.tokenizer.eot] = -np.inf + + timestamps = [i for i, v in enumerate(seq) if v > self.tokenizer.timestamp_begin] + if len(timestamps) > 0: + # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last + # also force each segment to have a nonzero length, to prevent infinite looping + last_timestamp = timestamps[-1] + if not last_timestamp or penultimate_was_timestamp: + last_timestamp += 1 + mask[k, self.tokenizer.timestamp_begin:last_timestamp] = -np.inf + + if tokens.shape[1] == self.sample_begin: + # suppress generating non-timestamp tokens at the beginning + mask[:, :self.tokenizer.timestamp_begin] = -np.inf + + # apply the `max_initial_timestamp` option + if self.max_initial_timestamp_index is not None: + last_allowed = (self.tokenizer.timestamp_begin + self.max_initial_timestamp_index) + mask[:, last_allowed + 1:] = -np.inf + + # if sum of probability over timestamps is above any other token, sample timestamp + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + for k in range(tokens.shape[0]): + timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin:].logsumexp(axis=-1) + max_text_token_logprob = logprobs[k, :self.tokenizer.timestamp_begin].max() + if timestamp_logprob > max_text_token_logprob: + mask[k, :self.tokenizer.timestamp_begin] = -np.inf + + return logits + mx.array(mask, logits.dtype) + + +class DecodingTask: + inference: Inference + sequence_ranker: SequenceRanker + decoder: TokenDecoder + logit_filters: List[LogitFilter] + + def __init__(self, model, options: DecodingOptions): + self.model = model + + language = options.language or "en" + tokenizer = get_tokenizer( + model.is_multilingual, + num_languages=model.num_languages, + language=language, + task=options.task, + ) + self.tokenizer: Tokenizer = tokenizer + self.options: DecodingOptions = self._verify_options(options) + + self.n_group: int = options.beam_size or options.best_of or 1 + self.n_ctx: int = model.dims.n_text_ctx + self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2 + + self.sot_sequence: Tuple[int] = tokenizer.sot_sequence + if self.options.without_timestamps: + self.sot_sequence = tokenizer.sot_sequence_including_notimestamps + + self.initial_tokens: Tuple[int] = self._get_initial_tokens() + self.sample_begin: int = len(self.initial_tokens) + self.sot_index: int = self.initial_tokens.index(tokenizer.sot) + + # inference: implements the forward pass through the decoder, including kv caching + self.inference = Inference(model, len(self.initial_tokens)) + + # sequence ranker: implements how to rank a group of sampled sequences + self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty) + + # decoder: implements how to select the next tokens, given the autoregressive distribution + if options.beam_size is not None: + raise NotImplementedError("Beam search decoder is not yet implemented") + # self.decoder = BeamSearchDecoder( + # options.beam_size, tokenizer.eot, self.inference, options.patience + # ) + else: + self.decoder = GreedyDecoder(options.temperature, tokenizer.eot) + + # logit filters: applies various rules to suppress or penalize certain tokens + self.logit_filters = [] + if self.options.suppress_blank: + self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin, model.dims.n_vocab)) + if self.options.suppress_tokens: + self.logit_filters.append(SuppressTokens(self._get_suppress_tokens(), model.dims.n_vocab)) + if not options.without_timestamps: + precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds + max_initial_timestamp_index = None + if options.max_initial_timestamp: + max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision) + self.logit_filters.append( + ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)) + + def _verify_options(self, options: DecodingOptions) -> DecodingOptions: + if options.beam_size is not None and options.best_of is not None: + raise ValueError("beam_size and best_of can't be given together") + if options.temperature == 0: + if options.best_of is not None: + raise ValueError("best_of with greedy sampling (T=0) is not compatible") + if options.patience is not None and options.beam_size is None: + raise ValueError("patience requires beam_size to be given") + if options.length_penalty is not None and not (0 <= options.length_penalty <= 1): + raise ValueError("length_penalty (alpha) should be a value between 0 and 1") + + return options + + def _get_initial_tokens(self) -> Tuple[int]: + tokens = list(self.sot_sequence) + + if prefix := self.options.prefix: + prefix_tokens = ( + self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix) + if self.sample_len is not None: + max_prefix_len = self.n_ctx // 2 - self.sample_len + prefix_tokens = prefix_tokens[-max_prefix_len:] + tokens = tokens + prefix_tokens + + if prompt := self.options.prompt: + prompt_tokens = ( + self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt) + tokens = ([self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1):] + tokens) + + return tuple(tokens) + + def _get_suppress_tokens(self) -> Tuple[int]: + suppress_tokens = self.options.suppress_tokens + + if isinstance(suppress_tokens, str): + suppress_tokens = [int(t) for t in suppress_tokens.split(",")] + + if -1 in suppress_tokens: + suppress_tokens = [t for t in suppress_tokens if t >= 0] + suppress_tokens.extend(self.tokenizer.non_speech_tokens) + elif suppress_tokens is None or len(suppress_tokens) == 0: + suppress_tokens = [] # interpret empty string as an empty list + else: + assert isinstance(suppress_tokens, list), "suppress_tokens must be a list" + + suppress_tokens.extend([ + self.tokenizer.transcribe, + self.tokenizer.translate, + self.tokenizer.sot, + self.tokenizer.sot_prev, + self.tokenizer.sot_lm, + ]) + if self.tokenizer.no_speech is not None: + # no-speech probability is collected separately + suppress_tokens.append(self.tokenizer.no_speech) + + return tuple(sorted(set(suppress_tokens))) + + def _get_audio_features(self, mel: mx.array): + if self.options.fp16: + mel = mel.astype(mx.float16) + + if mel.shape[-2:] == ( + self.model.dims.n_audio_ctx, + self.model.dims.n_audio_state, + ): + # encoded audio features are given; skip audio encoding + audio_features = mel + else: + audio_features = self.model.encoder(mel) + + if audio_features.dtype != (mx.float16 if self.options.fp16 else mx.float32): + raise TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}") + + return audio_features + + def _detect_language(self, audio_features: mx.array, tokens: np.array): + languages = [self.options.language] * audio_features.shape[0] + lang_probs = None + + if self.options.language is None or self.options.task == "lang_id": + lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer) + languages = [max(probs, key=probs.get) for probs in lang_probs] + if self.options.language is None: + # write language tokens + tokens[:, self.sot_index + 1] = np.array(lang_tokens) + + return languages, lang_probs + + def _main_loop(self, audio_features: mx.array, tokens: mx.array): + n_batch = tokens.shape[0] + sum_logprobs: mx.array = mx.zeros(n_batch) + no_speech_probs = [np.nan] * n_batch + + try: + for i in range(self.sample_len): + logits = self.inference.logits(tokens, audio_features) + + if (i == 0 and self.tokenizer.no_speech is not None): # save no_speech_probs + probs_at_sot = mx.softmax(logits[:, self.sot_index].astype(mx.float32), axis=-1) + no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() + + # now we need to consider the logits at the last token only + logits = logits[:, -1] + + # apply the logit filters, e.g. for suppressing or applying penalty to + for logit_filter in self.logit_filters: + logits = logit_filter.apply(logits, tokens) + + # expand the tokens tensor with the selected next tokens + tokens, completed, sum_logprobs = self.decoder.update(tokens, logits, sum_logprobs) + + if completed or tokens.shape[-1] > self.n_ctx: + break + finally: + self.inference.reset() + + return tokens, sum_logprobs, no_speech_probs + + def run(self, mel: mx.array) -> List[DecodingResult]: + self.decoder.reset() + tokenizer: Tokenizer = self.tokenizer + n_audio: int = mel.shape[0] + + mel = mx.repeat(mel, repeats=1, axis=0) + n_audio *= 1 + + audio_features: mx.array = self._get_audio_features(mel) # encoder forward pass + tokens: np.array = np.array(self.initial_tokens) + tokens = np.broadcast_to(tokens, (n_audio, len(self.initial_tokens))).copy() + + # detect language if requested, overwriting the language token + languages, language_probs = self._detect_language(audio_features, tokens) + if self.options.task == "lang_id": + return [ + DecodingResult(audio_features=features, language=language, language_probs=probs) + for features, language, probs in zip(audio_features, languages, language_probs) + ] + + # repeat tokens by the group size, for beam search or best-of-n sampling + tokens = mx.array(tokens) + if self.n_group > 1: + tokens = tokens[:, None, :] + tokens = mx.broadcast_to(tokens, [n_audio, self.n_group, len(self.initial_tokens)]) + tokens = tokens.reshape(tokens, (n_audio * self.n_group, len(self.initial_tokens))) + + # call the main sampling loop + tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens) + + # reshape the tensors to have (n_audio, n_group) as the first two dimensions + audio_features = audio_features[::self.n_group] + no_speech_probs = no_speech_probs[::self.n_group] + assert audio_features.shape[0] == len(no_speech_probs) == n_audio + + tokens = tokens.reshape(n_audio, self.n_group, -1) + sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group) + + # get the final candidates for each group, and slice between the first sampled token and EOT + tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) + tokens = tokens[..., self.sample_begin:].tolist() + tokens = [[t[:t.index(tokenizer.eot)] for t in s] for s in tokens] + + # select the top-ranked sample in each group + selected = self.sequence_ranker.rank(tokens, sum_logprobs) + tokens: List[List[int]] = [t[i] for i, t in zip(selected, tokens)] + texts: List[str] = [tokenizer.decode(t).strip() for t in tokens] + + sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)] + avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)] + + fields = ( + texts, + languages, + tokens, + audio_features, + avg_logprobs, + no_speech_probs, + ) + if len(set(map(len, fields))) != 1: + raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}") + + return [ + DecodingResult( + audio_features=features, + language=language, + tokens=tokens, + text=text, + avg_logprob=avg_logprob, + no_speech_prob=no_speech_prob, + temperature=self.options.temperature, + compression_ratio=compression_ratio(text), + ) for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields) + ] + + +def decode( + model, + mel: mx.array, + options: DecodingOptions = DecodingOptions(), + **kwargs, +) -> Union[DecodingResult, List[DecodingResult]]: + """ + Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). + + Parameters + ---------- + model: Whisper + the Whisper model instance + + mel: mx.array, shape = (80, 3000) or (*, 80, 3000) + An array containing the Mel spectrogram(s) + + options: DecodingOptions + A dataclass that contains all necessary options for decoding 30-second segments + + Returns + ------- + result: Union[DecodingResult, List[DecodingResult]] + The result(s) of decoding contained in `DecodingResult` dataclass instance(s) + """ + if single := mel.ndim == 2: + mel = mel[None] + + if kwargs: + options = replace(options, **kwargs) + + result = DecodingTask(model, options).run(mel) + return result[0] if single else result diff --git a/whisperplus/pipelines/lightning_whisper_mlx/lightning.py b/whisperplus/pipelines/lightning_whisper_mlx/lightning.py new file mode 100644 index 0000000..aa6763e --- /dev/null +++ b/whisperplus/pipelines/lightning_whisper_mlx/lightning.py @@ -0,0 +1,99 @@ +from huggingface_hub import hf_hub_download + +from .transcribe import transcribe_audio + +models = { + "tiny": { + "base": "mlx-community/whisper-tiny", + "4bit": "mlx-community/whisper-tiny-mlx-4bit", + "8bit": "mlx-community/whisper-tiny-mlx-8bit" + }, + "small": { + "base": "mlx-community/whisper-small-mlx", + "4bit": "mlx-community/whisper-small-mlx-4bit", + "8bit": "mlx-community/whisper-small-mlx-8bit" + }, + "distil-small.en": { + "base": "mustafaaljadery/distil-whisper-mlx", + }, + "base": { + "base": "mlx-community/whisper-base-mlx", + "4bit": "mlx-community/whisper-base-mlx-4bit", + "8bit": "mlx-community/whisper-base-mlx-8bit" + }, + "medium": { + "base": "mlx-community/whisper-medium-mlx", + "4bit": "mlx-community/whisper-medium-mlx-4bit", + "8bit": "mlx-community/whisper-medium-mlx-8bit" + }, + "distil-medium.en": { + "base": "mustafaaljadery/distil-whisper-mlx", + }, + "large": { + "base": "mlx-community/whisper-large-mlx", + "4bit": "mlx-community/whisper-large-mlx-4bit", + "8bit": "mlx-community/whisper-large-mlx-8bit", + }, + "large-v2": { + "base": "mlx-community/whisper-large-v2-mlx", + "4bit": "mlx-community/whisper-large-v2-mlx-4bit", + "8bit": "mlx-community/whisper-large-v2-mlx-8bit", + }, + "distil-large-v2": { + "base": "mustafaaljadery/distil-whisper-mlx", + }, + "large-v3": { + "base": "mlx-community/whisper-large-v3-mlx", + "4bit": "mlx-community/whisper-large-v3-mlx-4bit", + "8bit": "mlx-community/whisper-large-v3-mlx-8bit", + }, + "distil-large-v3": { + "base": "mustafaaljadery/distil-whisper-mlx", + }, +} + + +class LightningWhisperMLX(): + + def __init__(self, model, batch_size=12, quant=None): + if quant and (quant != "4bit" and quant != "8bit"): + raise ValueError("Quantization must be `4bit` or `8bit`") + + if model not in models: + raise ValueError("Please select a valid model") + + self.name = model + self.batch_size = batch_size + + repo_id = "" + + if quant and "distil" not in model: + repo_id = models[model][quant] + else: + repo_id = models[model]['base'] + + if quant and "distil" in model: + if quant == "4bit": + self.name += "-4-bit" + else: + self.name += "-8-bit" + + if "distil" in model: + filename1 = f"./mlx_models/{self.name}/weights.npz" + filename2 = f"./mlx_models/{self.name}/config.json" + local_dir = "./" + else: + filename1 = "weights.npz" + filename2 = "config.json" + local_dir = f"./mlx_models/{self.name}" + + hf_hub_download(repo_id=repo_id, filename=filename1, local_dir=local_dir) + hf_hub_download(repo_id=repo_id, filename=filename2, local_dir=local_dir) + + def transcribe(self, audio_path, language=None): + result = transcribe_audio( + audio_path, + path_or_hf_repo=f'./mlx_models/{self.name}', + language=language, + batch_size=self.batch_size) + return result diff --git a/whisperplus/pipelines/lightning_whisper_mlx/load_models.py b/whisperplus/pipelines/lightning_whisper_mlx/load_models.py new file mode 100644 index 0000000..9c774ce --- /dev/null +++ b/whisperplus/pipelines/lightning_whisper_mlx/load_models.py @@ -0,0 +1,40 @@ +# Copyright © 2023 Apple Inc. + +import json +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn +from huggingface_hub import snapshot_download +from mlx.utils import tree_unflatten + +from . import whisper + + +def load_model( + path_or_hf_repo: str, + dtype: mx.Dtype = mx.float32, +) -> whisper.Whisper: + model_path = Path(path_or_hf_repo) + if not model_path.exists(): + model_path = Path(snapshot_download(repo_id=path_or_hf_repo)) + + with open(str(model_path / "config.json")) as f: + config = json.loads(f.read()) + config.pop("model_type", None) + quantization = config.pop("quantization", None) + + model_args = whisper.ModelDimensions(**config) + + weights = mx.load(str(model_path / "weights.npz")) + weights = tree_unflatten(list(weights.items())) + + model = whisper.Whisper(model_args, dtype) + + if quantization is not None: + class_predicate = (lambda p, m: isinstance(m, (nn.Linear, nn.Embedding)) and f"{p}.scales" in weights) + nn.quantize(model, **quantization, class_predicate=class_predicate) + + model.update(weights) + mx.eval(model.parameters()) + return model diff --git a/whisperplus/pipelines/lightning_whisper_mlx/timing.py b/whisperplus/pipelines/lightning_whisper_mlx/timing.py new file mode 100644 index 0000000..2682c85 --- /dev/null +++ b/whisperplus/pipelines/lightning_whisper_mlx/timing.py @@ -0,0 +1,295 @@ +# Copyright © 2023 Apple Inc. + +import itertools +from dataclasses import dataclass +from typing import TYPE_CHECKING, List + +import mlx.core as mx +import numba +import numpy as np +from scipy import signal + +from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND +from .tokenizer import Tokenizer + +if TYPE_CHECKING: + from .model import Whisper + + +def median_filter(x: np.ndarray, filter_width: int): + """Apply a median filter of width `filter_width` along the last dimension of `x`""" + pad_width = filter_width // 2 + if x.shape[-1] <= pad_width: + # F.pad requires the padding width to be smaller than the input dimension + return x + + if (ndim := x.ndim) <= 2: + # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D + x = x[None, None, :] + + assert (filter_width > 0 and filter_width % 2 == 1), "`filter_width` should be an odd number" + + x = np.pad(x, ((0, 0), (0, 0), (pad_width, pad_width)), mode="reflect") + + # todo: more efficient version in mlx + result = signal.medfilt(x.astype(np.float32), kernel_size=(1, 1, filter_width))[..., pad_width:-pad_width] + + if ndim <= 2: + result = result[0, 0] + + return result + + +@numba.jit(nopython=True) +def backtrace(trace: np.ndarray): + i = trace.shape[0] - 1 + j = trace.shape[1] - 1 + trace[0, :] = 2 + trace[:, 0] = 1 + + result = [] + while i > 0 or j > 0: + result.append((i - 1, j - 1)) + + if trace[i, j] == 0: + i -= 1 + j -= 1 + elif trace[i, j] == 1: + i -= 1 + elif trace[i, j] == 2: + j -= 1 + else: + raise ValueError("Unexpected trace[i, j]") + + result = np.array(result) + return result[::-1, :].T + + +@numba.jit(nopython=True, parallel=True) +def dtw_cpu(x: np.ndarray): + N, M = x.shape + cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf + trace = -np.ones((N + 1, M + 1), dtype=np.float32) + + cost[0, 0] = 0 + for j in range(1, M + 1): + for i in range(1, N + 1): + c0 = cost[i - 1, j - 1] + c1 = cost[i - 1, j] + c2 = cost[i, j - 1] + + if c0 < c1 and c0 < c2: + c, t = c0, 0 + elif c1 < c0 and c1 < c2: + c, t = c1, 1 + else: + c, t = c2, 2 + + cost[i, j] = x[i - 1, j - 1] + c + trace[i, j] = t + + return backtrace(trace) + + +def dtw(x: np.ndarray) -> np.ndarray: + # todo: more efficient version in mlx + return dtw_cpu(x) + + +@dataclass +class WordTiming: + word: str + tokens: List[int] + start: float + end: float + probability: float + + +def find_alignment( + model: "Whisper", + tokenizer: Tokenizer, + text_tokens: List[int], + mel: mx.array, + num_frames: int, + *, + medfilt_width: int = 7, + qk_scale: float = 1.0, +) -> List[WordTiming]: + if len(text_tokens) == 0: + return [] + + tokens = mx.array([ + *tokenizer.sot_sequence, + tokenizer.no_timestamps, + *text_tokens, + tokenizer.eot, + ]) + + logits, cross_qk = model.forward_with_cross_qk(mel[None, :], tokens[None, :]) + # consider only the logits associated with predicting text + sampled_logits = logits[0][len(tokenizer.sot_sequence):-2, :tokenizer.eot] + token_probs = mx.softmax(sampled_logits.astype(mx.float32), axis=-1).astype(sampled_logits.dtype) + text_token_probs = mx.take_along_axis(token_probs, mx.array(text_tokens)[:, None], axis=1).squeeze(1) + text_token_probs = np.array(text_token_probs) + + # heads * tokens * frames + weights = mx.stack([cross_qk[_l.item()][0, _h.item()] for _l, _h in model.alignment_heads]) + weights = weights[:, :, :num_frames // 2] + weights = mx.softmax(weights * qk_scale, axis=-1) + mean = mx.mean(weights, axis=-2, keepdims=True) + std = mx.var(weights, axis=-2, keepdims=True, ddof=0).sqrt() + weights = (weights - mean) / std + weights = median_filter(np.array(weights), medfilt_width) + + matrix = weights.mean(axis=0) + matrix = matrix[len(tokenizer.sot_sequence):-1] + text_indices, time_indices = dtw(-matrix) + + words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot]) + if len(word_tokens) <= 1: + # return on eot only + # >>> np.pad([], (1, 0)) + # array([0.]) + # This results in crashes when we lookup jump_times with float, like + # IndexError: arrays used as indices must be of integer (or boolean) type + return [] + word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) + + jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) + jump_times = time_indices[jumps] / TOKENS_PER_SECOND + start_times = jump_times[word_boundaries[:-1]] + end_times = jump_times[word_boundaries[1:]] + word_probabilities = [ + np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) + ] + + return [ + WordTiming(word, tokens, start, end, probability) for word, tokens, start, end, probability in zip( + words, word_tokens, start_times, end_times, word_probabilities) + ] + + +def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str): + # merge prepended punctuations + i = len(alignment) - 2 + j = len(alignment) - 1 + while i >= 0: + previous = alignment[i] + following = alignment[j] + if previous.word.startswith(" ") and previous.word.strip() in prepended: + # prepend it to the following word + following.word = previous.word + following.word + following.tokens = previous.tokens + following.tokens + previous.word = "" + previous.tokens = [] + else: + j = i + i -= 1 + + # merge appended punctuations + i = 0 + j = 1 + while j < len(alignment): + previous = alignment[i] + following = alignment[j] + if not previous.word.endswith(" ") and following.word in appended: + # append it to the previous word + previous.word = previous.word + following.word + previous.tokens = previous.tokens + following.tokens + following.word = "" + following.tokens = [] + else: + i = j + j += 1 + + +def add_word_timestamps( + *, + segments: List[dict], + model: "Whisper", + tokenizer: Tokenizer, + mel: mx.array, + num_frames: int, + prepend_punctuations: str = "\"'“¿([{-", + append_punctuations: str = "\"'.。,,!!??::”)]}、", + last_speech_timestamp: float, + **kwargs, +): + if len(segments) == 0: + return + + text_tokens_per_segment = [[token for token in segment["tokens"] if token < tokenizer.eot] + for segment in segments] + + text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment)) + alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs) + word_durations = np.array([t.end - t.start for t in alignment]) + word_durations = word_durations[word_durations.nonzero()] + median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 + median_duration = min(0.7, float(median_duration)) + max_duration = median_duration * 2 + + # hack: truncate long words at sentence boundaries. + # a better segmentation algorithm based on VAD should be able to replace this. + if len(word_durations) > 0: + sentence_end_marks = ".。!!??" + # ensure words at sentence boundaries are not longer than twice the median word duration. + for i in range(1, len(alignment)): + if alignment[i].end - alignment[i].start > max_duration: + if alignment[i].word in sentence_end_marks: + alignment[i].end = alignment[i].start + max_duration + elif alignment[i - 1].word in sentence_end_marks: + alignment[i].start = alignment[i].end - max_duration + + merge_punctuations(alignment, prepend_punctuations, append_punctuations) + + time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE + word_index = 0 + + for segment, text_tokens in zip(segments, text_tokens_per_segment): + saved_tokens = 0 + words = [] + + while word_index < len(alignment) and saved_tokens < len(text_tokens): + timing = alignment[word_index] + + if timing.word: + words.append( + dict( + word=timing.word, + start=round(time_offset + timing.start, 2), + end=round(time_offset + timing.end, 2), + probability=timing.probability, + )) + + saved_tokens += len(timing.tokens) + word_index += 1 + + # hack: truncate long words at segment boundaries. + # a better segmentation algorithm based on VAD should be able to replace this. + if len(words) > 0: + # ensure the first and second word after a pause is not longer than + # twice the median word duration. + if (words[0]["end"] - last_speech_timestamp > median_duration * 4 and + (words[0]["end"] - words[0]["start"] > max_duration or + (len(words) > 1 and words[1]["end"] - words[0]["start"] > max_duration * 2))): + if len(words) > 1 and words[1]["end"] - words[1]["start"] > max_duration: + boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration) + words[0]["end"] = words[1]["start"] = boundary + words[0]["start"] = max(0, words[0]["end"] - max_duration) + + # prefer the segment-level start timestamp if the first word is too long. + if (segment["start"] < words[0]["end"] and segment["start"] - 0.5 > words[0]["start"]): + words[0]["start"] = max(0, min(words[0]["end"] - median_duration, segment["start"])) + else: + segment["start"] = words[0]["start"] + + # prefer the segment-level end timestamp if the last word is too long. + if (segment["end"] > words[-1]["start"] and segment["end"] + 0.5 < words[-1]["end"]): + words[-1]["end"] = max(words[-1]["start"] + median_duration, segment["end"]) + else: + segment["end"] = words[-1]["end"] + + last_speech_timestamp = segment["end"] + + segment["words"] = words diff --git a/whisperplus/pipelines/lightning_whisper_mlx/tokenizer.py b/whisperplus/pipelines/lightning_whisper_mlx/tokenizer.py new file mode 100644 index 0000000..b8c0378 --- /dev/null +++ b/whisperplus/pipelines/lightning_whisper_mlx/tokenizer.py @@ -0,0 +1,392 @@ +# Copyright © 2023 Apple Inc. + +import base64 +import os +import string +from dataclasses import dataclass, field +from functools import cached_property, lru_cache +from typing import Dict, List, Optional, Tuple + +import tiktoken + +LANGUAGES = { + "en": "english", + "zh": "chinese", + "de": "german", + "es": "spanish", + "ru": "russian", + "ko": "korean", + "fr": "french", + "ja": "japanese", + "pt": "portuguese", + "tr": "turkish", + "pl": "polish", + "ca": "catalan", + "nl": "dutch", + "ar": "arabic", + "sv": "swedish", + "it": "italian", + "id": "indonesian", + "hi": "hindi", + "fi": "finnish", + "vi": "vietnamese", + "he": "hebrew", + "uk": "ukrainian", + "el": "greek", + "ms": "malay", + "cs": "czech", + "ro": "romanian", + "da": "danish", + "hu": "hungarian", + "ta": "tamil", + "no": "norwegian", + "th": "thai", + "ur": "urdu", + "hr": "croatian", + "bg": "bulgarian", + "lt": "lithuanian", + "la": "latin", + "mi": "maori", + "ml": "malayalam", + "cy": "welsh", + "sk": "slovak", + "te": "telugu", + "fa": "persian", + "lv": "latvian", + "bn": "bengali", + "sr": "serbian", + "az": "azerbaijani", + "sl": "slovenian", + "kn": "kannada", + "et": "estonian", + "mk": "macedonian", + "br": "breton", + "eu": "basque", + "is": "icelandic", + "hy": "armenian", + "ne": "nepali", + "mn": "mongolian", + "bs": "bosnian", + "kk": "kazakh", + "sq": "albanian", + "sw": "swahili", + "gl": "galician", + "mr": "marathi", + "pa": "punjabi", + "si": "sinhala", + "km": "khmer", + "sn": "shona", + "yo": "yoruba", + "so": "somali", + "af": "afrikaans", + "oc": "occitan", + "ka": "georgian", + "be": "belarusian", + "tg": "tajik", + "sd": "sindhi", + "gu": "gujarati", + "am": "amharic", + "yi": "yiddish", + "lo": "lao", + "uz": "uzbek", + "fo": "faroese", + "ht": "haitian creole", + "ps": "pashto", + "tk": "turkmen", + "nn": "nynorsk", + "mt": "maltese", + "sa": "sanskrit", + "lb": "luxembourgish", + "my": "myanmar", + "bo": "tibetan", + "tl": "tagalog", + "mg": "malagasy", + "as": "assamese", + "tt": "tatar", + "haw": "hawaiian", + "ln": "lingala", + "ha": "hausa", + "ba": "bashkir", + "jw": "javanese", + "su": "sundanese", + "yue": "cantonese", +} + +# language code lookup by name, with a few language aliases +TO_LANGUAGE_CODE = { + **{ + language: code + for code, language in LANGUAGES.items() + }, + "burmese": "my", + "valencian": "ca", + "flemish": "nl", + "haitian": "ht", + "letzeburgesch": "lb", + "pushto": "ps", + "panjabi": "pa", + "moldavian": "ro", + "moldovan": "ro", + "sinhalese": "si", + "castilian": "es", + "mandarin": "zh", +} + + +@dataclass +class Tokenizer: + """A thin wrapper around `tiktoken` providing quick access to special tokens.""" + + encoding: tiktoken.Encoding + num_languages: int + language: Optional[str] = None + task: Optional[str] = None + sot_sequence: Tuple[int] = () + special_tokens: Dict[str, int] = field(default_factory=dict) + + def __post_init__(self): + for special in self.encoding.special_tokens_set: + special_token = self.encoding.encode_single_token(special) + self.special_tokens[special] = special_token + + sot: int = self.special_tokens["<|startoftranscript|>"] + translate: int = self.special_tokens["<|translate|>"] + transcribe: int = self.special_tokens["<|transcribe|>"] + + langs = tuple(LANGUAGES.keys())[:self.num_languages] + sot_sequence = [sot] + if self.language is not None: + sot_sequence.append(sot + 1 + langs.index(self.language)) + if self.task is not None: + task_token: int = transcribe if self.task == "transcribe" else translate + sot_sequence.append(task_token) + + self.sot_sequence = tuple(sot_sequence) + + def encode(self, text, **kwargs): + return self.encoding.encode(text, **kwargs) + + def decode(self, token_ids: List[int], **kwargs) -> str: + token_ids = [t for t in token_ids if t < self.timestamp_begin] + return self.encoding.decode(token_ids, **kwargs) + + def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str: + """ + Timestamp tokens are above other special tokens' id range and are ignored by `decode()`. + + This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". + """ + return self.encoding.decode(token_ids, **kwargs) + + @cached_property + def eot(self) -> int: + return self.encoding.eot_token + + @cached_property + def transcribe(self) -> int: + return self.special_tokens["<|transcribe|>"] + + @cached_property + def translate(self) -> int: + return self.special_tokens["<|translate|>"] + + @cached_property + def sot(self) -> int: + return self.special_tokens["<|startoftranscript|>"] + + @cached_property + def sot_lm(self) -> int: + return self.special_tokens["<|startoflm|>"] + + @cached_property + def sot_prev(self) -> int: + return self.special_tokens["<|startofprev|>"] + + @cached_property + def no_speech(self) -> int: + return self.special_tokens["<|nospeech|>"] + + @cached_property + def no_timestamps(self) -> int: + return self.special_tokens["<|notimestamps|>"] + + @cached_property + def timestamp_begin(self) -> int: + return self.special_tokens["<|0.00|>"] + + @cached_property + def language_token(self) -> int: + """Returns the token id corresponding to the value of the `language` field.""" + if self.language is None: + raise ValueError("This tokenizer does not have language token configured") + + return self.to_language_token(self.language) + + def to_language_token(self, language): + if token := self.special_tokens.get(f"<|{language}|>", None): + return token + + raise KeyError(f"Language {language} not found in tokenizer.") + + @cached_property + def all_language_tokens(self) -> Tuple[int]: + result = [] + for token, token_id in self.special_tokens.items(): + if token.strip("<|>") in LANGUAGES: + result.append(token_id) + return tuple(result)[:self.num_languages] + + @cached_property + def all_language_codes(self) -> Tuple[str]: + return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens) + + @cached_property + def sot_sequence_including_notimestamps(self) -> Tuple[int]: + return tuple(list(self.sot_sequence) + [self.no_timestamps]) + + @cached_property + def non_speech_tokens(self) -> Tuple[int]: + """ + Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech annotations, + to prevent sampling texts that are not actually spoken in the audio, e.g. + + - ♪♪♪ + - ( SPEAKING FOREIGN LANGUAGE ) + - [DAVID] Hey there, + + keeping basic punctuations like commas, periods, question marks, exclamation points, etc. + """ + symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』') + symbols += ("<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()) + + # symbols that may be a single token or multiple tokens depending on the tokenizer. + # In case they're multiple tokens, suppress the first token, which is safe because: + # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress + # in generations, and in the 3-byte UTF-8 representation they share the first two bytes. + miscellaneous = set("♩♪♫♬♭♮♯") + assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) + + # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word + result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]} + for symbol in symbols + list(miscellaneous): + for tokens in [ + self.encoding.encode(symbol), + self.encoding.encode(" " + symbol), + ]: + if len(tokens) == 1 or symbol in miscellaneous: + result.add(tokens[0]) + + return tuple(sorted(result)) + + def split_to_word_tokens(self, tokens: List[int]): + if self.language in {"zh", "ja", "th", "lo", "my", "yue"}: + # These languages don't typically use spaces, so it is difficult to split words + # without morpheme analysis. Here, we instead split words at any + # position where the tokens are decoded as valid unicode points + return self.split_tokens_on_unicode(tokens) + + return self.split_tokens_on_spaces(tokens) + + def split_tokens_on_unicode(self, tokens: List[int]): + decoded_full = self.decode_with_timestamps(tokens) + replacement_char = "\ufffd" + + words = [] + word_tokens = [] + current_tokens = [] + unicode_offset = 0 + + for token in tokens: + current_tokens.append(token) + decoded = self.decode_with_timestamps(current_tokens) + + if (replacement_char not in decoded or + decoded_full[unicode_offset + decoded.index(replacement_char)] == replacement_char): + words.append(decoded) + word_tokens.append(current_tokens) + current_tokens = [] + unicode_offset += len(decoded) + + return words, word_tokens + + def split_tokens_on_spaces(self, tokens: List[int]): + subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens) + words = [] + word_tokens = [] + + for subword, subword_tokens in zip(subwords, subword_tokens_list): + special = subword_tokens[0] >= self.eot + with_space = subword.startswith(" ") + punctuation = subword.strip() in string.punctuation + if special or with_space or punctuation or len(words) == 0: + words.append(subword) + word_tokens.append(subword_tokens) + else: + words[-1] = words[-1] + subword + word_tokens[-1].extend(subword_tokens) + + return words, word_tokens + + +@lru_cache(maxsize=None) +def get_encoding(name: str = "gpt2", num_languages: int = 99): + vocab_path = os.path.join(os.path.dirname(__file__), "../assets", f"{name}.tiktoken") + with open(vocab_path) as fid: + ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in fid if line)} + n_vocab = len(ranks) + special_tokens = {} + + specials = [ + "<|endoftext|>", + "<|startoftranscript|>", + *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], + "<|translate|>", + "<|transcribe|>", + "<|startoflm|>", + "<|startofprev|>", + "<|nospeech|>", + "<|notimestamps|>", + *[f"<|{i * 0.02:.2f}|>" for i in range(1501)], + ] + + for token in specials: + special_tokens[token] = n_vocab + n_vocab += 1 + + return tiktoken.Encoding( + name=os.path.basename(vocab_path), + explicit_n_vocab=n_vocab, + pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", + mergeable_ranks=ranks, + special_tokens=special_tokens, + ) + + +@lru_cache(maxsize=None) +def get_tokenizer( + multilingual: bool, + *, + num_languages: int = 99, + language: Optional[str] = None, + task: Optional[str] = None, # Literal["transcribe", "translate", None] +) -> Tokenizer: + if language is not None: + language = language.lower() + if language not in LANGUAGES: + if language in TO_LANGUAGE_CODE: + language = TO_LANGUAGE_CODE[language] + else: + raise ValueError(f"Unsupported language: {language}") + + if multilingual: + encoding_name = "multilingual" + language = language or "en" + task = task or "transcribe" + else: + encoding_name = "gpt2" + language = None + task = None + + encoding = get_encoding(name=encoding_name, num_languages=num_languages) + + return Tokenizer(encoding=encoding, num_languages=num_languages, language=language, task=task) diff --git a/whisperplus/pipelines/lightning_whisper_mlx/torch_whisper.py b/whisperplus/pipelines/lightning_whisper_mlx/torch_whisper.py new file mode 100644 index 0000000..43262b1 --- /dev/null +++ b/whisperplus/pipelines/lightning_whisper_mlx/torch_whisper.py @@ -0,0 +1,284 @@ +# Copyright © 2023 Apple Inc. + +import base64 +import gzip +from dataclasses import dataclass +from typing import Dict, Iterable, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +@dataclass +class ModelDimensions: + n_mels: int + n_audio_ctx: int + n_audio_state: int + n_audio_head: int + n_audio_layer: int + n_vocab: int + n_text_ctx: int + n_text_state: int + n_text_head: int + n_text_layer: int + + +class LayerNorm(nn.LayerNorm): + + def forward(self, x: Tensor) -> Tensor: + return super().forward(x.float()).type(x.dtype) + + +class Linear(nn.Linear): + + def forward(self, x: Tensor) -> Tensor: + return F.linear( + x, + self.weight.to(x.dtype), + None if self.bias is None else self.bias.to(x.dtype), + ) + + +class Conv1d(nn.Conv1d): + + def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: + return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)) + + +def sinusoids(length, channels, max_timescale=10000): + """Returns sinusoids for positional embedding.""" + assert channels % 2 == 0 + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + + +class MultiHeadAttention(nn.Module): + + def __init__(self, n_state: int, n_head: int): + super().__init__() + self.n_head = n_head + self.query = Linear(n_state, n_state) + self.key = Linear(n_state, n_state, bias=False) + self.value = Linear(n_state, n_state) + self.out = Linear(n_state, n_state) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + ): + q = self.query(x) + + if kv_cache is None or xa is None or self.key not in kv_cache: + # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; + # otherwise, perform key/value projections for self- or cross-attention as usual. + k = self.key(x if xa is None else xa) + v = self.value(x if xa is None else xa) + else: + # for cross-attention, calculate keys and values once and reuse in subsequent calls. + k = kv_cache[self.key] + v = kv_cache[self.value] + + wv, qk = self.qkv_attention(q, k, v, mask) + return self.out(wv), qk + + def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None): + n_batch, n_ctx, n_state = q.shape + scale = (n_state // self.n_head)**-0.25 + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale + v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + + qk = q @ k + if mask is not None: + qk = qk + mask[:n_ctx, :n_ctx] + qk = qk.float() + + w = F.softmax(qk, dim=-1).to(q.dtype) + return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): + super().__init__() + + self.attn = MultiHeadAttention(n_state, n_head) + self.attn_ln = LayerNorm(n_state) + + self.cross_attn = (MultiHeadAttention(n_state, n_head) if cross_attention else None) + self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None + + n_mlp = n_state * 4 + self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)) + self.mlp_ln = LayerNorm(n_state) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + ): + x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] + if self.cross_attn: + x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] + x = x + self.mlp(self.mlp_ln(x)) + return x + + +class AudioEncoder(nn.Module): + + def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): + super().__init__() + self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) + self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) + self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) + + self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( + [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]) + self.ln_post = LayerNorm(n_state) + + def forward(self, x: Tensor): + """ + x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) + the mel spectrogram of the audio + """ + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) + + assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" + x = (x + self.positional_embedding).to(x.dtype) + + for block in self.blocks: + x = block(x) + + x = self.ln_post(x) + return x + + +class TextDecoder(nn.Module): + + def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): + super().__init__() + + self.token_embedding = nn.Embedding(n_vocab, n_state) + self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) + + self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( + [ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]) + self.ln = LayerNorm(n_state) + + mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) + self.register_buffer("mask", mask, persistent=False) + + def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): + """ + x : torch.LongTensor, shape = (batch_size, <= n_ctx) + the text tokens + xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) + the encoded audio features to be attended on + """ + offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 + x = (self.token_embedding(x) + self.positional_embedding[offset:offset + x.shape[-1]]) + x = x.to(xa.dtype) + + for block in self.blocks: + x = block(x, xa, mask=self.mask, kv_cache=kv_cache) + + x = self.ln(x) + logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float() + + return logits + + +class Whisper(nn.Module): + + def __init__(self, dims: ModelDimensions): + super().__init__() + self.dims = dims + self.encoder = AudioEncoder( + self.dims.n_mels, + self.dims.n_audio_ctx, + self.dims.n_audio_state, + self.dims.n_audio_head, + self.dims.n_audio_layer, + ) + self.decoder = TextDecoder( + self.dims.n_vocab, + self.dims.n_text_ctx, + self.dims.n_text_state, + self.dims.n_text_head, + self.dims.n_text_layer, + ) + # use the last half among the decoder layers for time alignment by default; + # to use a specific set of heads, see `set_alignment_heads()` below. + all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool) + all_heads[self.dims.n_text_layer // 2:] = True + self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) + + def set_alignment_heads(self, dump: bytes): + array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy() + mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head) + self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False) + + def embed_audio(self, mel: torch.Tensor): + return self.encoder(mel) + + def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): + return self.decoder(tokens, audio_features) + + def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]: + return self.decoder(tokens, self.encoder(mel)) + + @property + def device(self): + return next(self.parameters()).device + + @property + def is_multilingual(self): + return self.dims.n_vocab >= 51865 + + @property + def num_languages(self): + return self.dims.n_vocab - 51765 - int(self.is_multilingual) + + def install_kv_cache_hooks(self, cache: Optional[dict] = None): + """ + The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value tensors + calculated for the previous positions. This method returns a dictionary that stores all caches, and + the necessary hooks for the key and value projection modules that save the intermediate tensors to be + reused during later calculations. + + Returns + ------- + cache : Dict[nn.Module, torch.Tensor] + A dictionary object mapping the key/value projection modules to its cache + hooks : List[RemovableHandle] + List of PyTorch RemovableHandle objects to stop the hooks to be called + """ + cache = {**cache} if cache is not None else {} + hooks = [] + + def save_to_cache(module, _, output): + if module not in cache or output.shape[1] > self.dims.n_text_ctx: + # save as-is, for the first token or cross attention + cache[module] = output + else: + cache[module] = torch.cat([cache[module], output], dim=1).detach() + return cache[module] + + def install_hooks(layer: nn.Module): + if isinstance(layer, MultiHeadAttention): + hooks.append(layer.key.register_forward_hook(save_to_cache)) + hooks.append(layer.value.register_forward_hook(save_to_cache)) + + self.decoder.apply(install_hooks) + return cache, hooks diff --git a/whisperplus/pipelines/lightning_whisper_mlx/transcribe.py b/whisperplus/pipelines/lightning_whisper_mlx/transcribe.py new file mode 100644 index 0000000..1ad8ee5 --- /dev/null +++ b/whisperplus/pipelines/lightning_whisper_mlx/transcribe.py @@ -0,0 +1,389 @@ +# Copyright © 2023 Apple Inc. + +import sys +import time +import warnings +from typing import List, Optional, Tuple, Union + +import mlx.core as mx +import numpy as np +import tqdm + +from .audio import ( + FRAMES_PER_SECOND, + HOP_LENGTH, + N_FRAMES, + N_SAMPLES, + SAMPLE_RATE, + log_mel_spectrogram, + pad_or_trim, +) +from .decoding import DecodingOptions, DecodingResult +from .load_models import load_model +from .timing import add_word_timestamps +from .tokenizer import LANGUAGES, get_tokenizer + + +def _format_timestamp(seconds: float): + assert seconds >= 0, "non-negative timestamp expected" + milliseconds = round(seconds * 1000.0) + + hours = milliseconds // 3_600_000 + milliseconds -= hours * 3_600_000 + + minutes = milliseconds // 60_000 + milliseconds -= minutes * 60_000 + + seconds = milliseconds // 1_000 + milliseconds -= seconds * 1_000 + + hours_marker = f"{hours:02d}:" if hours > 0 else "" + return f"{hours_marker}{minutes:02d}:{seconds:02d}.{milliseconds:03d}" + + +def _get_end(segments: List[dict]) -> Optional[float]: + return next( + (w["end"] for s in reversed(segments) for w in reversed(s["words"])), + segments[-1]["end"] if segments else None, + ) + + +class ModelHolder: + model = None + model_path = None + + @classmethod + def get_model(cls, model_path: str, dtype: mx.Dtype): + if cls.model is None or model_path != cls.model_path: + cls.model = load_model(model_path, dtype=dtype) + cls.model_path = model_path + return cls.model + + +def transcribe_audio( + audio: Union[str, np.ndarray, mx.array], + *, + path_or_hf_repo: str = "mlx-community/whisper-tiny", + verbose: Optional[bool] = None, + temperature: Union[float, Tuple[float, ...]] = (0.0, 1.0), + compression_ratio_threshold: Optional[float] = 2.4, + logprob_threshold: Optional[float] = -1, + no_speech_threshold: Optional[float] = 0.6, + condition_on_previous_text: bool = True, + initial_prompt: Optional[str] = None, + word_timestamps: bool = False, + prepend_punctuations: str = "\"'“¿([{-", + append_punctuations: str = "\"'.。,,!!??::”)]}、", + clip_timestamps: Union[str, List[float]] = "0", + hallucination_silence_threshold: Optional[float] = None, + batch_size: 6, + **decode_options, +): + """ + Transcribe an audio file using Whisper. + + Parameters + ---------- + audio: Union[str, np.ndarray, mx.array] + The path to the audio file to open, or the audio waveform + + path_or_hf_repo: str + The localpath to the Whisper model or HF Hub repo with the MLX converted weights. + + verbose: bool + Whether to display the text being decoded to the console. If True, displays all the details, + If False, displays minimal details. If None, does not display anything + + temperature: Union[float, Tuple[float, ...]] + Temperature for sampling. It can be a tuple of temperatures, which will be successively used + upon failures according to either `compression_ratio_threshold` or `logprob_threshold`. + + compression_ratio_threshold: float + If the gzip compression ratio is above this value, treat as failed + + logprob_threshold: float + If the average log probability over sampled tokens is below this value, treat as failed + + no_speech_threshold: float + If the no_speech probability is higher than this value AND the average log probability + over sampled tokens is below `logprob_threshold`, consider the segment as silent + + condition_on_previous_text: bool + if True, the previous output of the model is provided as a prompt for the next window; + disabling may make the text inconsistent across windows, but the model becomes less prone to + getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. + + word_timestamps: bool + Extract word-level timestamps using the cross-attention pattern and dynamic time warping, + and include the timestamps for each word in each segment. + + prepend_punctuations: str + If word_timestamps is True, merge these punctuation symbols with the next word + + append_punctuations: str + If word_timestamps is True, merge these punctuation symbols with the previous word + + initial_prompt: Optional[str] + Optional text to provide as a prompt for the first window. This can be used to provide, or + "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns + to make it more likely to predict those word correctly. + + decode_options: dict + Keyword arguments to construct `DecodingOptions` instances + + clip_timestamps: Union[str, List[float]] + Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process. + The last end timestamp defaults to the end of the file. + + hallucination_silence_threshold: Optional[float] + When word_timestamps is True, skip silent periods longer than this threshold (in seconds) + when a possible hallucination is detected + + Returns + ------- + A dictionary containing the resulting text ("text") and segment-level details ("segments"), and + the spoken language ("language"), which is detected when `decode_options["language"]` is None. + """ + + dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32 + model = ModelHolder.get_model(path_or_hf_repo, dtype) + + # Pad 30-seconds of silence to the input audio, for slicing + mel = log_mel_spectrogram(audio, n_mels=model.dims.n_mels, padding=N_SAMPLES) + content_frames = mel.shape[-2] - N_FRAMES + + if decode_options.get("language", None) is None: + if not model.is_multilingual: + decode_options["language"] = "en" + else: + if verbose: + print( + "Detecting language using up to the first 30 seconds. " + "Use the `language` decoding option to specify the language") + mel_segment = pad_or_trim(mel, N_FRAMES, axis=-2).astype(dtype) + _, probs = model.detect_language(mel_segment) + decode_options["language"] = max(probs, key=probs.get) + if verbose is not None: + print(f"Detected language: {LANGUAGES[decode_options['language']].title()}") + + language: str = decode_options["language"] + task: str = decode_options.get("task", "transcribe") + tokenizer = get_tokenizer( + model.is_multilingual, + num_languages=model.num_languages, + language=language, + task=task, + ) + + if isinstance(clip_timestamps, str): + clip_timestamps = [float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])] + seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps] + if len(seek_points) == 0: + seek_points.append(0) + if len(seek_points) % 2 == 1: + seek_points.append(content_frames) + seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2])) + + punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、" + + if word_timestamps and task == "translate": + warnings.warn("Word-level timestamps on translations may not be reliable.") + + def decode_process(segment_batch, t): + kwargs = {**decode_options} + options = DecodingOptions(**kwargs, temperature=t) + decode_results = model.decode(segment_batch, options) + return decode_results + + def decode_with_fallback(segment_batch: mx.array) -> DecodingResult: + decode_results = decode_process(segment_batch, 0.0) + final_decode = [] + + for i, decode_result in enumerate(decode_results): + segment = segment_batch[i:i + 1, :, :] + needs_fallback = False + if (compression_ratio_threshold is not None and + decode_result.compression_ratio > compression_ratio_threshold): + needs_fallback = True + + if (logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold): + needs_fallback = True + + if (no_speech_threshold is not None and decode_result.no_speech_prob > no_speech_threshold): + needs_fallback = False + + if needs_fallback: + final_decode.append(decode_process(segment, 1.0)[0]) + else: + final_decode.append(decode_result) + + return final_decode + + clip_idx = 0 + seek = seek_clips[clip_idx][0] + input_stride = N_FRAMES // model.dims.n_audio_ctx # mel frames per output token: 2 + time_precision = (input_stride * HOP_LENGTH / SAMPLE_RATE) # time per output token: 0.02 (seconds) + all_tokens = [] + all_segments = [] + prompt_reset_since = 0 + + if initial_prompt is not None: + initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip()) + all_tokens.extend(initial_prompt_tokens) + else: + initial_prompt_tokens = [] + + def new_segment(*, start: float, end: float, tokens: mx.array, result: DecodingResult): + tokens = tokens.tolist() + text_tokens = [token for token in tokens if token < tokenizer.eot] + return { + "seek": seek, + "start": start, + "end": end, + "text": tokenizer.decode(text_tokens), + "tokens": tokens, + "temperature": res.temperature, + "avg_logprob": res.avg_logprob, + "compression_ratio": res.compression_ratio, + "no_speech_prob": res.no_speech_prob, + } + + def format_output(tokens, res): + seek = 0 + current_segments = [] + + if no_speech_threshold is not None: + should_skip = res.no_speech_prob > no_speech_threshold + if (logprob_threshold is not None and res.avg_logprob > logprob_threshold): + should_skip = False + + if should_skip: + seek += (segment_size) + return current_segments, seek + + def word_anomaly_score(word: dict) -> float: + probability = word.get("probability", 0.0) + duration = word["end"] - word["start"] + score = 0.0 + if probability < 0.15: + score += 1.0 + if duration < 0.133: + score += (0.133 - duration) * 15 + if duration > 2.0: + score += duration - 2.0 + return score + + def is_segment_anomaly(segment: Optional[dict]) -> bool: + if segment is None or not segment["words"]: + return False + words = [w for w in segment["words"] if w["word"] not in punctuation] + words = words[:8] + score = sum(word_anomaly_score(w) for w in words) + return score >= 3 or score + 0.01 >= len(words) + + def next_words_segment(segments: List[dict]) -> Optional[dict]: + return next((s for s in segments if s["words"]), None) + + timestamp_tokens = tokens >= tokenizer.timestamp_begin + single_timestamp_ending = timestamp_tokens[-2:].tolist() == [ + False, + True, + ] + + consecutive = np.where(np.logical_and(timestamp_tokens[:-1], timestamp_tokens[1:]))[0] + consecutive += 1 + if len(consecutive) > 0: + slices = consecutive.tolist() + if single_timestamp_ending: + slices.append(len(tokens)) + + last_slice = 0 + for current_slice in slices: + sliced_tokens = tokens[last_slice:current_slice] + start_timestamp_pos = (sliced_tokens[0].item() - tokenizer.timestamp_begin) + end_timestamp_pos = (sliced_tokens[-1].item() - tokenizer.timestamp_begin) + current_segments.append( + new_segment( + start=time_offset + start_timestamp_pos * time_precision, + end=time_offset + end_timestamp_pos * time_precision, + tokens=sliced_tokens, + result=res, + )) + last_slice = current_slice + + if single_timestamp_ending: + seek += segment_size + else: + last_timestamp_pos = (tokens[last_slice - 1].item() - tokenizer.timestamp_begin) + seek += last_timestamp_pos * input_stride + else: + duration = segment_duration + timestamps = tokens[timestamp_tokens.nonzero()[0]] + if (len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin): + last_timestamp_pos = (timestamps[-1].item() - tokenizer.timestamp_begin) + duration = last_timestamp_pos * time_precision + + current_segments.append( + new_segment( + start=time_offset, + end=time_offset + duration, + tokens=tokens, + result=res, + )) + seek += segment_size + + for i, segment in enumerate(current_segments): + if (segment["start"] == segment["end"] or segment["text"].strip() == ""): + segment["text"] = "" + segment["tokens"] = [] + segment["words"] = [] + + return current_segments, seek + + seek_clip_end = seek_clips[0][1] + seek = -3000 + while seek < seek_clip_end: + time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE) + + mel_segments = [] + mel_timestamps = [] + + for _ in range(batch_size): + seek += N_FRAMES + if seek > seek_clip_end: + break + segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek) + mel_segment = mel[seek:seek + segment_size] + + segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE + mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(dtype) + mel_segments.append(mel_segment) + mel_timestamps.append((seek, seek + segment_size)) + + if not len(mel_segments): + break + + mel_segment_batch = mx.array(mx.stack(mel_segments, axis=0)) + decode_options["prompt"] = all_tokens[prompt_reset_since:] + result: DecodingResult = decode_with_fallback(mel_segment_batch) + + for index, res in enumerate(result): + start_seek, end_seek = mel_timestamps[index] + + tokens = np.array(res.tokens) + current_segments, value_seek = format_output(tokens, res) + + tokens = [token for segment in current_segments for token in segment["tokens"]] + + all_segments.append([start_seek, end_seek, tokenizer.decode(tokens)]) + + all_tokens.extend([token for segment in current_segments for token in segment["tokens"]]) + + if not condition_on_previous_text or res.temperature > 0.5: + prompt_reset_since = len(all_tokens) + + return dict( + text=tokenizer.decode(all_tokens[len(initial_prompt_tokens):]), + segments=all_segments, + language=language, + ) diff --git a/whisperplus/pipelines/lightning_whisper_mlx/whisper.py b/whisperplus/pipelines/lightning_whisper_mlx/whisper.py new file mode 100644 index 0000000..53561d0 --- /dev/null +++ b/whisperplus/pipelines/lightning_whisper_mlx/whisper.py @@ -0,0 +1,253 @@ +# Copyright © 2023 Apple Inc. + +import base64 +import gzip +import math +from dataclasses import dataclass +from typing import Union + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from .decoding import decode as decode_function +from .decoding import detect_language as detect_language_function + + +@dataclass +class ModelDimensions: + n_mels: int + n_audio_ctx: int + n_audio_state: int + n_audio_head: int + n_audio_layer: int + n_vocab: int + n_text_ctx: int + n_text_state: int + n_text_head: int + n_text_layer: int + + +def sinusoids(length, channels, max_timescale=10000): + """Returns sinusoids for positional embedding.""" + assert channels % 2 == 0 + log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1) + inv_timescales = mx.exp(-log_timescale_increment * mx.arange(channels // 2)) + scaled_time = mx.arange(length)[:, None] * inv_timescales[None, :] + return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1) + + +class MultiHeadAttention(nn.Module): + + def __init__(self, n_state: int, n_head: int): + super().__init__() + self.n_head = n_head + self.query = nn.Linear(n_state, n_state) + self.key = nn.Linear(n_state, n_state, bias=False) + self.value = nn.Linear(n_state, n_state) + self.out = nn.Linear(n_state, n_state) + + def __call__( + self, + x, + xa=None, + mask=None, + kv_cache=None, + ): + q = self.query(x) + + if xa is None: + k = self.key(x) + v = self.value(x) + if kv_cache is not None: + k = mx.concatenate([kv_cache[0], k], axis=1) + v = mx.concatenate([kv_cache[1], v], axis=1) + elif kv_cache is None: + k = self.key(xa) + v = self.value(xa) + else: + k, v = kv_cache + + wv, qk = self.qkv_attention(q, k, v, mask) + return self.out(wv), (k, v), qk + + def qkv_attention(self, q, k, v, mask=None): + n_batch, n_ctx, n_state = q.shape + scale = (n_state // self.n_head)**-0.25 + q = q.reshape(*q.shape[:2], self.n_head, -1).transpose(0, 2, 1, 3) * scale + k = k.reshape(*k.shape[:2], self.n_head, -1).transpose(0, 2, 3, 1) * scale + v = v.reshape(*v.shape[:2], self.n_head, -1).transpose(0, 2, 1, 3) + + qk = q @ k + if mask is not None: + qk = qk + mask[:n_ctx, :n_ctx] + qk = qk.astype(mx.float32) + + w = mx.softmax(qk, axis=-1).astype(q.dtype) + out = (w @ v).transpose(0, 2, 1, 3) + out = out.reshape(n_batch, n_ctx, n_state) + return out, qk + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): + super().__init__() + + self.attn = MultiHeadAttention(n_state, n_head) + self.attn_ln = nn.LayerNorm(n_state) + + self.cross_attn = (MultiHeadAttention(n_state, n_head) if cross_attention else None) + self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None + + n_mlp = n_state * 4 + self.mlp1 = nn.Linear(n_state, n_mlp) + self.mlp2 = nn.Linear(n_mlp, n_state) + self.mlp_ln = nn.LayerNorm(n_state) + + def __call__(self, x, xa=None, mask=None, kv_cache=None): + kv, cross_kv = kv_cache if kv_cache else (None, None) + y, kv, _ = self.attn(self.attn_ln(x), mask=mask, kv_cache=kv) + x += y + cross_qk = None + if self.cross_attn: + y, cross_kv, cross_qk = self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=cross_kv) + x += y + x = x + self.mlp2(nn.gelu(self.mlp1(self.mlp_ln(x)))) + return x, (kv, cross_kv), cross_qk + + +class AudioEncoder(nn.Module): + + def __init__( + self, + n_mels: int, + n_ctx: int, + n_state: int, + n_head: int, + n_layer: int, + dtype: mx.Dtype = mx.float16, + ): + super().__init__() + self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) + self._positional_embedding = sinusoids(n_ctx, n_state).astype(dtype) + + self.blocks = [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] + self.ln_post = nn.LayerNorm(n_state) + + def __call__(self, x): + x = nn.gelu(self.conv1(x)) + x = nn.gelu(self.conv2(x)) + assert x.shape[1:] == self._positional_embedding.shape, "incorrect audio shape" + x = x + self._positional_embedding + + for block in self.blocks: + x, _, _ = block(x) + + x = self.ln_post(x) + return x + + +class TextDecoder(nn.Module): + + def __init__( + self, + n_vocab: int, + n_ctx: int, + n_state: int, + n_head: int, + n_layer: int, + dtype: mx.Dtype = mx.float16, + ): + super().__init__() + + self.token_embedding = nn.Embedding(n_vocab, n_state) + self.positional_embedding = mx.zeros((n_ctx, n_state)) + + self.blocks = [ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)] + self.ln = nn.LayerNorm(n_state) + self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(dtype) + + def __call__(self, x, xa, kv_cache=None): + """ + x : mx.array, shape = (batch_size, <= n_ctx) + the text tokens + xa : mx.array, shape = (batch_size, n_audio_ctx, n_audio_state) + the encoded audio features to be attended on + """ + offset = kv_cache[0][0][0].shape[1] if kv_cache else 0 + x = (self.token_embedding(x) + self.positional_embedding[offset:offset + x.shape[-1]]) + + if kv_cache is None: + kv_cache = [None] * len(self.blocks) + cross_qk = [None] * len(self.blocks) + for e, block in enumerate(self.blocks): + x, kv_cache[e], cross_qk[e] = block(x, xa, mask=self._mask, kv_cache=kv_cache[e]) + + x = self.ln(x) + return x @ self.token_embedding.weight.T, kv_cache, cross_qk + + +class Whisper(nn.Module): + + def __init__(self, dims: ModelDimensions, dtype: mx.Dtype = mx.float16): + super().__init__() + self.dims = dims + self.encoder = AudioEncoder( + self.dims.n_mels, + self.dims.n_audio_ctx, + self.dims.n_audio_state, + self.dims.n_audio_head, + self.dims.n_audio_layer, + dtype, + ) + self.decoder = TextDecoder( + self.dims.n_vocab, + self.dims.n_text_ctx, + self.dims.n_text_state, + self.dims.n_text_head, + self.dims.n_text_layer, + dtype, + ) + # use the last half among the decoder layers for time alignment by default; + # to use a specific set of heads, see `set_alignment_heads()` below. + all_heads = np.zeros((self.dims.n_text_layer, self.dims.n_text_head), dtype=bool) + all_heads[self.dims.n_text_layer // 2:] = True + self.alignment_heads = mx.array(np.asarray(all_heads.nonzero()).T) + + def set_alignment_heads(self, dump: Union[bytes, np.ndarray]): + if isinstance(dump, np.ndarray): + self.alignment_heads = mx.array(dump) + elif isinstance(dump, bytes): + array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy() + mask = array.reshape(self.dims.n_text_layer, self.dims.n_text_head) + self.alignment_heads = mx.array(np.asarray(mask.nonzero()).T) + else: + raise ValueError( + f"Invalid type for `dump`: {type(dump)}. Expected a np.ndarray or base85-encoded bytes containing" + " alignment_head information") + + def embed_audio(self, mel): + return self.encoder(mel) + + def logits(self, tokens, audio_features): + return self.decoder(tokens, audio_features)[0] + + def forward_with_cross_qk(self, mel, tokens): + logits, _, cross_qk = self.decoder(tokens, self.encoder(mel)) + return logits, cross_qk + + def __call__(self, mel, tokens): + return self.decoder(tokens, self.encoder(mel))[0] + + @property + def is_multilingual(self): + return self.dims.n_vocab >= 51865 + + @property + def num_languages(self): + return self.dims.n_vocab - 51765 - int(self.is_multilingual) + + detect_language = detect_language_function + decode = decode_function diff --git a/whisperplus/pipelines/mlx_whisper/audio.py b/whisperplus/pipelines/mlx_whisper/audio.py index bf65f3c..bd19f32 100644 --- a/whisperplus/pipelines/mlx_whisper/audio.py +++ b/whisperplus/pipelines/mlx_whisper/audio.py @@ -87,7 +87,7 @@ def mel_filters(n_mels: int) -> mx.array: """ assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" - filename = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") + filename = os.path.join(os.path.dirname(__file__), "../assets", "mel_filters.npz") return mx.load(filename)[f"mel_{n_mels}"] diff --git a/whisperplus/pipelines/mlx_whisper/tokenizer.py b/whisperplus/pipelines/mlx_whisper/tokenizer.py index 5251dec..b8c0378 100644 --- a/whisperplus/pipelines/mlx_whisper/tokenizer.py +++ b/whisperplus/pipelines/mlx_whisper/tokenizer.py @@ -330,7 +330,7 @@ def split_tokens_on_spaces(self, tokens: List[int]): @lru_cache(maxsize=None) def get_encoding(name: str = "gpt2", num_languages: int = 99): - vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") + vocab_path = os.path.join(os.path.dirname(__file__), "../assets", f"{name}.tiktoken") with open(vocab_path) as fid: ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in fid if line)} n_vocab = len(ranks)