-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[cm] Cached Mel spec implemetation (#82)
* [cm] Init implementation and test of cached mel spec * [cm] Fixing init implementation * [cm] Fixing bug * [cm] Decreasing delay * [cm] Adding documentation * [cm] Adding additional documentation
- Loading branch information
1 parent
9b2e8bc
commit a73c53a
Showing
2 changed files
with
216 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import logging | ||
import os | ||
from typing import Optional, Callable | ||
|
||
import torch as tr | ||
from torch import Tensor | ||
from torch import nn | ||
from torchaudio.transforms import MelSpectrogram | ||
|
||
from neutone_sdk import CircularInplaceTensorQueue | ||
|
||
logging.basicConfig() | ||
log = logging.getLogger(__name__) | ||
log.setLevel(level=os.environ.get("LOGLEVEL", "INFO")) | ||
|
||
|
||
class CachedMelSpec(nn.Module): | ||
def __init__( | ||
self, | ||
sr: int, | ||
n_ch: int, | ||
n_fft: int = 2048, | ||
hop_len: int = 512, | ||
f_min: float = 0.0, | ||
f_max: Optional[float] = None, | ||
n_mels: int = 128, | ||
window_fn: Callable[..., Tensor] = tr.hann_window, | ||
power: float = 2.0, | ||
normalized: bool = False, | ||
center: bool = True, | ||
use_debug_mode: bool = True, | ||
) -> None: | ||
""" | ||
Creates a Mel spectrogram that supports streaming of a centered, non-causal | ||
Mel spectrogram operation that uses zero padding. Using this will result in | ||
audio being delayed by (n_fft / 2) - hop_len samples. When calling forward, | ||
the input audio block length must be a multiple of the hop length. | ||
Parameters: | ||
sr (int): Sample rate of the audio | ||
n_ch (int): Number of audio channels | ||
n_fft (int): STFT n_fft (must be even) | ||
hop_len (int): STFT hop length (must divide into n_fft // 2) | ||
f_min (float): Minimum frequency of the Mel filterbank | ||
f_max (float): Maximum frequency of the Mel filterbank | ||
n_mels (int): Number of mel filterbank bins | ||
window_fn (Callable[..., Tensor]): A function to create a window tensor | ||
power (float): Exponent for the magnitude spectrogram (must be > 0) | ||
normalized (bool): Whether to normalize the mel spectrogram or not | ||
center (bool): Whether to center the mel spectrogram (must be True) | ||
use_debug_mode (bool): Whether to use debug mode or not | ||
""" | ||
super().__init__() | ||
assert center, "center must be True, causal mode is not supported yet" | ||
assert n_fft % 2 == 0, "n_fft must be even" | ||
assert (n_fft // 2) % hop_len == 0, "n_fft // 2 must be divisible by hop_len" | ||
self.n_ch = n_ch | ||
self.n_fft = n_fft | ||
self.hop_len = hop_len | ||
self.use_debug_mode = use_debug_mode | ||
self.mel_spec = MelSpectrogram( | ||
sample_rate=sr, | ||
n_fft=n_fft, | ||
hop_length=hop_len, | ||
f_min=f_min, | ||
f_max=f_max, | ||
n_mels=n_mels, | ||
window_fn=window_fn, | ||
power=power, | ||
normalized=normalized, | ||
center=False, # We use a causal STFT since we do the padding ourselves | ||
) | ||
self.padding_n_samples = self.n_fft - self.hop_len | ||
self.cache = CircularInplaceTensorQueue( | ||
n_ch, self.padding_n_samples, use_debug_mode | ||
) | ||
self.register_buffer("padding", tr.zeros((n_ch, self.padding_n_samples))) | ||
self.cache.push(self.padding) | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
""" | ||
Computes the Mel spectrogram of the input audio tensor. Supports streaming as | ||
long as the input audio tensor is a multiple of the hop length. | ||
""" | ||
if self.use_debug_mode: | ||
assert x.ndim == 2, "input audio must have shape (n_ch, n_samples)" | ||
assert x.size(0) == self.n_ch, "input audio n_ch is incorrect" | ||
assert ( | ||
x.size(1) % self.hop_len == 0 | ||
), "input audio n_samples must be divisible by hop_len" | ||
# Compute the Mel spec | ||
n_samples = x.size(1) | ||
n_frames = n_samples // self.hop_len | ||
padded_x = tr.cat([self.padding, x], dim=1) | ||
padded_spec = self.mel_spec(padded_x) | ||
spec = padded_spec[:, :, -n_frames:] | ||
|
||
# Update the cache and padding | ||
padding_idx = min(n_samples, self.padding_n_samples) | ||
self.cache.push(x[:, -padding_idx:]) | ||
self.cache.fill(self.padding) | ||
return spec | ||
|
||
def prepare_for_inference(self) -> None: | ||
""" | ||
Prepares the cached Mel spectrogram for inference by disabling debug mode. | ||
""" | ||
self.cache.use_debug_mode = False | ||
self.use_debug_mode = False | ||
|
||
@tr.jit.export | ||
def get_delay_samples(self) -> int: | ||
""" | ||
Returns the number of samples of delay of the cached Mel spectrogram. | ||
""" | ||
return (self.n_fft // 2) - self.hop_len | ||
|
||
@tr.jit.export | ||
def get_delay_frames(self) -> int: | ||
""" | ||
Returns the number of frames of delay of the cached Mel spectrogram. | ||
""" | ||
return self.get_delay_samples() // self.hop_len | ||
|
||
@tr.jit.export | ||
def reset(self) -> None: | ||
""" | ||
Resets the cache and padding of the cached Mel spectrogram. | ||
""" | ||
self.cache.reset() | ||
self.padding.zero_() | ||
self.cache.push(self.padding) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import logging | ||
import os | ||
|
||
import torch as tr | ||
from torchaudio.transforms import MelSpectrogram | ||
|
||
from neutone_sdk.cached_mel_spec import CachedMelSpec | ||
|
||
logging.basicConfig() | ||
log = logging.getLogger(__name__) | ||
log.setLevel(level=os.environ.get("LOGLEVEL", "INFO")) | ||
|
||
|
||
def test_cached_mel_spec(): | ||
# Setup | ||
tr.set_printoptions(precision=1) | ||
tr.random.manual_seed(42) | ||
|
||
sr = 44100 | ||
n_ch = 1 | ||
n_fft = 2048 | ||
hop_len = 128 | ||
n_mels = 16 | ||
total_n_samples = 1000 * hop_len | ||
|
||
audio = tr.rand(n_ch, total_n_samples) | ||
# log.info(f"audio = {audio}") | ||
mel_spec = MelSpectrogram( | ||
sample_rate=sr, | ||
n_fft=n_fft, | ||
hop_length=hop_len, | ||
n_mels=n_mels, | ||
center=True, | ||
pad_mode="constant", | ||
) | ||
cached_mel_spec = CachedMelSpec( | ||
sr=sr, n_ch=n_ch, n_fft=n_fft, hop_len=hop_len, n_mels=n_mels | ||
) | ||
|
||
# Test delay | ||
delay_samples = cached_mel_spec.get_delay_samples() | ||
assert delay_samples == n_fft // 2 - hop_len | ||
|
||
# Test processing all audio at once | ||
spec = mel_spec(audio) | ||
delay_frames = cached_mel_spec.get_delay_frames() | ||
cached_spec = cached_mel_spec(audio) | ||
cached_spec = cached_spec[:, :, delay_frames:] | ||
# log.info(f" spec = {spec}") | ||
# log.info(f"cached_spec = {cached_spec}") | ||
assert tr.allclose(spec[:, :, : cached_spec.size(2)], cached_spec) | ||
cached_mel_spec.reset() | ||
|
||
# Test processing audio in chunks (random chunk size) | ||
chunks = [] | ||
min_chunk_size = 1 | ||
max_chunk_size = 100 | ||
curr_idx = 0 | ||
while curr_idx < total_n_samples - max_chunk_size: | ||
chunk_size = ( | ||
tr.randint(min_chunk_size, max_chunk_size + 1, (1,)).item() * hop_len | ||
) | ||
chunks.append(audio[:, curr_idx : curr_idx + chunk_size]) | ||
curr_idx += chunk_size | ||
if curr_idx < total_n_samples: | ||
chunks.append(audio[:, curr_idx:]) | ||
chunks.append( | ||
tr.zeros(n_ch, cached_mel_spec.get_delay_samples() + cached_mel_spec.hop_len) | ||
) | ||
|
||
spec_chunks = [] | ||
for chunk in chunks: | ||
spec_chunk = cached_mel_spec(chunk) | ||
spec_chunks.append(spec_chunk) | ||
chunked_spec = tr.cat(spec_chunks, dim=2) | ||
chunked_spec = chunked_spec[:, :, delay_frames:] | ||
# log.info(f" spec = {spec}") | ||
# log.info(f"chunked_spec = {chunked_spec}") | ||
assert tr.allclose(spec, chunked_spec) | ||
log.info("test_cached_mel_spec passed!") | ||
|
||
|
||
if __name__ == "__main__": | ||
test_cached_mel_spec() |