From 6111785e685d01c2e88dbee45264e1f6821e570f Mon Sep 17 00:00:00 2001 From: Sam Avery Date: Fri, 30 Aug 2024 17:17:16 -0700 Subject: [PATCH 1/9] feature: WavTokenizer starting point --- WavTokenizer.yaml | 168 ++++++++++++++++++++++++++++++++++++++++++++++ codecformer3.py | 161 ++++++++++++++++++++++++++++++++++++++++---- train_cdf.py | 2 +- 3 files changed, 318 insertions(+), 13 deletions(-) create mode 100644 WavTokenizer.yaml diff --git a/WavTokenizer.yaml b/WavTokenizer.yaml new file mode 100644 index 0000000..bdfd28e --- /dev/null +++ b/WavTokenizer.yaml @@ -0,0 +1,168 @@ +# ################################ +# Model: Codecformer for source separation +# https://arxiv.org/abs/2406.12434 +# Dataset : WSJ0-2mix and WSJ0-3mix +# ################################ +# +# Basic parameters +# Seed needs to be set at top of yaml, before objects with parameters are made +# +seed: 1234 +__set_seed: !apply:torch.manual_seed [1234] + +# Data params + +# e.g. '/yourpath/wsj0-mix/2speakers' +# end with 2speakers for wsj0-2mix or 3speakers for wsj0-3mix +data_folder: /wsj0-mix/2speakers + +# the path for wsj0/si_tr_s/ folder -- only needed if dynamic mixing is used +# e.g. /yourpath/wsj0-processed/si_tr_s/ +# you need to convert the original wsj0 to 8k +# you can do this conversion with the script ../meta/preprocess_dynamic_mixing.py +base_folder_dm: /wsj0-mix/2speakers/si_tr_s/ + +experiment_name: codecformer/DAC_original_L4nq +output_folder: !ref results// +train_log: !ref /train_log.txt +save_folder: !ref /save +train_data: !ref /wsj_tr.csv +valid_data: !ref /wsj_cv.csv +test_data: !ref /wsj_tt.csv +skip_prep: false + + +# Experiment params +auto_mix_prec: false # Set it to True for mixed precision +test_only: false +num_spks: 2 # set to 3 for wsj0-3mix +noprogressbar: false +save_audio: true # Save estimated sources on disk +n_audio_to_save: 5 +sample_rate: 8000 +quantize_before: false +quantize_after: false + +# Training parameters +N_epochs: 20 +batch_size: 1 #3 +lr: 0.00015 #0.003 +clip_grad_norm: 5 +loss_upper_lim: 999999 # this is the upper limit for an acceptable loss +# if True, the training sequences are cut to a specified length +limit_training_signal_len: false +# this is the length of sequences if we choose to limit +# the signal length of training sequences +training_signal_len: 40000 + +# Set it to True to dynamically create mixtures at training time +dynamic_mixing: false + +# Parameters for data augmentation +use_wavedrop: false +use_speedperturb: true +use_rand_shift: false +min_shift: -8000 +max_shift: 8000 + +# Speed perturbation +speed_changes: [95, 100, 105] # List of speed changes for time-stretching + +speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb + orig_freq: !ref + speeds: !ref + +# Frequency drop: randomly drops a number of frequency bands to zero. +drop_freq_low: 0 # Min frequency band dropout probability +drop_freq_high: 1 # Max frequency band dropout probability +drop_freq_count_low: 1 # Min number of frequency bands to drop +drop_freq_count_high: 3 # Max number of frequency bands to drop +drop_freq_width: 0.05 # Width of frequency bands to drop + +drop_freq: !new:speechbrain.augment.time_domain.DropFreq + drop_freq_low: !ref + drop_freq_high: !ref + drop_freq_count_low: !ref + drop_freq_count_high: !ref + drop_freq_width: !ref + +# Time drop: randomly drops a number of temporal chunks. +drop_chunk_count_low: 1 # Min number of audio chunks to drop +drop_chunk_count_high: 5 # Max number of audio chunks to drop +drop_chunk_length_low: 1000 # Min length of audio chunks to drop +drop_chunk_length_high: 2000 # Max length of audio chunks to drop + +drop_chunk: !new:speechbrain.augment.time_domain.DropChunk + drop_length_low: !ref + drop_length_high: !ref + drop_count_low: !ref + drop_count_high: !ref + +# loss thresholding -- this thresholds the training loss +threshold_byloss: True +threshold: -30 + +# Dataloader options +# Set num_workers: 0 on MacOS due to behavior of the multiprocessing library +dataloader_opts: + batch_size: !ref + num_workers: 3 + +test_dataloader_opts: + batch_size: 1 + num_workers: 3 + +# Specifying the network + +# Encoder parameters +channels: 1024 +block_channels: 256 #1024 #256 + +block: !new:speechbrain.lobes.models.dual_path.SBTransformerBlock + num_layers: 16 #16 + d_model: 256 + nhead: 8 #1/8 + d_ffn: 1024 #2048? + dropout: 0.1 #0.0/0.1/0.5 + use_positional_encoding: true + norm_before: true + +dacmodel: !new:speechbrain.lobes.models.codecformer3.WavTokenizerWrapper + input_sample_rate: 8000 + model_config_path: /WavTokenizer/wavtokenizer/configs/wavtokenizer_smalldata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml + model_ckpt_path: /WavTokenizer_small_320_24k_4096.ckpt + tokenizer_sample_rate: 24000 + Freeze: true + +sepmodel: !new:speechbrain.lobes.models.codecformer3.simpleSeparator2 + # dacmodel: !ref + num_spks: 2 + channels: 512 ## Note needs to be 512 for WavTokenizer and 1024 for DAC + block: !ref + block_channels: 256 + +optimizer: !name:torch.optim.Adam + lr: !ref + weight_decay: 0 + +#Loss parameters +loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper + +lr_scheduler: !new:speechbrain.nnet.schedulers.ReduceLROnPlateau + factor: 0.5 + patience: 2 + dont_halve_until_epoch: 5 + +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +modules: + sepmodel: !ref +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + sepmodel: !ref + counter: !ref + lr_scheduler: !ref +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref diff --git a/codecformer3.py b/codecformer3.py index 5a25b61..afa0a84 100644 --- a/codecformer3.py +++ b/codecformer3.py @@ -1,21 +1,13 @@ import dac print("descript audio codec v",dac.__version__) -from audiotools import AudioSignal import torchaudio.transforms as T -import math +import torchaudio import torch import torch.nn as nn import torch.nn.functional as F -import copy -import collections -import warnings -import pyloudnorm -import random -import numpy as np from torch.nn.utils import weight_norm from dac.nn.layers import Snake1d -from speechbrain.lobes.models.transformer.Transformer import TransformerDecoder -from speechbrain.lobes.models.transformer.Transformer import PositionalEncoding +from wavtokenizer.decoder.pretrained import WavTokenizer class DACWrapper(): ''' @@ -147,7 +139,7 @@ def __init__(self, num_spks, channels, block, block_channels): ) def forward(self,x): - + x = x.clone() x = self.ch_down(x) #[B,N,L] x = x.permute(0,2,1) @@ -176,4 +168,149 @@ def forward(self,x): x = x.transpose(0,1) # [spks, B, N, L] - return x \ No newline at end of file + return x + +class WavTokenizerWrapper: + ''' + Wrapper model for WavTokenizer + ''' + def __init__(self, input_sample_rate=24000, model_config_path=None, model_ckpt_path=None, tokenizer_sample_rate=24000, Freeze=True): + ''' + input_sample_rate: defaults to 24000 as it's the standard for WavTokenizer + model_config_path: Path to the config file for WavTokenizer + model_ckpt_path: Path to the checkpoint file for WavTokenizer + tokenizer_sample_rate: defaults to 24000. Specify if using a model with a different sample rate. + ''' + super(WavTokenizerWrapper, self).__init__() + self.input_sample_rate = input_sample_rate + self.tokenizer_sample_rate = tokenizer_sample_rate + + if model_config_path is None or model_ckpt_path is None: + raise ValueError("Please provide both the model config and checkpoint paths.") + + self.model = WavTokenizer.from_pretrained0802(model_config_path, model_ckpt_path) + + self.dac_sampler = T.Resample(input_sample_rate, tokenizer_sample_rate) + self.org_sampler = T.Resample(tokenizer_sample_rate, input_sample_rate) + + def count_all_parameters(model): return sum(p.numel() for p in model.parameters()) + def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) + + if Freeze: + for param in self.model.parameters(): + param.requires_grad = False + + print(f'Model frozen with {count_parameters(self.model)/1000000:.2f}M trainable parameters remaining') + print(f'Model has {count_all_parameters(self.model)/1000000:.2f}M parameters in total') + else: + print(f'Model with {count_all_parameters(self.model)/1000000:.2f}M trainable parameters loaded') + + def convert_audio(self, wav: torch.Tensor, sr: int, target_sr: int, target_channels: int): + ''' + Converts audio to the desired sample rate and channels. + ''' + assert wav.dim() >= 2, "Audio tensor must have at least 2 dimensions" + assert wav.shape[-2] in [1, 2], "Audio must be mono or stereo." + *shape, channels, length = wav.shape + + if target_channels == 1: + wav = wav.mean(-2, keepdim=True) + elif target_channels == 2: + wav = wav.expand(*shape, target_channels, length) + elif channels == 1: + wav = wav.expand(target_channels, -1) + else: + raise RuntimeError(f"Impossible to convert from {channels} to {target_channels}") + + # Perform the resampling + wav = torchaudio.transforms.Resample(sr, target_sr).to(wav.device)(wav) + return wav + + def resample_audio(self, x, condition): + ''' + Resample the audio according to the condition. + condition: "tokenizer" to set the sampling rate to the tokenizer's rate + "org" to set the sampling rate back to the original rate + ''' + device = x.device + + assert len(x.shape) == 3 + B, C, T = x.shape + assert C == 1 # The model should only handle single channel + + if condition == "tokenizer": + x_resamp = self.dac_sampler(x) + elif condition == "org": + x_resamp = self.org_sampler(x) + + x_resamp = x_resamp / torch.max(x_resamp.abs(), dim=2, keepdim=True)[0] + + return x_resamp.to(device) + + def get_encoded_features(self, x): + ''' + x should be a torch tensor with dimensions [Batch, Channel, Time] + ''' + original_length = x.shape[-1] + + # Ensure the tensor has the right format + # with torch.no_grad(): + x = self.convert_audio(x, self.input_sample_rate, self.tokenizer_sample_rate, 1) + + # If you want to remove batch and channel dimensions for the audio data tensor + x = x.squeeze() # Remove dimensions of size 1 + + # Generate features and discrete codes + bandwidth_id = torch.tensor([0]).to(x.device) + features, discrete_code = self.model.encode_infer(x.unsqueeze(0), bandwidth_id=bandwidth_id) + + return features, original_length + + def get_quantized_features(self, x): + ''' + Expects input [B, D, T] where D is the encoded continuous representation of input + ''' + # Ensure the tensor has 3 dimensions [Batch, Channels, Time] + if x.ndim == 2: + x = x.unsqueeze(1) # Add a channel dimension if missing + + # Directly feed the encoded features to the quantizer + # with torch.no_grad(): + q_res = self.model.feature_extractor.encodec.quantizer.infer(x, frame_rate=self.model.feature_extractor.frame_rate, bandwidth=self.model.feature_extractor.bandwidths[0]) + quantized = q_res.quantized + codes = q_res.codes + commit_loss = q_res.penalty + + # Return the outputs to match the format expected by the rest of your code + return quantized, codes, None, commit_loss, None + + def get_decoded_signal(self, features, original_length): + ''' + Decodes the features back to the audio signal. + ''' + # Decode the features to get the waveform + bandwidth_id = torch.tensor([0]).to(features.device) + + # with torch.no_grad(): + x = self.model.backbone(features, bandwidth_id=bandwidth_id) + y_hat = self.model.head(x) + + # Ensure the output has three dimensions [Batch, Channels, Time] before resampling + if y_hat.ndim == 2: + y_hat = y_hat.unsqueeze(1) # Add a channel dimension if it's missing + + # Resample the decoded signal back to the original sampling rate + y_hat_resampled = self.resample_audio(y_hat, "org") + + # Ensure the output shape matches the original length + if y_hat_resampled.shape[-1] != original_length: + T_origin = original_length + T_est = y_hat_resampled.shape[-1] + + if T_origin > T_est: + y_hat_resampled = F.pad(y_hat_resampled, (0, T_origin - T_est)) + else: + y_hat_resampled = y_hat_resampled[:, :, :T_origin] + + return y_hat_resampled + diff --git a/train_cdf.py b/train_cdf.py index e6d6f0a..4090991 100644 --- a/train_cdf.py +++ b/train_cdf.py @@ -35,7 +35,6 @@ import numpy as np import torch -import torch.nn.functional as F import torchaudio from hyperpyyaml import load_hyperpyyaml from tqdm import tqdm @@ -127,6 +126,7 @@ def compute_forward(self, mix, targets, stage, noise=None): mix_w = self.hparams.dacmodel.get_quantized_features(mix_w)[0] est_mask = self.hparams.sepmodel(mix_w) + mix_w = torch.stack([mix_w] * self.hparams.num_spks) mix_s = mix_w * est_mask # mix_s = est_mask From 1070495aaeeaff9f6d72e5d6751b7e37cecd4ec7 Mon Sep 17 00:00:00 2001 From: Sam Avery Date: Tue, 3 Sep 2024 09:28:05 -0700 Subject: [PATCH 2/9] renaming for consistency --- DAC_original_L4nq.yaml => codecformer-dac.yaml | 0 WavTokenizer.yaml => codecformer-wavtokenizer.yaml | 0 codecformer3.py => codecformer.py | 0 train_cdf.py => train_codecformer.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename DAC_original_L4nq.yaml => codecformer-dac.yaml (100%) rename WavTokenizer.yaml => codecformer-wavtokenizer.yaml (100%) rename codecformer3.py => codecformer.py (100%) rename train_cdf.py => train_codecformer.py (100%) diff --git a/DAC_original_L4nq.yaml b/codecformer-dac.yaml similarity index 100% rename from DAC_original_L4nq.yaml rename to codecformer-dac.yaml diff --git a/WavTokenizer.yaml b/codecformer-wavtokenizer.yaml similarity index 100% rename from WavTokenizer.yaml rename to codecformer-wavtokenizer.yaml diff --git a/codecformer3.py b/codecformer.py similarity index 100% rename from codecformer3.py rename to codecformer.py diff --git a/train_cdf.py b/train_codecformer.py similarity index 100% rename from train_cdf.py rename to train_codecformer.py From 70b1c669653f05495e78bcd4741caf046c75faa6 Mon Sep 17 00:00:00 2001 From: Sam Avery Date: Wed, 4 Sep 2024 10:33:08 -0700 Subject: [PATCH 3/9] minor readme update --- README.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index da17a86..18e6176 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,12 @@ This repository contains a number of scripts required for replicating Codecformer within the speechbrain framework. Unfortunately, they will have to be copied into the respective directories manually. -train_cdf.py -> recipes/WSJ02Mix/Separation +train_codecformer.py -> recipes/WSJ02Mix/Separation -DAC_original_L4nq.yaml -> recipes/WSJ02Mix/Separation/hparams +codecformer-dac.yaml -> recipes/WSJ02Mix/Separation/hparams +codecformer-wavtokenizer.yaml -> recipes/WSJ02Mix/Separation/hparams -codecformer3.py -> speechbrain/lobes/models +codecformer.py -> speechbrain/lobes/models For replication efforts, please note that the activation function of the simpleseparator2 model has a big impact on performance. Ensure that the activation function of the separator matches the activation function used in the final layer of the neural audio codec's encoder. From c5ff4b1ca3302487bd1dcd91b3874d57364e7e70 Mon Sep 17 00:00:00 2001 From: Sam Avery Date: Wed, 4 Sep 2024 10:42:49 -0700 Subject: [PATCH 4/9] add base folder to yamls --- codecformer-dac.yaml | 9 ++++++--- codecformer-wavtokenizer.yaml | 17 ++++++++++------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/codecformer-dac.yaml b/codecformer-dac.yaml index c215279..804f93f 100644 --- a/codecformer-dac.yaml +++ b/codecformer-dac.yaml @@ -11,18 +11,21 @@ seed: 1234 __set_seed: !apply:torch.manual_seed [1234] # Data params +# your base folder where this repo and other relevant files/repos are stored +base_folder: /yourpath +# experiment folder name to generate in -/separation/results +experiment_name: codecformer/DAC_original_L4nq # e.g. '/yourpath/wsj0-mix/2speakers' # end with 2speakers for wsj0-2mix or 3speakers for wsj0-3mix -data_folder: +data_folder: !ref /wsj0-mix/2speakers # the path for wsj0/si_tr_s/ folder -- only needed if dynamic mixing is used # e.g. /yourpath/wsj0-processed/si_tr_s/ # you need to convert the original wsj0 to 8k # you can do this conversion with the script ../meta/preprocess_dynamic_mixing.py -base_folder_dm: /yourpath/wsj0-processed/si_tr_s/ +base_folder_dm: !ref /wsj0-mix/2speakers/si_tr_s/ -experiment_name: codecformer/DAC_original_L4nq output_folder: !ref results// train_log: !ref /train_log.txt save_folder: !ref /save diff --git a/codecformer-wavtokenizer.yaml b/codecformer-wavtokenizer.yaml index bdfd28e..0a334d2 100644 --- a/codecformer-wavtokenizer.yaml +++ b/codecformer-wavtokenizer.yaml @@ -11,18 +11,21 @@ seed: 1234 __set_seed: !apply:torch.manual_seed [1234] # Data params +# your base folder where this repo and other relevant files/repos are stored +base_folder: /yourpath +# experiment folder name to generate in -/separation/results +experiment_name: codecformer/wavtokenizer2 # e.g. '/yourpath/wsj0-mix/2speakers' # end with 2speakers for wsj0-2mix or 3speakers for wsj0-3mix -data_folder: /wsj0-mix/2speakers +data_folder: !ref /wsj0-mix/2speakers # the path for wsj0/si_tr_s/ folder -- only needed if dynamic mixing is used # e.g. /yourpath/wsj0-processed/si_tr_s/ # you need to convert the original wsj0 to 8k # you can do this conversion with the script ../meta/preprocess_dynamic_mixing.py -base_folder_dm: /wsj0-mix/2speakers/si_tr_s/ +base_folder_dm: !ref /wsj0-mix/2speakers/si_tr_s/ -experiment_name: codecformer/DAC_original_L4nq output_folder: !ref results// train_log: !ref /train_log.txt save_folder: !ref /save @@ -127,14 +130,14 @@ block: !new:speechbrain.lobes.models.dual_path.SBTransformerBlock use_positional_encoding: true norm_before: true -dacmodel: !new:speechbrain.lobes.models.codecformer3.WavTokenizerWrapper +dacmodel: !new:speechbrain.lobes.models.codecformer.WavTokenizerWrapper input_sample_rate: 8000 - model_config_path: /WavTokenizer/wavtokenizer/configs/wavtokenizer_smalldata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml - model_ckpt_path: /WavTokenizer_small_320_24k_4096.ckpt + model_config_path: !ref /WavTokenizer/wavtokenizer/configs/wavtokenizer_smalldata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml + model_ckpt_path: !ref /WavTokenizer_small_320_24k_4096.ckpt tokenizer_sample_rate: 24000 Freeze: true -sepmodel: !new:speechbrain.lobes.models.codecformer3.simpleSeparator2 +sepmodel: !new:speechbrain.lobes.models.codecformer.simpleSeparator2 # dacmodel: !ref num_spks: 2 channels: 512 ## Note needs to be 512 for WavTokenizer and 1024 for DAC From c623f3bc0c674be402e1ea7b9f55233761b882d5 Mon Sep 17 00:00:00 2001 From: Sam Avery Date: Wed, 4 Sep 2024 14:36:15 -0700 Subject: [PATCH 5/9] add optional override of sepmodel activation function --- codecformer-wavtokenizer.yaml | 1 + codecformer.py | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/codecformer-wavtokenizer.yaml b/codecformer-wavtokenizer.yaml index 0a334d2..30b3f71 100644 --- a/codecformer-wavtokenizer.yaml +++ b/codecformer-wavtokenizer.yaml @@ -143,6 +143,7 @@ sepmodel: !new:speechbrain.lobes.models.codecformer.simpleSeparator2 channels: 512 ## Note needs to be 512 for WavTokenizer and 1024 for DAC block: !ref block_channels: 256 + activation: !new:torch.nn.LeakyReLU optimizer: !name:torch.optim.Adam lr: !ref diff --git a/codecformer.py b/codecformer.py index afa0a84..b996ead 100644 --- a/codecformer.py +++ b/codecformer.py @@ -119,7 +119,7 @@ def get_decoded_signal(self, x, original_length): return out class simpleSeparator2(nn.Module): - def __init__(self, num_spks, channels, block, block_channels): + def __init__(self, num_spks, channels, block, block_channels, activation=None): super(simpleSeparator2, self).__init__() self.num_spks = num_spks self.channels = channels #this is dependent on the dac model @@ -129,7 +129,10 @@ def __init__(self, num_spks, channels, block, block_channels): #self.time_mix = nn.Conv1d(channels,channels,1,bias=False) self.masker = weight_norm(nn.Conv1d(channels, channels*num_spks, 1, bias=False)) - self.activation = Snake1d(channels) #nn.Tanh() #nn.ReLU() #Snake1d(channels) + if not activation: + self.activation = Snake1d(channels) #nn.Tanh() #nn.ReLU() #Snake1d(channels) + else: + self.activation = activation # gated output layer self.output = nn.Sequential( nn.Conv1d(channels, channels, 1), Snake1d(channels) #nn.Tanh() #, Snake1d(channels)# From 84c2c575dec11986a991d658380b5e305f532f9c Mon Sep 17 00:00:00 2001 From: Sam Avery Date: Wed, 4 Sep 2024 15:53:59 -0700 Subject: [PATCH 6/9] optional file limits for train valid and test --- codecformer-dac.yaml | 6 ++++++ codecformer-wavtokenizer.yaml | 6 ++++++ train_codecformer.py | 11 +++++++++++ 3 files changed, 23 insertions(+) diff --git a/codecformer-dac.yaml b/codecformer-dac.yaml index 804f93f..35739ae 100644 --- a/codecformer-dac.yaml +++ b/codecformer-dac.yaml @@ -33,6 +33,12 @@ train_data: !ref /wsj_tr.csv valid_data: !ref /wsj_cv.csv test_data: !ref /wsj_tt.csv skip_prep: false +# optionally specify the number of files to use for training, validation, and testing +# comment out to use all files +# file_limits: +# train: 100 +# valid: 10 +# test: 10 # Experiment params diff --git a/codecformer-wavtokenizer.yaml b/codecformer-wavtokenizer.yaml index 30b3f71..2fe00f9 100644 --- a/codecformer-wavtokenizer.yaml +++ b/codecformer-wavtokenizer.yaml @@ -33,6 +33,12 @@ train_data: !ref /wsj_tr.csv valid_data: !ref /wsj_cv.csv test_data: !ref /wsj_tt.csv skip_prep: false +# optionally specify the number of files to use for training, validation, and testing +# comment out to use all files +# file_limits: +# train: 100 +# valid: 10 +# test: 10 # Experiment params diff --git a/train_codecformer.py b/train_codecformer.py index 4090991..220185e 100644 --- a/train_codecformer.py +++ b/train_codecformer.py @@ -576,6 +576,17 @@ def dataio_prep(hparams): replacements={"data_root": hparams["data_folder"]}, ) + file_limits = hparams.get("file_limits", {"train": None, "valid": None, "test": None}) + + if file_limits.get("train") is not None: + train_data = train_data.filtered_sorted(select_n=file_limits["train"]) + + if file_limits.get("valid") is not None: + valid_data = valid_data.filtered_sorted(select_n=file_limits["valid"]) + + if file_limits.get("test") is not None: + test_data = test_data.filtered_sorted(select_n=file_limits["test"]) + datasets = [train_data, valid_data, test_data] # 2. Provide audio pipelines From 9d0011d7f674b8385f05ad44f92cf2b273958f1e Mon Sep 17 00:00:00 2001 From: Sam Avery Date: Wed, 4 Sep 2024 16:19:53 -0700 Subject: [PATCH 7/9] fix: codecformer filename --- codecformer-dac.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codecformer-dac.yaml b/codecformer-dac.yaml index 35739ae..0d971af 100644 --- a/codecformer-dac.yaml +++ b/codecformer-dac.yaml @@ -136,13 +136,13 @@ block: !new:speechbrain.lobes.models.dual_path.SBTransformerBlock use_positional_encoding: true norm_before: true -dacmodel: !new:speechbrain.lobes.models.codecformer3.DACWrapper +dacmodel: !new:speechbrain.lobes.models.codecformer.DACWrapper input_sample_rate: 8000 DAC_model_path: #if None, will download model from huggingface. Otherwise, path to checkpoint should be provided for the model to be loaded. Model has been hardcoded to download the 16khz model. please modify the code if you need another model. DAC_sample_rate: 16000 Freeze: true -sepmodel: !new:speechbrain.lobes.models.codecformer3.simpleSeparator2 +sepmodel: !new:speechbrain.lobes.models.codecformer.simpleSeparator2 # dacmodel: !ref num_spks: 2 channels: 1024 From d38e558c6f52aadc4e9fccc7e415031d4202866e Mon Sep 17 00:00:00 2001 From: Sam Avery Date: Wed, 4 Sep 2024 16:28:42 -0700 Subject: [PATCH 8/9] comment cleanup --- codecformer.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/codecformer.py b/codecformer.py index b996ead..99e4e8d 100644 --- a/codecformer.py +++ b/codecformer.py @@ -257,7 +257,6 @@ def get_encoded_features(self, x): original_length = x.shape[-1] # Ensure the tensor has the right format - # with torch.no_grad(): x = self.convert_audio(x, self.input_sample_rate, self.tokenizer_sample_rate, 1) # If you want to remove batch and channel dimensions for the audio data tensor @@ -275,37 +274,30 @@ def get_quantized_features(self, x): ''' # Ensure the tensor has 3 dimensions [Batch, Channels, Time] if x.ndim == 2: - x = x.unsqueeze(1) # Add a channel dimension if missing + x = x.unsqueeze(1) - # Directly feed the encoded features to the quantizer - # with torch.no_grad(): q_res = self.model.feature_extractor.encodec.quantizer.infer(x, frame_rate=self.model.feature_extractor.frame_rate, bandwidth=self.model.feature_extractor.bandwidths[0]) quantized = q_res.quantized codes = q_res.codes commit_loss = q_res.penalty - # Return the outputs to match the format expected by the rest of your code return quantized, codes, None, commit_loss, None def get_decoded_signal(self, features, original_length): ''' Decodes the features back to the audio signal. ''' - # Decode the features to get the waveform bandwidth_id = torch.tensor([0]).to(features.device) - # with torch.no_grad(): x = self.model.backbone(features, bandwidth_id=bandwidth_id) y_hat = self.model.head(x) # Ensure the output has three dimensions [Batch, Channels, Time] before resampling if y_hat.ndim == 2: - y_hat = y_hat.unsqueeze(1) # Add a channel dimension if it's missing + y_hat = y_hat.unsqueeze(1) - # Resample the decoded signal back to the original sampling rate y_hat_resampled = self.resample_audio(y_hat, "org") - # Ensure the output shape matches the original length if y_hat_resampled.shape[-1] != original_length: T_origin = original_length T_est = y_hat_resampled.shape[-1] From f18e1f750d609bcfe917c4addbf99b9d0ea97c44 Mon Sep 17 00:00:00 2001 From: Sam Avery Date: Fri, 6 Sep 2024 09:48:56 -0700 Subject: [PATCH 9/9] wavtokenizer training baseline --- codecformer-wavtokenizer.yaml | 4 +- codecformer.py | 77 +++++++++++++++++------------------ 2 files changed, 39 insertions(+), 42 deletions(-) diff --git a/codecformer-wavtokenizer.yaml b/codecformer-wavtokenizer.yaml index 2fe00f9..0a9b519 100644 --- a/codecformer-wavtokenizer.yaml +++ b/codecformer-wavtokenizer.yaml @@ -53,7 +53,7 @@ quantize_before: false quantize_after: false # Training parameters -N_epochs: 20 +N_epochs: 3 batch_size: 1 #3 lr: 0.00015 #0.003 clip_grad_norm: 5 @@ -124,7 +124,7 @@ test_dataloader_opts: # Specifying the network # Encoder parameters -channels: 1024 +channels: 512 block_channels: 256 #1024 #256 block: !new:speechbrain.lobes.models.dual_path.SBTransformerBlock diff --git a/codecformer.py b/codecformer.py index 99e4e8d..cc5dcdb 100644 --- a/codecformer.py +++ b/codecformer.py @@ -142,7 +142,6 @@ def __init__(self, num_spks, channels, block, block_channels, activation=None): ) def forward(self,x): - x = x.clone() x = self.ch_down(x) #[B,N,L] x = x.permute(0,2,1) @@ -177,9 +176,9 @@ class WavTokenizerWrapper: ''' Wrapper model for WavTokenizer ''' - def __init__(self, input_sample_rate=24000, model_config_path=None, model_ckpt_path=None, tokenizer_sample_rate=24000, Freeze=True): + def __init__(self, input_sample_rate=8000, model_config_path=None, model_ckpt_path=None, tokenizer_sample_rate=24000, Freeze=True): ''' - input_sample_rate: defaults to 24000 as it's the standard for WavTokenizer + input_sample_rate: defaults to 8000 as expected file input model_config_path: Path to the config file for WavTokenizer model_ckpt_path: Path to the checkpoint file for WavTokenizer tokenizer_sample_rate: defaults to 24000. Specify if using a model with a different sample rate. @@ -208,27 +207,6 @@ def count_parameters(model): return sum(p.numel() for p in model.parameters() if else: print(f'Model with {count_all_parameters(self.model)/1000000:.2f}M trainable parameters loaded') - def convert_audio(self, wav: torch.Tensor, sr: int, target_sr: int, target_channels: int): - ''' - Converts audio to the desired sample rate and channels. - ''' - assert wav.dim() >= 2, "Audio tensor must have at least 2 dimensions" - assert wav.shape[-2] in [1, 2], "Audio must be mono or stereo." - *shape, channels, length = wav.shape - - if target_channels == 1: - wav = wav.mean(-2, keepdim=True) - elif target_channels == 2: - wav = wav.expand(*shape, target_channels, length) - elif channels == 1: - wav = wav.expand(target_channels, -1) - else: - raise RuntimeError(f"Impossible to convert from {channels} to {target_channels}") - - # Perform the resampling - wav = torchaudio.transforms.Resample(sr, target_sr).to(wav.device)(wav) - return wav - def resample_audio(self, x, condition): ''' Resample the audio according to the condition. @@ -237,14 +215,16 @@ def resample_audio(self, x, condition): ''' device = x.device - assert len(x.shape) == 3 + assert len(x.shape) == 3, "Input tensor must have 3 dimensions [Batch, Channels, Time]" B, C, T = x.shape - assert C == 1 # The model should only handle single channel + assert C == 1, "Input tensor must be mono-channel [Batch, 1, Time]" if condition == "tokenizer": x_resamp = self.dac_sampler(x) elif condition == "org": x_resamp = self.org_sampler(x) + else: + raise ValueError("Unknown condition for resampling: {}".format(condition)) x_resamp = x_resamp / torch.max(x_resamp.abs(), dim=2, keepdim=True)[0] @@ -256,37 +236,53 @@ def get_encoded_features(self, x): ''' original_length = x.shape[-1] - # Ensure the tensor has the right format - x = self.convert_audio(x, self.input_sample_rate, self.tokenizer_sample_rate, 1) + # Resample the audio to the tokenizer's sample rate + x = self.resample_audio(x, "tokenizer") - # If you want to remove batch and channel dimensions for the audio data tensor - x = x.squeeze() # Remove dimensions of size 1 + # Remove channel dimensions for the audio data tensor + x = x.squeeze(1) # Generate features and discrete codes bandwidth_id = torch.tensor([0]).to(x.device) - features, discrete_code = self.model.encode_infer(x.unsqueeze(0), bandwidth_id=bandwidth_id) - + features, _, _ = self.model.feature_extractor(x, bandwidth_id=bandwidth_id) return features, original_length - def get_quantized_features(self, x): + def get_quantized_features(self, x, bandwidth_id=None): ''' - Expects input [B, D, T] where D is the encoded continuous representation of input + Expects input [B, D, T] where D is the encoded continuous representation of input. + Returns quantized features, codes, latents, commitment loss, and codebook loss in the same format as DACWrapper. ''' + if bandwidth_id is None: + bandwidth_id = torch.tensor([0]).to(x.device) + # Ensure the tensor has 3 dimensions [Batch, Channels, Time] - if x.ndim == 2: - x = x.unsqueeze(1) + if x.ndim != 3: + raise ValueError(f"Expected input to have 3 dimensions [Batch, Channels, Time], but got {x.ndim} dimensions.") + + # Perform the quantization directly on the encoded features + q_res = self.model.feature_extractor.encodec.quantizer( + x, + frame_rate=self.model.feature_extractor.frame_rate, + bandwidth=self.model.feature_extractor.bandwidths[bandwidth_id] + ) - q_res = self.model.feature_extractor.encodec.quantizer.infer(x, frame_rate=self.model.feature_extractor.frame_rate, bandwidth=self.model.feature_extractor.bandwidths[0]) + # Extract necessary outputs quantized = q_res.quantized codes = q_res.codes + latents = x # The input x itself is the latent representation after encoding commit_loss = q_res.penalty - return quantized, codes, None, commit_loss, None + # Placeholder for codebook_loss (not directly available, could be None) + codebook_loss = None + + # Return the outputs in the expected format + return quantized, codes, latents, commit_loss, codebook_loss def get_decoded_signal(self, features, original_length): ''' Decodes the features back to the audio signal. ''' + # Decode the features to get the waveform bandwidth_id = torch.tensor([0]).to(features.device) x = self.model.backbone(features, bandwidth_id=bandwidth_id) @@ -294,10 +290,12 @@ def get_decoded_signal(self, features, original_length): # Ensure the output has three dimensions [Batch, Channels, Time] before resampling if y_hat.ndim == 2: - y_hat = y_hat.unsqueeze(1) + y_hat = y_hat.unsqueeze(1) # Add a channel dimension if it's missing + # Resample the decoded signal back to the original sampling rate y_hat_resampled = self.resample_audio(y_hat, "org") + # Ensure the output shape matches the original length if y_hat_resampled.shape[-1] != original_length: T_origin = original_length T_est = y_hat_resampled.shape[-1] @@ -308,4 +306,3 @@ def get_decoded_signal(self, features, original_length): y_hat_resampled = y_hat_resampled[:, :, :T_origin] return y_hat_resampled -