-
-
Notifications
You must be signed in to change notification settings - Fork 136
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #108 from kadirnar/add_lightning_mlx
π Add Lightning Mlx library
- Loading branch information
Showing
16 changed files
with
2,585 additions
and
2 deletions.
There are no files selected for viewing
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .lightning import LightningWhisperMLX |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# Copyright Β© 2023 Apple Inc. | ||
|
||
import os | ||
from functools import lru_cache | ||
from subprocess import CalledProcessError, run | ||
from typing import Optional, Union | ||
|
||
import mlx.core as mx | ||
import numpy as np | ||
|
||
# hard-coded audio hyperparameters | ||
SAMPLE_RATE = 16000 | ||
N_FFT = 400 | ||
HOP_LENGTH = 160 | ||
CHUNK_LENGTH = 30 | ||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE | ||
N_FRAMES = N_SAMPLES // HOP_LENGTH | ||
|
||
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 | ||
FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame | ||
TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token | ||
|
||
|
||
def load_audio(file: str, sr: int = SAMPLE_RATE): | ||
""" | ||
Open an audio file and read as mono waveform, resampling as necessary. | ||
Parameters | ||
---------- | ||
file: str | ||
The audio file to open | ||
sr: int | ||
The sample rate to resample the audio if necessary | ||
Returns | ||
------- | ||
A NumPy array containing the audio waveform, in float32 dtype. | ||
""" | ||
|
||
# This launches a subprocess to decode audio while down-mixing | ||
# and resampling as necessary. Requires the ffmpeg CLI in PATH. | ||
# fmt: off | ||
cmd = [ | ||
"ffmpeg", | ||
"-nostdin", | ||
"-threads", "0", | ||
"-i", file, | ||
"-f", "s16le", | ||
"-ac", "1", | ||
"-acodec", "pcm_s16le", | ||
"-ar", str(sr), | ||
"-" | ||
] | ||
# fmt: on | ||
try: | ||
out = run(cmd, capture_output=True, check=True).stdout | ||
except CalledProcessError as e: | ||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e | ||
|
||
return mx.array(np.frombuffer(out, np.int16)).flatten().astype(mx.float32) / 32768.0 | ||
|
||
|
||
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): | ||
"""Pad or trim the audio array to N_SAMPLES, as expected by the encoder.""" | ||
if array.shape[axis] > length: | ||
sl = [slice(None)] * array.ndim | ||
sl[axis] = slice(0, length) | ||
array = array[tuple(sl)] | ||
|
||
if array.shape[axis] < length: | ||
pad_widths = [(0, 0)] * array.ndim | ||
pad_widths[axis] = (0, length - array.shape[axis]) | ||
array = mx.pad(array, pad_widths) | ||
|
||
return array | ||
|
||
|
||
@lru_cache(maxsize=None) | ||
def mel_filters(n_mels: int) -> mx.array: | ||
""" | ||
Load the mel filterbank matrix for projecting STFT into a Mel spectrogram. Allows decoupling librosa | ||
dependency; saved using: | ||
np.savez_compressed( "mel_filters.npz", mel_80=librosa.filters.mel(sr=16000, n_fft=400, | ||
n_mels=80), mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), ) | ||
""" | ||
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" | ||
|
||
filename = os.path.join(os.path.dirname(__file__), "../assets", "mel_filters.npz") | ||
return mx.load(filename)[f"mel_{n_mels}"] | ||
|
||
|
||
@lru_cache(maxsize=None) | ||
def hanning(size): | ||
return mx.array(np.hanning(size + 1)[:-1]) | ||
|
||
|
||
def stft(x, window, nperseg=256, noverlap=None, nfft=None, axis=-1, pad_mode="reflect"): | ||
if nfft is None: | ||
nfft = nperseg | ||
if noverlap is None: | ||
noverlap = nfft // 4 | ||
|
||
def _pad(x, padding, pad_mode="constant"): | ||
if pad_mode == "constant": | ||
return mx.pad(x, [(padding, padding)]) | ||
elif pad_mode == "reflect": | ||
prefix = x[1:padding + 1][::-1] | ||
suffix = x[-(padding + 1):-1][::-1] | ||
return mx.concatenate([prefix, x, suffix]) | ||
else: | ||
raise ValueError(f"Invalid pad_mode {pad_mode}") | ||
|
||
padding = nperseg // 2 | ||
x = _pad(x, padding, pad_mode) | ||
|
||
strides = [noverlap, 1] | ||
t = (x.size - nperseg + noverlap) // noverlap | ||
shape = [t, nfft] | ||
x = mx.as_strided(x, shape=shape, strides=strides) | ||
return mx.fft.rfft(x * window) | ||
|
||
|
||
def log_mel_spectrogram( | ||
audio: Union[str, np.ndarray], | ||
n_mels: int = 80, | ||
padding: int = 0, | ||
): | ||
""" | ||
Compute the log-Mel spectrogram of. | ||
Parameters | ||
---------- | ||
audio: Union[str, np.ndarray, mx.array], shape = (*) | ||
The path to audio or either a NumPy or mlx array containing the audio waveform in 16 kHz | ||
n_mels: int | ||
The number of Mel-frequency filters, only 80 is supported | ||
padding: int | ||
Number of zero samples to pad to the right | ||
Returns | ||
------- | ||
mx.array, shape = (80, n_frames) | ||
An array that contains the Mel spectrogram | ||
""" | ||
device = mx.default_device() | ||
mx.set_default_device(mx.cpu) | ||
if isinstance(audio, str): | ||
audio = load_audio(audio) | ||
elif not isinstance(audio, mx.array): | ||
audio = mx.array(audio) | ||
|
||
if padding > 0: | ||
audio = mx.pad(audio, (0, padding)) | ||
window = hanning(N_FFT) | ||
freqs = stft(audio, window, nperseg=N_FFT, noverlap=HOP_LENGTH) | ||
magnitudes = freqs[:-1, :].abs().square() | ||
|
||
filters = mel_filters(n_mels) | ||
mel_spec = magnitudes @ filters.T | ||
|
||
log_spec = mx.maximum(mel_spec, 1e-10).log10() | ||
log_spec = mx.maximum(log_spec, log_spec.max() - 8.0) | ||
log_spec = (log_spec + 4.0) / 4.0 | ||
mx.set_default_device(device) | ||
return log_spec |
Oops, something went wrong.