diff --git a/README.md b/README.md index da17a86..18e6176 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,12 @@ This repository contains a number of scripts required for replicating Codecformer within the speechbrain framework. Unfortunately, they will have to be copied into the respective directories manually. -train_cdf.py -> recipes/WSJ02Mix/Separation +train_codecformer.py -> recipes/WSJ02Mix/Separation -DAC_original_L4nq.yaml -> recipes/WSJ02Mix/Separation/hparams +codecformer-dac.yaml -> recipes/WSJ02Mix/Separation/hparams +codecformer-wavtokenizer.yaml -> recipes/WSJ02Mix/Separation/hparams -codecformer3.py -> speechbrain/lobes/models +codecformer.py -> speechbrain/lobes/models For replication efforts, please note that the activation function of the simpleseparator2 model has a big impact on performance. Ensure that the activation function of the separator matches the activation function used in the final layer of the neural audio codec's encoder. diff --git a/DAC_original_L4nq.yaml b/codecformer-dac.yaml similarity index 89% rename from DAC_original_L4nq.yaml rename to codecformer-dac.yaml index c215279..0d971af 100644 --- a/DAC_original_L4nq.yaml +++ b/codecformer-dac.yaml @@ -11,18 +11,21 @@ seed: 1234 __set_seed: !apply:torch.manual_seed [1234] # Data params +# your base folder where this repo and other relevant files/repos are stored +base_folder: /yourpath +# experiment folder name to generate in -/separation/results +experiment_name: codecformer/DAC_original_L4nq # e.g. '/yourpath/wsj0-mix/2speakers' # end with 2speakers for wsj0-2mix or 3speakers for wsj0-3mix -data_folder: +data_folder: !ref /wsj0-mix/2speakers # the path for wsj0/si_tr_s/ folder -- only needed if dynamic mixing is used # e.g. /yourpath/wsj0-processed/si_tr_s/ # you need to convert the original wsj0 to 8k # you can do this conversion with the script ../meta/preprocess_dynamic_mixing.py -base_folder_dm: /yourpath/wsj0-processed/si_tr_s/ +base_folder_dm: !ref /wsj0-mix/2speakers/si_tr_s/ -experiment_name: codecformer/DAC_original_L4nq output_folder: !ref results// train_log: !ref /train_log.txt save_folder: !ref /save @@ -30,6 +33,12 @@ train_data: !ref /wsj_tr.csv valid_data: !ref /wsj_cv.csv test_data: !ref /wsj_tt.csv skip_prep: false +# optionally specify the number of files to use for training, validation, and testing +# comment out to use all files +# file_limits: +# train: 100 +# valid: 10 +# test: 10 # Experiment params @@ -127,13 +136,13 @@ block: !new:speechbrain.lobes.models.dual_path.SBTransformerBlock use_positional_encoding: true norm_before: true -dacmodel: !new:speechbrain.lobes.models.codecformer3.DACWrapper +dacmodel: !new:speechbrain.lobes.models.codecformer.DACWrapper input_sample_rate: 8000 DAC_model_path: #if None, will download model from huggingface. Otherwise, path to checkpoint should be provided for the model to be loaded. Model has been hardcoded to download the 16khz model. please modify the code if you need another model. DAC_sample_rate: 16000 Freeze: true -sepmodel: !new:speechbrain.lobes.models.codecformer3.simpleSeparator2 +sepmodel: !new:speechbrain.lobes.models.codecformer.simpleSeparator2 # dacmodel: !ref num_spks: 2 channels: 1024 diff --git a/codecformer-wavtokenizer.yaml b/codecformer-wavtokenizer.yaml new file mode 100644 index 0000000..0a9b519 --- /dev/null +++ b/codecformer-wavtokenizer.yaml @@ -0,0 +1,178 @@ +# ################################ +# Model: Codecformer for source separation +# https://arxiv.org/abs/2406.12434 +# Dataset : WSJ0-2mix and WSJ0-3mix +# ################################ +# +# Basic parameters +# Seed needs to be set at top of yaml, before objects with parameters are made +# +seed: 1234 +__set_seed: !apply:torch.manual_seed [1234] + +# Data params +# your base folder where this repo and other relevant files/repos are stored +base_folder: /yourpath +# experiment folder name to generate in -/separation/results +experiment_name: codecformer/wavtokenizer2 + +# e.g. '/yourpath/wsj0-mix/2speakers' +# end with 2speakers for wsj0-2mix or 3speakers for wsj0-3mix +data_folder: !ref /wsj0-mix/2speakers + +# the path for wsj0/si_tr_s/ folder -- only needed if dynamic mixing is used +# e.g. /yourpath/wsj0-processed/si_tr_s/ +# you need to convert the original wsj0 to 8k +# you can do this conversion with the script ../meta/preprocess_dynamic_mixing.py +base_folder_dm: !ref /wsj0-mix/2speakers/si_tr_s/ + +output_folder: !ref results// +train_log: !ref /train_log.txt +save_folder: !ref /save +train_data: !ref /wsj_tr.csv +valid_data: !ref /wsj_cv.csv +test_data: !ref /wsj_tt.csv +skip_prep: false +# optionally specify the number of files to use for training, validation, and testing +# comment out to use all files +# file_limits: +# train: 100 +# valid: 10 +# test: 10 + + +# Experiment params +auto_mix_prec: false # Set it to True for mixed precision +test_only: false +num_spks: 2 # set to 3 for wsj0-3mix +noprogressbar: false +save_audio: true # Save estimated sources on disk +n_audio_to_save: 5 +sample_rate: 8000 +quantize_before: false +quantize_after: false + +# Training parameters +N_epochs: 3 +batch_size: 1 #3 +lr: 0.00015 #0.003 +clip_grad_norm: 5 +loss_upper_lim: 999999 # this is the upper limit for an acceptable loss +# if True, the training sequences are cut to a specified length +limit_training_signal_len: false +# this is the length of sequences if we choose to limit +# the signal length of training sequences +training_signal_len: 40000 + +# Set it to True to dynamically create mixtures at training time +dynamic_mixing: false + +# Parameters for data augmentation +use_wavedrop: false +use_speedperturb: true +use_rand_shift: false +min_shift: -8000 +max_shift: 8000 + +# Speed perturbation +speed_changes: [95, 100, 105] # List of speed changes for time-stretching + +speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb + orig_freq: !ref + speeds: !ref + +# Frequency drop: randomly drops a number of frequency bands to zero. +drop_freq_low: 0 # Min frequency band dropout probability +drop_freq_high: 1 # Max frequency band dropout probability +drop_freq_count_low: 1 # Min number of frequency bands to drop +drop_freq_count_high: 3 # Max number of frequency bands to drop +drop_freq_width: 0.05 # Width of frequency bands to drop + +drop_freq: !new:speechbrain.augment.time_domain.DropFreq + drop_freq_low: !ref + drop_freq_high: !ref + drop_freq_count_low: !ref + drop_freq_count_high: !ref + drop_freq_width: !ref + +# Time drop: randomly drops a number of temporal chunks. +drop_chunk_count_low: 1 # Min number of audio chunks to drop +drop_chunk_count_high: 5 # Max number of audio chunks to drop +drop_chunk_length_low: 1000 # Min length of audio chunks to drop +drop_chunk_length_high: 2000 # Max length of audio chunks to drop + +drop_chunk: !new:speechbrain.augment.time_domain.DropChunk + drop_length_low: !ref + drop_length_high: !ref + drop_count_low: !ref + drop_count_high: !ref + +# loss thresholding -- this thresholds the training loss +threshold_byloss: True +threshold: -30 + +# Dataloader options +# Set num_workers: 0 on MacOS due to behavior of the multiprocessing library +dataloader_opts: + batch_size: !ref + num_workers: 3 + +test_dataloader_opts: + batch_size: 1 + num_workers: 3 + +# Specifying the network + +# Encoder parameters +channels: 512 +block_channels: 256 #1024 #256 + +block: !new:speechbrain.lobes.models.dual_path.SBTransformerBlock + num_layers: 16 #16 + d_model: 256 + nhead: 8 #1/8 + d_ffn: 1024 #2048? + dropout: 0.1 #0.0/0.1/0.5 + use_positional_encoding: true + norm_before: true + +dacmodel: !new:speechbrain.lobes.models.codecformer.WavTokenizerWrapper + input_sample_rate: 8000 + model_config_path: !ref /WavTokenizer/wavtokenizer/configs/wavtokenizer_smalldata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml + model_ckpt_path: !ref /WavTokenizer_small_320_24k_4096.ckpt + tokenizer_sample_rate: 24000 + Freeze: true + +sepmodel: !new:speechbrain.lobes.models.codecformer.simpleSeparator2 + # dacmodel: !ref + num_spks: 2 + channels: 512 ## Note needs to be 512 for WavTokenizer and 1024 for DAC + block: !ref + block_channels: 256 + activation: !new:torch.nn.LeakyReLU + +optimizer: !name:torch.optim.Adam + lr: !ref + weight_decay: 0 + +#Loss parameters +loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper + +lr_scheduler: !new:speechbrain.nnet.schedulers.ReduceLROnPlateau + factor: 0.5 + patience: 2 + dont_halve_until_epoch: 5 + +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +modules: + sepmodel: !ref +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + sepmodel: !ref + counter: !ref + lr_scheduler: !ref +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref diff --git a/codecformer3.py b/codecformer.py similarity index 51% rename from codecformer3.py rename to codecformer.py index 5a25b61..cc5dcdb 100644 --- a/codecformer3.py +++ b/codecformer.py @@ -1,21 +1,13 @@ import dac print("descript audio codec v",dac.__version__) -from audiotools import AudioSignal import torchaudio.transforms as T -import math +import torchaudio import torch import torch.nn as nn import torch.nn.functional as F -import copy -import collections -import warnings -import pyloudnorm -import random -import numpy as np from torch.nn.utils import weight_norm from dac.nn.layers import Snake1d -from speechbrain.lobes.models.transformer.Transformer import TransformerDecoder -from speechbrain.lobes.models.transformer.Transformer import PositionalEncoding +from wavtokenizer.decoder.pretrained import WavTokenizer class DACWrapper(): ''' @@ -127,7 +119,7 @@ def get_decoded_signal(self, x, original_length): return out class simpleSeparator2(nn.Module): - def __init__(self, num_spks, channels, block, block_channels): + def __init__(self, num_spks, channels, block, block_channels, activation=None): super(simpleSeparator2, self).__init__() self.num_spks = num_spks self.channels = channels #this is dependent on the dac model @@ -137,7 +129,10 @@ def __init__(self, num_spks, channels, block, block_channels): #self.time_mix = nn.Conv1d(channels,channels,1,bias=False) self.masker = weight_norm(nn.Conv1d(channels, channels*num_spks, 1, bias=False)) - self.activation = Snake1d(channels) #nn.Tanh() #nn.ReLU() #Snake1d(channels) + if not activation: + self.activation = Snake1d(channels) #nn.Tanh() #nn.ReLU() #Snake1d(channels) + else: + self.activation = activation # gated output layer self.output = nn.Sequential( nn.Conv1d(channels, channels, 1), Snake1d(channels) #nn.Tanh() #, Snake1d(channels)# @@ -147,7 +142,6 @@ def __init__(self, num_spks, channels, block, block_channels): ) def forward(self,x): - x = self.ch_down(x) #[B,N,L] x = x.permute(0,2,1) @@ -176,4 +170,139 @@ def forward(self,x): x = x.transpose(0,1) # [spks, B, N, L] - return x \ No newline at end of file + return x + +class WavTokenizerWrapper: + ''' + Wrapper model for WavTokenizer + ''' + def __init__(self, input_sample_rate=8000, model_config_path=None, model_ckpt_path=None, tokenizer_sample_rate=24000, Freeze=True): + ''' + input_sample_rate: defaults to 8000 as expected file input + model_config_path: Path to the config file for WavTokenizer + model_ckpt_path: Path to the checkpoint file for WavTokenizer + tokenizer_sample_rate: defaults to 24000. Specify if using a model with a different sample rate. + ''' + super(WavTokenizerWrapper, self).__init__() + self.input_sample_rate = input_sample_rate + self.tokenizer_sample_rate = tokenizer_sample_rate + + if model_config_path is None or model_ckpt_path is None: + raise ValueError("Please provide both the model config and checkpoint paths.") + + self.model = WavTokenizer.from_pretrained0802(model_config_path, model_ckpt_path) + + self.dac_sampler = T.Resample(input_sample_rate, tokenizer_sample_rate) + self.org_sampler = T.Resample(tokenizer_sample_rate, input_sample_rate) + + def count_all_parameters(model): return sum(p.numel() for p in model.parameters()) + def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) + + if Freeze: + for param in self.model.parameters(): + param.requires_grad = False + + print(f'Model frozen with {count_parameters(self.model)/1000000:.2f}M trainable parameters remaining') + print(f'Model has {count_all_parameters(self.model)/1000000:.2f}M parameters in total') + else: + print(f'Model with {count_all_parameters(self.model)/1000000:.2f}M trainable parameters loaded') + + def resample_audio(self, x, condition): + ''' + Resample the audio according to the condition. + condition: "tokenizer" to set the sampling rate to the tokenizer's rate + "org" to set the sampling rate back to the original rate + ''' + device = x.device + + assert len(x.shape) == 3, "Input tensor must have 3 dimensions [Batch, Channels, Time]" + B, C, T = x.shape + assert C == 1, "Input tensor must be mono-channel [Batch, 1, Time]" + + if condition == "tokenizer": + x_resamp = self.dac_sampler(x) + elif condition == "org": + x_resamp = self.org_sampler(x) + else: + raise ValueError("Unknown condition for resampling: {}".format(condition)) + + x_resamp = x_resamp / torch.max(x_resamp.abs(), dim=2, keepdim=True)[0] + + return x_resamp.to(device) + + def get_encoded_features(self, x): + ''' + x should be a torch tensor with dimensions [Batch, Channel, Time] + ''' + original_length = x.shape[-1] + + # Resample the audio to the tokenizer's sample rate + x = self.resample_audio(x, "tokenizer") + + # Remove channel dimensions for the audio data tensor + x = x.squeeze(1) + + # Generate features and discrete codes + bandwidth_id = torch.tensor([0]).to(x.device) + features, _, _ = self.model.feature_extractor(x, bandwidth_id=bandwidth_id) + return features, original_length + + def get_quantized_features(self, x, bandwidth_id=None): + ''' + Expects input [B, D, T] where D is the encoded continuous representation of input. + Returns quantized features, codes, latents, commitment loss, and codebook loss in the same format as DACWrapper. + ''' + if bandwidth_id is None: + bandwidth_id = torch.tensor([0]).to(x.device) + + # Ensure the tensor has 3 dimensions [Batch, Channels, Time] + if x.ndim != 3: + raise ValueError(f"Expected input to have 3 dimensions [Batch, Channels, Time], but got {x.ndim} dimensions.") + + # Perform the quantization directly on the encoded features + q_res = self.model.feature_extractor.encodec.quantizer( + x, + frame_rate=self.model.feature_extractor.frame_rate, + bandwidth=self.model.feature_extractor.bandwidths[bandwidth_id] + ) + + # Extract necessary outputs + quantized = q_res.quantized + codes = q_res.codes + latents = x # The input x itself is the latent representation after encoding + commit_loss = q_res.penalty + + # Placeholder for codebook_loss (not directly available, could be None) + codebook_loss = None + + # Return the outputs in the expected format + return quantized, codes, latents, commit_loss, codebook_loss + + def get_decoded_signal(self, features, original_length): + ''' + Decodes the features back to the audio signal. + ''' + # Decode the features to get the waveform + bandwidth_id = torch.tensor([0]).to(features.device) + + x = self.model.backbone(features, bandwidth_id=bandwidth_id) + y_hat = self.model.head(x) + + # Ensure the output has three dimensions [Batch, Channels, Time] before resampling + if y_hat.ndim == 2: + y_hat = y_hat.unsqueeze(1) # Add a channel dimension if it's missing + + # Resample the decoded signal back to the original sampling rate + y_hat_resampled = self.resample_audio(y_hat, "org") + + # Ensure the output shape matches the original length + if y_hat_resampled.shape[-1] != original_length: + T_origin = original_length + T_est = y_hat_resampled.shape[-1] + + if T_origin > T_est: + y_hat_resampled = F.pad(y_hat_resampled, (0, T_origin - T_est)) + else: + y_hat_resampled = y_hat_resampled[:, :, :T_origin] + + return y_hat_resampled diff --git a/train_cdf.py b/train_codecformer.py similarity index 98% rename from train_cdf.py rename to train_codecformer.py index e6d6f0a..220185e 100644 --- a/train_cdf.py +++ b/train_codecformer.py @@ -35,7 +35,6 @@ import numpy as np import torch -import torch.nn.functional as F import torchaudio from hyperpyyaml import load_hyperpyyaml from tqdm import tqdm @@ -127,6 +126,7 @@ def compute_forward(self, mix, targets, stage, noise=None): mix_w = self.hparams.dacmodel.get_quantized_features(mix_w)[0] est_mask = self.hparams.sepmodel(mix_w) + mix_w = torch.stack([mix_w] * self.hparams.num_spks) mix_s = mix_w * est_mask # mix_s = est_mask @@ -576,6 +576,17 @@ def dataio_prep(hparams): replacements={"data_root": hparams["data_folder"]}, ) + file_limits = hparams.get("file_limits", {"train": None, "valid": None, "test": None}) + + if file_limits.get("train") is not None: + train_data = train_data.filtered_sorted(select_n=file_limits["train"]) + + if file_limits.get("valid") is not None: + valid_data = valid_data.filtered_sorted(select_n=file_limits["valid"]) + + if file_limits.get("test") is not None: + test_data = test_data.filtered_sorted(select_n=file_limits["test"]) + datasets = [train_data, valid_data, test_data] # 2. Provide audio pipelines