Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix semantic tokens #13

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions VERSION
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0.2.5
2 changes: 1 addition & 1 deletion open_musiclm/clap_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vector_quantize_pytorch import ResidualVQ

from .laion_clap import CLAP_Module
from .utils import exists, beartype_jit
from .utils import exists, beartype_jit, int16_to_float32, float32_to_int16


@beartype_jit
Expand Down
7 changes: 6 additions & 1 deletion open_musiclm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,17 @@ class ClapRVQConfig:
@dataclass
class HubertKmeansConfig:
model_name: str
normalize_embeds: bool
normalize_input: bool = True
normalize_embeds: bool = True
embed_layer: int = 7
target_sample_hz: int = 16000
seq_len_multiple_of: int = 320
codebook_size: int = 1024
output_hz: int = 50
# split input into smaller context window. note: MERT generalizes to longer sequences so probably not necessary
context_window_seconds: Optional[float] = None
# number of adjacent features to average together
bin_size: int = 1

@dataclass
class EncodecConfig:
Expand Down
20 changes: 5 additions & 15 deletions open_musiclm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torchaudio.functional import resample

from .utils import (beartype_jit, curtail_to_multiple, default,
float32_to_int16, int16_to_float32,
zero_mean_unit_var_norm)
from .utils import (beartype_jit, curtail_to_multiple, default)

# helper functions

Expand Down Expand Up @@ -68,7 +66,6 @@ def __init__(
folder,
exts = ['flac', 'wav', 'mp3'],
max_length_seconds: Optional[Union[FloatOrInt, Tuple[Optional[FloatOrInt], ...]]] = 1,
normalize: Union[bool, Tuple[bool, ...]] = False,
target_sample_hz: OptionalIntOrTupleInt = None,
seq_len_multiple_of: OptionalIntOrTupleInt = None,
ignore_files: Optional[List[str]] = None,
Expand Down Expand Up @@ -104,12 +101,10 @@ def __init__(
self.max_length_seconds = cast_tuple(max_length_seconds, num_outputs)
self.max_length = tuple([int(s * hz) if exists(s) else None for s, hz in zip(self.max_length_seconds, self.target_sample_hz)])

self.normalize = cast_tuple(normalize, num_outputs)

self.seq_len_multiple_of = cast_tuple(seq_len_multiple_of, num_outputs)

assert len(self.max_length) == len(self.max_length_seconds) == len(
self.target_sample_hz) == len(self.seq_len_multiple_of) == len(self.normalize)
self.target_sample_hz) == len(self.seq_len_multiple_of)

def __len__(self):
return len(self.files)
Expand All @@ -134,10 +129,8 @@ def process_audio(self, data, sample_hz, pad_to_target_length=True):

# recursively crop the audio at random in the order of longest to shortest max_length_seconds, padding when necessary.
# e.g. if max_length_seconds = (10, 4), pick a 10 second crop from the original, then pick a 4 second crop from the 10 second crop
# also use normalized data when specified

temp_data = data
temp_data_normalized = zero_mean_unit_var_norm(data)

num_outputs = len(self.target_sample_hz)
data = [None for _ in range(num_outputs)]
Expand All @@ -157,17 +150,14 @@ def process_audio(self, data, sample_hz, pad_to_target_length=True):
start = torch.randint(0, max_start, (1, )) if self.random_crop else 0

temp_data = temp_data[:, start:start + target_length]
temp_data_normalized = temp_data_normalized[:, start:start + target_length]
else:
if pad_to_target_length:
temp_data = F.pad(temp_data, (0, target_length - audio_length), 'constant')
temp_data_normalized = F.pad(temp_data_normalized, (0, target_length - audio_length), 'constant')

data[unsorted_i] = temp_data_normalized if self.normalize[unsorted_i] else temp_data
data[unsorted_i] = temp_data

# resample if target_sample_hz is not None in the tuple
data_tuple = tuple((resample(d, sample_hz, target_sample_hz) if exists(target_sample_hz) else d) for d, target_sample_hz in zip(data, self.target_sample_hz))
# quantize non-normalized audio to a valid waveform
data_tuple = tuple(d if self.normalize[i] else int16_to_float32(float32_to_int16(d)) for i, d in enumerate(data_tuple))

output = []

Expand Down Expand Up @@ -350,7 +340,7 @@ def get_clap_tokens(self, clap_token_ids, start_idx):
def crop_semantic_tokens(self, semantic_token_ids, start_idx, end_idx):
# with start_idx = 0, end_idx = 2, semantic_steps_per_second=50
# we return semantic_token_ids[:, 0:99]
return semantic_token_ids[:, start_idx * self.semantic_steps_per_second: end_idx * self.semantic_steps_per_second - 1]
return semantic_token_ids[:, start_idx * self.semantic_steps_per_second: end_idx * self.semantic_steps_per_second]

def crop_acoustic_tokens(self, coarse_or_fine_ids, start_idx, end_idx):
# with start_idx = 0, end_idx = 2, coarse_steps_per_second=75
Expand Down
38 changes: 35 additions & 3 deletions open_musiclm/hf_hubert_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import torch
from torch import nn
import numpy as np
from einops import rearrange, pack, unpack
from einops import reduce, pack, unpack
from beartype.typing import Optional
import math

from torchaudio.functional import resample
from .utils import exists, curtail_to_multiple, zero_mean_unit_var_norm
Expand All @@ -20,6 +21,7 @@ class HfHubertWithKmeans(nn.Module):
"""
Hugging Face HubertModel + a k-means layer on top. Pretrained checkpoint for music: https://huggingface.co/m-a-p/MERT-v0
Note: MERT-v0 outputs features at 50Hz while Wav2Vec-BERT (used in the paper) outputs at 25 Hz.
We can reduce the number of semantic tokens by averaging adjacent features.
"""

def __init__(
Expand All @@ -30,20 +32,26 @@ def __init__(
embed_layer: int=7,
target_sample_hz=16000,
seq_len_multiple_of=int(16000 / 50),
normalize_input=True,
normalize_embeds=True,
codebook_size: int=1024,
output_hz: int=50
output_hz: int=50,
context_window_seconds: Optional[float]=None,
bin_size: int=1
):
super().__init__()
self.target_sample_hz = target_sample_hz
self.output_hz = output_hz
self.seq_len_multiple_of = seq_len_multiple_of
self.codebook_size = kmeans.n_clusters if exists(kmeans) else None
self.context_window_seconds = context_window_seconds
self.bin_size = bin_size
self.output_hz = output_hz

self.codebook_size = codebook_size
if exists(kmeans):
assert self.codebook_size == kmeans.n_clusters, "codebook_size must match kmeans.n_clusters"

self.normalize_input = normalize_input
self.normalize_embeds = normalize_embeds

self.embed_layer = embed_layer
Expand All @@ -66,9 +74,20 @@ def forward(
if exists(input_sample_hz):
wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz)

if exists(self.context_window_seconds):
target_length = int(self.context_window_seconds * self.target_sample_hz)
wav_input = list(wav_input.split(target_length, dim=-1))
wav_input[-1] = torch.nn.functional.pad(wav_input[-1], (0, target_length - wav_input[-1].shape[-1]))
wav_input, packed_wav_input_shape = pack(wav_input, '* d')
else:
packed_wav_input_shape = None

if exists(self.seq_len_multiple_of):
wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)

if self.normalize_input:
wav_input = zero_mean_unit_var_norm(wav_input)

hubert_args = {
'input_values': wav_input,
'attention_mask': torch.ones_like(wav_input, device=device), # TODO: handle padding
Expand All @@ -77,6 +96,19 @@ def forward(
outputs = self.hubert(**hubert_args, output_hidden_states = True)
embed = outputs.hidden_states[self.embed_layer]

# pad and reduce with bin size
audio_length_seconds = wav_input.shape[-1] / self.target_sample_hz
pad_to = int(self.output_hz * self.bin_size * audio_length_seconds)
if embed.shape[1] < pad_to:
# repeat last few frames
embed = torch.cat([embed, embed[:, -1:, :].repeat(1, pad_to - embed.shape[1], 1)], dim=1)
if self.bin_size > 1:
embed = reduce(embed, '... (n n1) f -> ... n f', reduction='mean', n1=self.bin_size)

if exists(packed_wav_input_shape):
embed = unpack(embed, packed_wav_input_shape, '* t d')
embed = torch.cat(embed, dim=1)

if self.normalize_embeds:
embed = zero_mean_unit_var_norm(embed)

Expand Down
7 changes: 2 additions & 5 deletions open_musiclm/open_musiclm.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ def forward(
semantic_window_seconds=10,
coarse_window_seconds=4,
fine_window_seconds=2,
semantic_steps_per_second=50, # Note: for MERTv0 its actually 50 * seconds - 1
semantic_steps_per_second=50, # e.g. for MERTv0 its 50 / bin_size
acoustic_steps_per_second=75, # 75 for encodec, 50 for soundstream
return_coarse_generated_wave=False,
mask_out_generated_fine_tokens=False,
Expand Down Expand Up @@ -888,13 +888,11 @@ def forward(
prime_wave,
prime_wave_sample_hz,
self.wav2vec.target_sample_hz,
normalize=True,
target_length_seconds=semantic_window_seconds)
prime_wave_encodec = prepare_audio(
prime_wave,
prime_wave_sample_hz,
self.neural_codec.sample_rate,
normalize=False,
target_length_seconds=semantic_window_seconds)

condition_semantic_token_ids = get_or_compute_semantic_token_ids(None, prime_wave_wav2vec, self.wav2vec)
Expand Down Expand Up @@ -942,7 +940,7 @@ def forward(

# coarse stage

window_size = int(coarse_window_seconds * semantic_steps_per_second - 1)
window_size = int(coarse_window_seconds * semantic_steps_per_second)
step_size = int(window_size * coarse_sliding_window_step_percent)
all_semantic_token_ids = all_semantic_token_ids.unfold(1, window_size, step_size)
all_semantic_token_ids = rearrange(all_semantic_token_ids, 'b n q w -> n b w q')
Expand Down Expand Up @@ -1048,7 +1046,6 @@ def generate_top_match(
text_latents = repeat(text_latents, 'b d -> (repeat b) d', repeat=num_samples)

clap_input = resample(samples, self.neural_codec.sample_rate, self.clap.sample_rate)
clap_input = int16_to_float32(float32_to_int16(clap_input))
audio_latents = self.clap(audio_input=clap_input, return_embedding=True)

sim = F.cosine_similarity(text_latents, audio_latents, dim=-1)
Expand Down
3 changes: 0 additions & 3 deletions open_musiclm/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,6 @@ def __init__(

target_sample_hz = (audio_conditioner.sample_rate, wav2vec.target_sample_hz, neural_codec.sample_rate)

normalize = (False, True, False)

seq_len_multiple_of = (None, wav2vec.seq_len_multiple_of, None)

data_max_length_seconds = (max_audio_length_seconds, max_audio_length_seconds, max_audio_length_seconds)
Expand All @@ -152,7 +150,6 @@ def __init__(
pad_to_seconds=self.semantic_audio_length_seconds,
max_length_seconds=data_max_length_seconds,
random_crop=random_crop,
normalize=normalize,
target_sample_hz=target_sample_hz,
seq_len_multiple_of=seq_len_multiple_of,
ignore_load_errors=ignore_load_errors,
Expand Down
5 changes: 0 additions & 5 deletions open_musiclm/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ def __init__(
else:
self.ds_fields = ('raw_wave_for_clap', 'raw_wave_for_semantic')
target_sample_hz = (audio_conditioner.sample_rate, wav2vec.target_sample_hz)
normalize = (False, True)
seq_len_multiple_of = wav2vec.seq_len_multiple_of
elif stage == 'coarse':
assert self.use_preprocessed_data or (exists(wav2vec) and exists(audio_conditioner) and exists(neural_codec))
Expand All @@ -197,7 +196,6 @@ def __init__(
else:
self.ds_fields = ('raw_wave_for_clap', 'raw_wave_for_semantic', 'raw_wave_for_acoustic')
target_sample_hz = (audio_conditioner.sample_rate, wav2vec.target_sample_hz, neural_codec.sample_rate)
normalize = (False, True, False)
seq_len_multiple_of = wav2vec.seq_len_multiple_of
elif stage == 'fine':
assert self.use_preprocessed_data or (exists(audio_conditioner) and exists(neural_codec))
Expand All @@ -212,7 +210,6 @@ def __init__(
else:
self.ds_fields = ('raw_wave_for_clap', 'raw_wave_for_acoustic')
target_sample_hz = (audio_conditioner.sample_rate, neural_codec.sample_rate)
normalize = (False, False)
seq_len_multiple_of = None
else:
raise ValueError(f'invalid stage: {stage}')
Expand Down Expand Up @@ -260,7 +257,6 @@ def __init__(
self.ds = SoundDataset(
folder,
max_length_seconds=data_max_length_seconds,
normalize=normalize,
target_sample_hz=target_sample_hz,
seq_len_multiple_of=seq_len_multiple_of,
ignore_files=default(ignore_files, []),
Expand Down Expand Up @@ -781,7 +777,6 @@ def __init__(
self.ds = SoundDataset(
folder,
max_length_seconds=data_max_length_seconds,
normalize=True,
target_sample_hz=hubert_kmeans.target_sample_hz,
seq_len_multiple_of=hubert_kmeans.seq_len_multiple_of,
ignore_files=default(ignore_files, []),
Expand Down
5 changes: 1 addition & 4 deletions open_musiclm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,12 @@ def float32_to_int16(x):
def zero_mean_unit_var_norm(x):
return (x - x.mean(dim=-1, keepdim=True)) / torch.sqrt(x.var(dim=-1, keepdim=True) + 1e-7)

def prepare_audio(data, sample_hz, target_sample_hz, normalize=True, target_length_seconds=None):
def prepare_audio(data, sample_hz, target_sample_hz, target_length_seconds=None):
if data.shape[0] > 1:
data = torch.mean(data, dim=0).unsqueeze(0)
if normalize:
data = zero_mean_unit_var_norm(data)
if exists(target_length_seconds) and data.shape[1] > target_length_seconds * sample_hz:
data = data[: , :int(target_length_seconds * sample_hz)]
audio_for_wav2vec = resample(data, sample_hz, target_sample_hz)
audio_for_wav2vec = int16_to_float32(float32_to_int16(audio_for_wav2vec))
return audio_for_wav2vec

# helper for saving config
Expand Down
7 changes: 1 addition & 6 deletions scripts/infer_coarse.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,10 @@
data = torch.mean(data, dim=0).unsqueeze(0)

target_length = int(4 * sample_hz)
normalized_data = zero_mean_unit_var_norm(data)

data = data[:, :target_length]
normalized_data = normalized_data[: , :target_length]
audio_for_clap = resample(data, sample_hz, clap.sample_rate)
audio_for_wav2vec = resample(normalized_data, sample_hz, wav2vec.target_sample_hz)

audio_for_clap = int16_to_float32(float32_to_int16(audio_for_clap))
audio_for_wav2vec = int16_to_float32(float32_to_int16(audio_for_wav2vec))
audio_for_wav2vec = resample(data, sample_hz, wav2vec.target_sample_hz)

audios_for_clap.append(audio_for_clap)
audios_for_wav2vec.append(audio_for_wav2vec)
Expand Down
2 changes: 0 additions & 2 deletions scripts/test/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
dataset = SoundDataset(
folder,
max_length_seconds=(1, 5),
normalize=(True, False),
target_sample_hz=(16000, 24000),
seq_len_multiple_of=None,
ignore_load_errors=True
Expand All @@ -42,7 +41,6 @@
dataset = SoundDatasetForPreprocessing(
folder,
max_length_seconds=(None, 1),
normalize=(True, False),
target_sample_hz=(16000, 24000),
seq_len_multiple_of=None,
ignore_load_errors=True
Expand Down
22 changes: 5 additions & 17 deletions scripts/test/test_hubert_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))

from open_musiclm.config import load_model_config, load_training_config, create_hubert_kmeans_from_config, create_hubert_kmeans_trainer_from_config
from open_musiclm.utils import zero_mean_unit_var_norm, int16_to_float32, float32_to_int16, exists
from open_musiclm.utils import zero_mean_unit_var_norm, int16_to_float32, float32_to_int16, exists, prepare_audio
from open_musiclm.open_musiclm import get_or_compute_semantic_token_ids

if __name__ == '__main__':
Expand All @@ -29,7 +29,7 @@

print('loading hubert...')
wav2vec = create_hubert_kmeans_from_config(model_config, args.kmeans_path, device)

path = Path(args.folder)
assert path.exists(), 'folder does not exist'

Expand All @@ -48,21 +48,9 @@
audios_for_wav2vec = []
for audio_path in files[start_audio: start_audio + 16]:
data, sample_hz = torchaudio.load(audio_path)

if data.shape[0] > 1:
data = torch.mean(data, dim=0).unsqueeze(0)

target_length = int(audio_seconds * sample_hz)
normalized_data = zero_mean_unit_var_norm(data)

normalized_data = normalized_data[: , :target_length]

audio_for_wav2vec = resample(normalized_data, sample_hz, wav2vec.target_sample_hz)

audio_for_wav2vec = int16_to_float32(float32_to_int16(audio_for_wav2vec))

audio_for_wav2vec = prepare_audio(data, sample_hz, wav2vec.target_sample_hz, audio_seconds)
audios_for_wav2vec.append(audio_for_wav2vec)

audios_for_wav2vec = torch.cat(audios_for_wav2vec, dim=0).to(device)
semantic_token_ids = get_or_compute_semantic_token_ids(None, audios_for_wav2vec, wav2vec)
print(semantic_token_ids.shape)
Expand Down Expand Up @@ -123,5 +111,5 @@

# show the plot
plt.savefig('./results/accuracy_matrix.png')


Loading