diff --git a/VERSION b/VERSION new file mode 100644 index 0000000..28af839 --- /dev/null +++ b/VERSION @@ -0,0 +1 @@ +0.2.5 \ No newline at end of file diff --git a/open_musiclm/clap_quantized.py b/open_musiclm/clap_quantized.py index 7894d6f..606ae17 100644 --- a/open_musiclm/clap_quantized.py +++ b/open_musiclm/clap_quantized.py @@ -10,7 +10,7 @@ from vector_quantize_pytorch import ResidualVQ from .laion_clap import CLAP_Module -from .utils import exists, beartype_jit +from .utils import exists, beartype_jit, int16_to_float32, float32_to_int16 @beartype_jit diff --git a/open_musiclm/config.py b/open_musiclm/config.py index 2c98e73..39c7b40 100644 --- a/open_musiclm/config.py +++ b/open_musiclm/config.py @@ -31,12 +31,17 @@ class ClapRVQConfig: @dataclass class HubertKmeansConfig: model_name: str - normalize_embeds: bool + normalize_input: bool = True + normalize_embeds: bool = True embed_layer: int = 7 target_sample_hz: int = 16000 seq_len_multiple_of: int = 320 codebook_size: int = 1024 output_hz: int = 50 + # split input into smaller context window. note: MERT generalizes to longer sequences so probably not necessary + context_window_seconds: Optional[float] = None + # number of adjacent features to average together + bin_size: int = 1 @dataclass class EncodecConfig: diff --git a/open_musiclm/data.py b/open_musiclm/data.py index f67c78e..3391794 100644 --- a/open_musiclm/data.py +++ b/open_musiclm/data.py @@ -16,9 +16,7 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset from torchaudio.functional import resample -from .utils import (beartype_jit, curtail_to_multiple, default, - float32_to_int16, int16_to_float32, - zero_mean_unit_var_norm) +from .utils import (beartype_jit, curtail_to_multiple, default) # helper functions @@ -68,7 +66,6 @@ def __init__( folder, exts = ['flac', 'wav', 'mp3'], max_length_seconds: Optional[Union[FloatOrInt, Tuple[Optional[FloatOrInt], ...]]] = 1, - normalize: Union[bool, Tuple[bool, ...]] = False, target_sample_hz: OptionalIntOrTupleInt = None, seq_len_multiple_of: OptionalIntOrTupleInt = None, ignore_files: Optional[List[str]] = None, @@ -104,12 +101,10 @@ def __init__( self.max_length_seconds = cast_tuple(max_length_seconds, num_outputs) self.max_length = tuple([int(s * hz) if exists(s) else None for s, hz in zip(self.max_length_seconds, self.target_sample_hz)]) - self.normalize = cast_tuple(normalize, num_outputs) - self.seq_len_multiple_of = cast_tuple(seq_len_multiple_of, num_outputs) assert len(self.max_length) == len(self.max_length_seconds) == len( - self.target_sample_hz) == len(self.seq_len_multiple_of) == len(self.normalize) + self.target_sample_hz) == len(self.seq_len_multiple_of) def __len__(self): return len(self.files) @@ -134,10 +129,8 @@ def process_audio(self, data, sample_hz, pad_to_target_length=True): # recursively crop the audio at random in the order of longest to shortest max_length_seconds, padding when necessary. # e.g. if max_length_seconds = (10, 4), pick a 10 second crop from the original, then pick a 4 second crop from the 10 second crop - # also use normalized data when specified temp_data = data - temp_data_normalized = zero_mean_unit_var_norm(data) num_outputs = len(self.target_sample_hz) data = [None for _ in range(num_outputs)] @@ -157,17 +150,14 @@ def process_audio(self, data, sample_hz, pad_to_target_length=True): start = torch.randint(0, max_start, (1, )) if self.random_crop else 0 temp_data = temp_data[:, start:start + target_length] - temp_data_normalized = temp_data_normalized[:, start:start + target_length] else: if pad_to_target_length: temp_data = F.pad(temp_data, (0, target_length - audio_length), 'constant') - temp_data_normalized = F.pad(temp_data_normalized, (0, target_length - audio_length), 'constant') - data[unsorted_i] = temp_data_normalized if self.normalize[unsorted_i] else temp_data + data[unsorted_i] = temp_data + # resample if target_sample_hz is not None in the tuple data_tuple = tuple((resample(d, sample_hz, target_sample_hz) if exists(target_sample_hz) else d) for d, target_sample_hz in zip(data, self.target_sample_hz)) - # quantize non-normalized audio to a valid waveform - data_tuple = tuple(d if self.normalize[i] else int16_to_float32(float32_to_int16(d)) for i, d in enumerate(data_tuple)) output = [] @@ -350,7 +340,7 @@ def get_clap_tokens(self, clap_token_ids, start_idx): def crop_semantic_tokens(self, semantic_token_ids, start_idx, end_idx): # with start_idx = 0, end_idx = 2, semantic_steps_per_second=50 # we return semantic_token_ids[:, 0:99] - return semantic_token_ids[:, start_idx * self.semantic_steps_per_second: end_idx * self.semantic_steps_per_second - 1] + return semantic_token_ids[:, start_idx * self.semantic_steps_per_second: end_idx * self.semantic_steps_per_second] def crop_acoustic_tokens(self, coarse_or_fine_ids, start_idx, end_idx): # with start_idx = 0, end_idx = 2, coarse_steps_per_second=75 diff --git a/open_musiclm/hf_hubert_kmeans.py b/open_musiclm/hf_hubert_kmeans.py index d89b00b..7e447bb 100644 --- a/open_musiclm/hf_hubert_kmeans.py +++ b/open_musiclm/hf_hubert_kmeans.py @@ -3,8 +3,9 @@ import torch from torch import nn import numpy as np -from einops import rearrange, pack, unpack +from einops import reduce, pack, unpack from beartype.typing import Optional +import math from torchaudio.functional import resample from .utils import exists, curtail_to_multiple, zero_mean_unit_var_norm @@ -20,6 +21,7 @@ class HfHubertWithKmeans(nn.Module): """ Hugging Face HubertModel + a k-means layer on top. Pretrained checkpoint for music: https://huggingface.co/m-a-p/MERT-v0 Note: MERT-v0 outputs features at 50Hz while Wav2Vec-BERT (used in the paper) outputs at 25 Hz. + We can reduce the number of semantic tokens by averaging adjacent features. """ def __init__( @@ -30,20 +32,26 @@ def __init__( embed_layer: int=7, target_sample_hz=16000, seq_len_multiple_of=int(16000 / 50), + normalize_input=True, normalize_embeds=True, codebook_size: int=1024, - output_hz: int=50 + output_hz: int=50, + context_window_seconds: Optional[float]=None, + bin_size: int=1 ): super().__init__() self.target_sample_hz = target_sample_hz - self.output_hz = output_hz self.seq_len_multiple_of = seq_len_multiple_of self.codebook_size = kmeans.n_clusters if exists(kmeans) else None + self.context_window_seconds = context_window_seconds + self.bin_size = bin_size + self.output_hz = output_hz self.codebook_size = codebook_size if exists(kmeans): assert self.codebook_size == kmeans.n_clusters, "codebook_size must match kmeans.n_clusters" + self.normalize_input = normalize_input self.normalize_embeds = normalize_embeds self.embed_layer = embed_layer @@ -66,9 +74,20 @@ def forward( if exists(input_sample_hz): wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz) + if exists(self.context_window_seconds): + target_length = int(self.context_window_seconds * self.target_sample_hz) + wav_input = list(wav_input.split(target_length, dim=-1)) + wav_input[-1] = torch.nn.functional.pad(wav_input[-1], (0, target_length - wav_input[-1].shape[-1])) + wav_input, packed_wav_input_shape = pack(wav_input, '* d') + else: + packed_wav_input_shape = None + if exists(self.seq_len_multiple_of): wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of) + if self.normalize_input: + wav_input = zero_mean_unit_var_norm(wav_input) + hubert_args = { 'input_values': wav_input, 'attention_mask': torch.ones_like(wav_input, device=device), # TODO: handle padding @@ -77,6 +96,19 @@ def forward( outputs = self.hubert(**hubert_args, output_hidden_states = True) embed = outputs.hidden_states[self.embed_layer] + # pad and reduce with bin size + audio_length_seconds = wav_input.shape[-1] / self.target_sample_hz + pad_to = int(self.output_hz * self.bin_size * audio_length_seconds) + if embed.shape[1] < pad_to: + # repeat last few frames + embed = torch.cat([embed, embed[:, -1:, :].repeat(1, pad_to - embed.shape[1], 1)], dim=1) + if self.bin_size > 1: + embed = reduce(embed, '... (n n1) f -> ... n f', reduction='mean', n1=self.bin_size) + + if exists(packed_wav_input_shape): + embed = unpack(embed, packed_wav_input_shape, '* t d') + embed = torch.cat(embed, dim=1) + if self.normalize_embeds: embed = zero_mean_unit_var_norm(embed) diff --git a/open_musiclm/open_musiclm.py b/open_musiclm/open_musiclm.py index 2e76d8a..a6baf4e 100644 --- a/open_musiclm/open_musiclm.py +++ b/open_musiclm/open_musiclm.py @@ -860,7 +860,7 @@ def forward( semantic_window_seconds=10, coarse_window_seconds=4, fine_window_seconds=2, - semantic_steps_per_second=50, # Note: for MERTv0 its actually 50 * seconds - 1 + semantic_steps_per_second=50, # e.g. for MERTv0 its 50 / bin_size acoustic_steps_per_second=75, # 75 for encodec, 50 for soundstream return_coarse_generated_wave=False, mask_out_generated_fine_tokens=False, @@ -888,13 +888,11 @@ def forward( prime_wave, prime_wave_sample_hz, self.wav2vec.target_sample_hz, - normalize=True, target_length_seconds=semantic_window_seconds) prime_wave_encodec = prepare_audio( prime_wave, prime_wave_sample_hz, self.neural_codec.sample_rate, - normalize=False, target_length_seconds=semantic_window_seconds) condition_semantic_token_ids = get_or_compute_semantic_token_ids(None, prime_wave_wav2vec, self.wav2vec) @@ -942,7 +940,7 @@ def forward( # coarse stage - window_size = int(coarse_window_seconds * semantic_steps_per_second - 1) + window_size = int(coarse_window_seconds * semantic_steps_per_second) step_size = int(window_size * coarse_sliding_window_step_percent) all_semantic_token_ids = all_semantic_token_ids.unfold(1, window_size, step_size) all_semantic_token_ids = rearrange(all_semantic_token_ids, 'b n q w -> n b w q') @@ -1048,7 +1046,6 @@ def generate_top_match( text_latents = repeat(text_latents, 'b d -> (repeat b) d', repeat=num_samples) clap_input = resample(samples, self.neural_codec.sample_rate, self.clap.sample_rate) - clap_input = int16_to_float32(float32_to_int16(clap_input)) audio_latents = self.clap(audio_input=clap_input, return_embedding=True) sim = F.cosine_similarity(text_latents, audio_latents, dim=-1) diff --git a/open_musiclm/preprocess.py b/open_musiclm/preprocess.py index 28ab01c..3b82e1b 100644 --- a/open_musiclm/preprocess.py +++ b/open_musiclm/preprocess.py @@ -139,8 +139,6 @@ def __init__( target_sample_hz = (audio_conditioner.sample_rate, wav2vec.target_sample_hz, neural_codec.sample_rate) - normalize = (False, True, False) - seq_len_multiple_of = (None, wav2vec.seq_len_multiple_of, None) data_max_length_seconds = (max_audio_length_seconds, max_audio_length_seconds, max_audio_length_seconds) @@ -152,7 +150,6 @@ def __init__( pad_to_seconds=self.semantic_audio_length_seconds, max_length_seconds=data_max_length_seconds, random_crop=random_crop, - normalize=normalize, target_sample_hz=target_sample_hz, seq_len_multiple_of=seq_len_multiple_of, ignore_load_errors=ignore_load_errors, diff --git a/open_musiclm/trainer.py b/open_musiclm/trainer.py index 38cb3f5..3fd2e45 100644 --- a/open_musiclm/trainer.py +++ b/open_musiclm/trainer.py @@ -181,7 +181,6 @@ def __init__( else: self.ds_fields = ('raw_wave_for_clap', 'raw_wave_for_semantic') target_sample_hz = (audio_conditioner.sample_rate, wav2vec.target_sample_hz) - normalize = (False, True) seq_len_multiple_of = wav2vec.seq_len_multiple_of elif stage == 'coarse': assert self.use_preprocessed_data or (exists(wav2vec) and exists(audio_conditioner) and exists(neural_codec)) @@ -197,7 +196,6 @@ def __init__( else: self.ds_fields = ('raw_wave_for_clap', 'raw_wave_for_semantic', 'raw_wave_for_acoustic') target_sample_hz = (audio_conditioner.sample_rate, wav2vec.target_sample_hz, neural_codec.sample_rate) - normalize = (False, True, False) seq_len_multiple_of = wav2vec.seq_len_multiple_of elif stage == 'fine': assert self.use_preprocessed_data or (exists(audio_conditioner) and exists(neural_codec)) @@ -212,7 +210,6 @@ def __init__( else: self.ds_fields = ('raw_wave_for_clap', 'raw_wave_for_acoustic') target_sample_hz = (audio_conditioner.sample_rate, neural_codec.sample_rate) - normalize = (False, False) seq_len_multiple_of = None else: raise ValueError(f'invalid stage: {stage}') @@ -260,7 +257,6 @@ def __init__( self.ds = SoundDataset( folder, max_length_seconds=data_max_length_seconds, - normalize=normalize, target_sample_hz=target_sample_hz, seq_len_multiple_of=seq_len_multiple_of, ignore_files=default(ignore_files, []), @@ -781,7 +777,6 @@ def __init__( self.ds = SoundDataset( folder, max_length_seconds=data_max_length_seconds, - normalize=True, target_sample_hz=hubert_kmeans.target_sample_hz, seq_len_multiple_of=hubert_kmeans.seq_len_multiple_of, ignore_files=default(ignore_files, []), diff --git a/open_musiclm/utils.py b/open_musiclm/utils.py index ffb4830..edf3d7e 100644 --- a/open_musiclm/utils.py +++ b/open_musiclm/utils.py @@ -154,15 +154,12 @@ def float32_to_int16(x): def zero_mean_unit_var_norm(x): return (x - x.mean(dim=-1, keepdim=True)) / torch.sqrt(x.var(dim=-1, keepdim=True) + 1e-7) -def prepare_audio(data, sample_hz, target_sample_hz, normalize=True, target_length_seconds=None): +def prepare_audio(data, sample_hz, target_sample_hz, target_length_seconds=None): if data.shape[0] > 1: data = torch.mean(data, dim=0).unsqueeze(0) - if normalize: - data = zero_mean_unit_var_norm(data) if exists(target_length_seconds) and data.shape[1] > target_length_seconds * sample_hz: data = data[: , :int(target_length_seconds * sample_hz)] audio_for_wav2vec = resample(data, sample_hz, target_sample_hz) - audio_for_wav2vec = int16_to_float32(float32_to_int16(audio_for_wav2vec)) return audio_for_wav2vec # helper for saving config diff --git a/scripts/infer_coarse.py b/scripts/infer_coarse.py index 86749cd..690278b 100644 --- a/scripts/infer_coarse.py +++ b/scripts/infer_coarse.py @@ -94,15 +94,10 @@ data = torch.mean(data, dim=0).unsqueeze(0) target_length = int(4 * sample_hz) - normalized_data = zero_mean_unit_var_norm(data) data = data[:, :target_length] - normalized_data = normalized_data[: , :target_length] audio_for_clap = resample(data, sample_hz, clap.sample_rate) - audio_for_wav2vec = resample(normalized_data, sample_hz, wav2vec.target_sample_hz) - - audio_for_clap = int16_to_float32(float32_to_int16(audio_for_clap)) - audio_for_wav2vec = int16_to_float32(float32_to_int16(audio_for_wav2vec)) + audio_for_wav2vec = resample(data, sample_hz, wav2vec.target_sample_hz) audios_for_clap.append(audio_for_clap) audios_for_wav2vec.append(audio_for_wav2vec) diff --git a/scripts/test/test_dataloader.py b/scripts/test/test_dataloader.py index af4f8c3..49dfed3 100644 --- a/scripts/test/test_dataloader.py +++ b/scripts/test/test_dataloader.py @@ -21,7 +21,6 @@ dataset = SoundDataset( folder, max_length_seconds=(1, 5), - normalize=(True, False), target_sample_hz=(16000, 24000), seq_len_multiple_of=None, ignore_load_errors=True @@ -42,7 +41,6 @@ dataset = SoundDatasetForPreprocessing( folder, max_length_seconds=(None, 1), - normalize=(True, False), target_sample_hz=(16000, 24000), seq_len_multiple_of=None, ignore_load_errors=True diff --git a/scripts/test/test_hubert_clustering.py b/scripts/test/test_hubert_clustering.py index 6f61af8..dcaffe4 100644 --- a/scripts/test/test_hubert_clustering.py +++ b/scripts/test/test_hubert_clustering.py @@ -12,7 +12,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) from open_musiclm.config import load_model_config, load_training_config, create_hubert_kmeans_from_config, create_hubert_kmeans_trainer_from_config -from open_musiclm.utils import zero_mean_unit_var_norm, int16_to_float32, float32_to_int16, exists +from open_musiclm.utils import zero_mean_unit_var_norm, int16_to_float32, float32_to_int16, exists, prepare_audio from open_musiclm.open_musiclm import get_or_compute_semantic_token_ids if __name__ == '__main__': @@ -29,7 +29,7 @@ print('loading hubert...') wav2vec = create_hubert_kmeans_from_config(model_config, args.kmeans_path, device) - + path = Path(args.folder) assert path.exists(), 'folder does not exist' @@ -48,21 +48,9 @@ audios_for_wav2vec = [] for audio_path in files[start_audio: start_audio + 16]: data, sample_hz = torchaudio.load(audio_path) - - if data.shape[0] > 1: - data = torch.mean(data, dim=0).unsqueeze(0) - - target_length = int(audio_seconds * sample_hz) - normalized_data = zero_mean_unit_var_norm(data) - - normalized_data = normalized_data[: , :target_length] - - audio_for_wav2vec = resample(normalized_data, sample_hz, wav2vec.target_sample_hz) - - audio_for_wav2vec = int16_to_float32(float32_to_int16(audio_for_wav2vec)) - + audio_for_wav2vec = prepare_audio(data, sample_hz, wav2vec.target_sample_hz, audio_seconds) audios_for_wav2vec.append(audio_for_wav2vec) - + audios_for_wav2vec = torch.cat(audios_for_wav2vec, dim=0).to(device) semantic_token_ids = get_or_compute_semantic_token_ids(None, audios_for_wav2vec, wav2vec) print(semantic_token_ids.shape) @@ -123,5 +111,5 @@ # show the plot plt.savefig('./results/accuracy_matrix.png') - + diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..8078b42 --- /dev/null +++ b/setup.py @@ -0,0 +1,52 @@ +from pathlib import Path + +from setuptools import setup, find_packages + +# Load the version from file +__version__ = Path("VERSION").read_text().strip() + +setup( + name = 'open-musiclm', + packages = find_packages(exclude=[]), + version = __version__, + license='MIT', + description = 'Open MusicLM - Implementation of MusicLM, a text to music model published by Google Research, with a few modifications', + author = 'Allen Zhang', + long_description_content_type = 'text/markdown', + url = 'https://github.com/zhvng/open-musiclm', + keywords = [ + 'artificial intelligence', + 'deep learning', + 'transformers', + 'attention mechanism', + 'audio generation', + 'musiclm', + ], + install_requires=[ + 'torch', + 'torchvision', + 'torchaudio', + 'einops', + 'vector-quantize-pytorch', + 'librosa', + 'torchlibrosa', + 'ftfy', + 'tqdm', + 'transformers', + 'encodec', + 'gdown', + 'accelerate', + 'beartype', + 'joblib', + 'h5py', + 'sklearn', + 'wget', + ], + classifiers=[ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 3.10', + ], +) \ No newline at end of file