diff --git a/.gitignore b/.gitignore
index fa9a082..65d704f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,13 +7,13 @@ data/*.ipynb
__pycache__
*.egg-info
-dasp-pytorch/
+
mix_KE_adv/**
.vscode/
logs/**
checkpoints/
debug
-dasp-pytorch
*.wav
*.png
-data/FXencoder_ps.pt
\ No newline at end of file
+data/FXencoder_ps.pt
+outputs/**
\ No newline at end of file
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000..66b2013
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "dasp-pytorch"]
+ path = dasp-pytorch
+ url = https://github.com/csteinmetz1/dasp-pytorch
diff --git a/Assets/diffmst-main_modified.jpg b/Assets/diffmst-main_modified.jpg
new file mode 100644
index 0000000..705280a
Binary files /dev/null and b/Assets/diffmst-main_modified.jpg differ
diff --git a/Assets/mst_final.png b/Assets/mst_final.png
deleted file mode 100644
index eb79192..0000000
Binary files a/Assets/mst_final.png and /dev/null differ
diff --git a/Assets/mst_wbg.png b/Assets/mst_wbg.png
deleted file mode 100644
index f3c2e1e..0000000
Binary files a/Assets/mst_wbg.png and /dev/null differ
diff --git a/README.md b/README.md
index 09c3cf2..eaa47d8 100644
--- a/README.md
+++ b/README.md
@@ -2,95 +2,76 @@
# Differentiable Mixing Style Transfer
-[Paper]() | [Website]()
+[Paper](https://sai-soum.github.io/assets/pdf/diffmst.pdf) | [Website](https://sai-soum.github.io/projects/diffmst/)
-
+
-Mixing style transfer using reference mix.
+
+# Repository Structure
+1. 'configs' - Contains configuration files for training and inference.
+2. 'mst' - Contains the main codebase for the project.
+ - 'dataloaders' - Contains dataloaders for the project.
+ - 'modules' - Contains the modules for different components of the system.
+ - 'mixing' - Contains the mixing modules for creating mixes.
+ - 'loss' - Contains the loss functions for the project.
+ - 'panns' - contains the most basic components like cnn14, resnet, etc.
+ - 'utils' - Contains utility functions for the project.
+3. 'scripts' - Contains scripts for running inference.
# Usage
-
Clone the repository and install the `mst` package.
```
-git clone https://github.com/sai-soum/mix_style_transfer.git
-cd mix_style_transfer
+git clone --recursive https://github.com/sai-soum/Diff-MST.git
+cd Diff-MST
python -m venv env
source env/bin/activate
pip install -e .
```
-[dasp-pytorch](https://github.com/csteinmetz1/dasp-pytorch) is required for differentiable audio effects.
-Clone the repo into the top-level of the project directory.
+[dasp-pytorch](https://github.com/csteinmetz1/dasp-pytorch) is required for differentiable audio effects.
+Install the dependencies for dasp-pytorch.
```
-git clone https://github.com/csteinmetz1/dasp-pytorch.git
cd dasp-pytorch
pip install -e .
```
-Since `dasp` is currently under development you need to pull changes periodically.
-To do so change to the directory and pull.
-```
-cd dasp-pytorch
-git pull
-```
-
-## Inference
-
-```
-CUDA_VISIBLE_DEVICES=5 python scripts/run.py \
-checkpoints/20230719/config.yaml \
-checkpoints/20230719/epoch=132-step=83125.ckpt \
-"/import/c4dm-02/acw639/DiffMST/song 2/Kat Wright_By My Side/" \
-output/ref_mix.wav \
-```
-
## Train
-
-First update the paths in the configuration file for both the logger and the dataset root directory.
+We use [LightningCLI](https://lightning.ai/docs/pytorch/stable/) for training and [Wandb](https://wandb.ai/site) for logging.
+First update the paths in the configuration file for both the logger, loss function, and the dataset root directory.
Then call the `main.py` script passing in the configuration file.
+
+### Method 1: Training with random mixes of the same song as reference using MRSTFT loss.
```
-# new model configuration with audio feature loss
CUDA_VISIBLE_DEVICES=0 python main.py fit \
--c configs/config_cjs.yaml \
+-c configs/config.yaml \
-c configs/optimizer.yaml \
--c configs/data/medley+cambridge+jamendo-8.yaml \
--c configs/models/gain+eq+comp-feat.yaml
+-c configs/data/medley+cambridge-8.yaml \
+-c configs/models/naive.yaml
+```
+You can change the number of tracks, the size of training data for an epoch, and the batch size in the data configuration file located at `configs/data/`
-# new model configuration with CLAP loss
+### Method 2: Training with real unpaired songs as reference using AFloss.
+```
CUDA_VISIBLE_DEVICES=0 python main.py fit \
--c configs/config_cjs.yaml \
+-c configs/config.yaml \
-c configs/optimizer.yaml \
-c configs/data/medley+cambridge+jamendo-8.yaml \
--c configs/models/gain+eq+comp-clap.yaml
+-c configs/models/naive+feat.yaml
```
+## Inference
+To evaluate the model on real world data, run the ` scripts/eval_all_combo.py` script.
-# Stability (ignore)
-```
-source env/bin/activate
-cd /scratch
-mkdir medleydb
-cd medleydb
-aws s3 sync s3://stability-aws/MedleyDB ./
-tar -xvf MedleyDB_v1.tar
-tar -xvf MedleyDB_v2.tar
-python main.py fit -c configs/config.yaml -c configs/optimizer.yaml -c configs/data/medleydb_cjs.yaml -c configs/models/naive_dmc_adv.yaml
-CUDA_VISIBLE_DEVICES=7 python main.py fit -c configs/config_cjs.yaml -c configs/optimizer.yaml -c configs/data/medleydb_c4dm.yaml -c configs/models/ke_dmc_adv.yaml
-
-CUDA_VISIBLE_DEVICES=7 python main.py fit -c configs/config.yaml -c configs/optimizer.yaml -c configs/data/medley+cambridge-4.yaml -c configs/models/naive+fx_encoder_loss.yaml
-
-To run the paramloss code
-
-CUDA_VISIBLE_DEVICES=2 python main.py fit -c configs/config.yaml -c configs/optimizer.yaml -c configs/data/medley+cambridge-4.yaml -c configs/models/naive+paramloss.yaml
+Update the model checkpoints and the inference examples directory in the script.
-```
+`Python 3.10` was used for training.
diff --git a/configs/config.yaml b/configs/config.yaml
index cedbb7f..3dd2465 100644
--- a/configs/config.yaml
+++ b/configs/config.yaml
@@ -6,34 +6,34 @@ trainer:
init_args:
project: DiffMST
save_dir: /import/c4dm-datasets-ext/diffmst_logs_soum
-
enable_checkpointing: true
-
-
callbacks:
- class_path: mst.callbacks.audio.LogAudioCallback
- class_path: pytorch_lightning.callbacks.ModelSummary
init_args:
max_depth: 2
-
- class_path: mst.callbacks.mix.LogReferenceMix
init_args:
- root_dirs:
- - /import/c4dm-datasets-ext/diffmst-examples/song1/BenFlowers_Ecstasy_Full/
- - /import/c4dm-datasets-ext/diffmst-examples/song2/Kat Wright_By My Side/
- - /import/c4dm-datasets-ext/diffmst-examples/song3/Titanium_HauntedAge_Full/
+ root_dirs:
+ - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song1/Soren_ALittleLate_Full
+ - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song1/Soren_ALittleLate_Full
+ - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song2/MR0903_Moosmusic_Full
+ - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song2/MR0903_Moosmusic_Full
+ - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song3/SaturnSyndicate_CatchTheWave_Full
ref_mixes:
- - /import/c4dm-datasets-ext/diffmst-examples/song1/ref/_Feel it all Around_ by Washed Out (Portlandia Theme).mp3
- - /import/c4dm-datasets-ext/diffmst-examples/song2/ref/The Dip - Paddle To The Stars (Lyric Video).mp3
- - /import/c4dm-datasets-ext/diffmst-examples/song3/ref/Architects - _Doomsday_.mp3
+ - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song1/ref/Harry Styles - Late Night Talking (Official Video).wav
+ - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song1/ref/Poom - Les Voiles (Official Audio).wav
+ - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song2/ref/Justin Timberlake - Can't Stop The Feeling! [Lyrics].wav
+ - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song2/ref/Taylor Swift - Shake It Off.wav
+ - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song3/ref/Miley Cyrus - Wrecking Ball (Lyrics).wav
default_root_dir: null
gradient_clip_val: 10.0
- devices: 3
- detect_anomaly: True
-
+ devices: 1
check_val_every_n_epoch: 1
- max_epochs: 10000
- log_every_n_steps: 200
+
+ max_epochs: 800
+
+ log_every_n_steps: 50
accelerator: gpu
strategy: ddp_find_unused_parameters_true
sync_batchnorm: true
@@ -42,8 +42,5 @@ trainer:
num_sanity_val_steps: 2
benchmark: true
accumulate_grad_batches: 1
- reload_dataloaders_every_n_epochs: 1
-
+ #reload_dataloaders_every_n_epochs: 1
-# - /import/c4dm-datasets-ext/diffmst-examples/song1/BenFlowers_Ecstasy_Full/
-# - /import/c4dm-datasets-ext/diffmst_validation/listening/diffmst-examples_wavref/Feel it all Around by Washed Out (Portlandia Theme).wav
\ No newline at end of file
diff --git a/configs/config_cjs.yaml b/configs/config_cjs.yaml
deleted file mode 100644
index 3dd2465..0000000
--- a/configs/config_cjs.yaml
+++ /dev/null
@@ -1,46 +0,0 @@
-seed_everything: 42
-
-trainer:
- logger:
- class_path: pytorch_lightning.loggers.WandbLogger
- init_args:
- project: DiffMST
- save_dir: /import/c4dm-datasets-ext/diffmst_logs_soum
- enable_checkpointing: true
- callbacks:
- - class_path: mst.callbacks.audio.LogAudioCallback
- - class_path: pytorch_lightning.callbacks.ModelSummary
- init_args:
- max_depth: 2
- - class_path: mst.callbacks.mix.LogReferenceMix
- init_args:
- root_dirs:
- - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song1/Soren_ALittleLate_Full
- - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song1/Soren_ALittleLate_Full
- - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song2/MR0903_Moosmusic_Full
- - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song2/MR0903_Moosmusic_Full
- - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song3/SaturnSyndicate_CatchTheWave_Full
- ref_mixes:
- - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song1/ref/Harry Styles - Late Night Talking (Official Video).wav
- - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song1/ref/Poom - Les Voiles (Official Audio).wav
- - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song2/ref/Justin Timberlake - Can't Stop The Feeling! [Lyrics].wav
- - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song2/ref/Taylor Swift - Shake It Off.wav
- - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song3/ref/Miley Cyrus - Wrecking Ball (Lyrics).wav
- default_root_dir: null
- gradient_clip_val: 10.0
- devices: 1
- check_val_every_n_epoch: 1
-
- max_epochs: 800
-
- log_every_n_steps: 50
- accelerator: gpu
- strategy: ddp_find_unused_parameters_true
- sync_batchnorm: true
- precision: 32
- enable_model_summary: true
- num_sanity_val_steps: 2
- benchmark: true
- accumulate_grad_batches: 1
- #reload_dataloaders_every_n_epochs: 1
-
diff --git a/configs/config_param.yaml b/configs/config_param.yaml
deleted file mode 100644
index 58f88b4..0000000
--- a/configs/config_param.yaml
+++ /dev/null
@@ -1,29 +0,0 @@
-seed_everything: 42
-#ckpt_path: /import/c4dm-datasets-ext/Diff-MST/DiffMST/4bjbp29c/checkpoints/epoch=118-step=148750.ckpt
-
-trainer:
- logger:
- class_path: pytorch_lightning.loggers.WandbLogger
- init_args:
- project: DiffMST-Param
- save_dir: /import/c4dm-datasets-ext/Diff-MST
- enable_checkpointing: true
- callbacks:
- - class_path: pytorch_lightning.callbacks.ModelSummary
- init_args:
- max_depth: 2
- default_root_dir: null
- gradient_clip_val: 10.0
- devices: 1
- check_val_every_n_epoch: 1
- max_epochs: 500
- log_every_n_steps: 50
- accelerator: gpu
- strategy: ddp_find_unused_parameters_true
- sync_batchnorm: true
- precision: 32
- enable_model_summary: true
- num_sanity_val_steps: 2
- benchmark: true
- accumulate_grad_batches: 1
-
diff --git a/configs/configs_hpc.yaml b/configs/configs_hpc.yaml
deleted file mode 100644
index 547c84a..0000000
--- a/configs/configs_hpc.yaml
+++ /dev/null
@@ -1,29 +0,0 @@
-seed_everything: 42
-
-trainer:
- logger:
- class_path: pytorch_lightning.loggers.WandbLogger
- init_args:
- name: DiffMST_naive_advanced_combined_data
- project: DiffMST
- save_dir: /data/scratch/acw639/DiffMST/logs
- enable_checkpointing: true
- callbacks:
- - class_path: mst.callbacks.audio.LogAudioCallback
- - class_path: pytorch_lightning.callbacks.ModelSummary
- init_args:
- max_depth: 2
- default_root_dir: null
- gradient_clip_val: 4.0
- devices: 1
- check_val_every_n_epoch: 1
- max_steps: 1000000
- log_every_n_steps: 50
- accelerator: gpu
- sync_batchnorm: true
- precision: 32
- enable_model_summary: true
- num_sanity_val_steps: 2
- benchmark: true
- accumulate_grad_batches: 4
- reload_dataloaders_every_n_epochs: 1
diff --git a/configs/models/naive+fx_encoder_loss.yaml b/configs/models/naive+fx_encoder_loss.yaml
deleted file mode 100644
index e75f8b6..0000000
--- a/configs/models/naive+fx_encoder_loss.yaml
+++ /dev/null
@@ -1,63 +0,0 @@
-model:
- class_path: mst.system.System
- init_args:
- generate_mix: True
- active_eq_epoch: 0
- active_compressor_epoch: 0
- active_fx_bus_epoch: 1000000
- active_master_bus_epoch: 0
- mix_fn: mst.mixing.naive_random_mix
- mix_console:
- class_path: mst.modules.AdvancedMixConsole
- init_args:
- sample_rate: 44100
- input_min_gain_db: -48.0
- input_max_gain_db: 48.0
- output_min_gain_db: -48.0
- output_max_gain_db: 48.0
- eq_min_gain_db: -12.0
- eq_max_gain_db: 12.0
- min_pan: 0.0
- max_pan: 1.0
- model:
- class_path: mst.modules.MixStyleTransferModel
- init_args:
- track_encoder:
- class_path: mst.modules.SpectrogramEncoder
- init_args:
- embed_dim: 512
- n_fft: 2048
- hop_length: 512
- input_batchnorm: false
- mix_encoder:
- class_path: mst.modules.SpectrogramEncoder
- init_args:
- embed_dim: 512
- n_fft: 2048
- hop_length: 512
- input_batchnorm: false
- controller:
- class_path: mst.modules.TransformerController
- init_args:
- embed_dim: 512
- num_track_control_params: 27
- num_fx_bus_control_params: 25
- num_master_bus_control_params: 26
- num_layers: 12
- nhead: 8
- loss:
- class_path: mst.loss.FX_encoder_loss
- init_args:
- audiofeatures: False
- # weights:
- # - 0.1 # rms
- # - 0.001 # crest factor
- # - 0.1 # stereo width
- # - 0.01 # stereo imbalance
- # - 0.01 # bark spectrum
- # - 0.01 # fx_encoder
-
-
-
-
-#weights = [1.0, 0.001, 1.0, 1.0, 0.01 , 0.01]
\ No newline at end of file
diff --git a/configs/models/naive+ours.yaml b/configs/models/naive+ours.yaml
deleted file mode 100644
index 3356299..0000000
--- a/configs/models/naive+ours.yaml
+++ /dev/null
@@ -1,55 +0,0 @@
-model:
- class_path: mst.system.System
- init_args:
- generate_mix: false
- active_eq_epoch: 0
- active_compressor_epoch: 0
- active_fx_bus_epoch: 1000
- active_master_bus_epoch: 0
- mix_fn: mst.mixing.naive_random_mix
- mix_console:
- class_path: mst.modules.AdvancedMixConsole
- init_args:
- sample_rate: 44100
- input_min_gain_db: -48.0
- input_max_gain_db: 48.0
- output_min_gain_db: -48.0
- output_max_gain_db: 48.0
- eq_min_gain_db: -12.0
- eq_max_gain_db: 12.0
- min_pan: 0.0
- max_pan: 1.0
- model:
- class_path: mst.modules.MixStyleTransferModel
- init_args:
- track_encoder:
- class_path: mst.modules.SpectrogramEncoder
- init_args:
- embed_dim: 512
- n_fft: 2048
- hop_length: 512
- input_batchnorm: false
- mix_encoder:
- class_path: mst.modules.SpectrogramEncoder
- init_args:
- embed_dim: 512
- n_fft: 2048
- hop_length: 512
- input_batchnorm: false
- controller:
- class_path: mst.modules.TransformerController
- init_args:
- embed_dim: 512
- num_track_control_params: 27
- num_fx_bus_control_params: 25
- num_master_bus_control_params: 26
- num_layers: 12
- nhead: 8
-
- loss:
- class_path: mst.loss.ParameterEstimatorLoss
- init_args:
- ckpt_path: /import/c4dm-datasets-ext/Diff-MST/DiffMST-Param/0ymfi1pp/checkpoints/epoch=20-step=37947.ckpt
-
-
-
diff --git a/configs/models/naive+paramloss.yaml b/configs/models/naive+paramloss.yaml
deleted file mode 100644
index a9215e2..0000000
--- a/configs/models/naive+paramloss.yaml
+++ /dev/null
@@ -1,58 +0,0 @@
-model:
- class_path: mst.system.System
- init_args:
- generate_mix: True
- active_eq_epoch: 0
- active_compressor_epoch: 0
- active_fx_bus_epoch: 1000
- active_master_bus_epoch: 0
- use_track_loss : false
- use_mix_loss : false
- use_param_loss : true
- mix_fn: mst.mixing.naive_random_mix
- mix_console:
- class_path: mst.modules.AdvancedMixConsole
- init_args:
- sample_rate: 44100
- input_min_gain_db: -48.0
- input_max_gain_db: 48.0
- output_min_gain_db: -48.0
- output_max_gain_db: 48.0
- eq_min_gain_db: -12.0
- eq_max_gain_db: 12.0
- min_pan: 0.0
- max_pan: 1.0
- model:
- class_path: mst.modules.MixStyleTransferModel
- init_args:
- track_encoder:
- class_path: mst.modules.SpectrogramEncoder
- init_args:
- embed_dim: 512
- n_fft: 2048
- hop_length: 512
- input_batchnorm: false
- mix_encoder:
- class_path: mst.modules.SpectrogramEncoder
- init_args:
- embed_dim: 512
- n_fft: 2048
- hop_length: 512
- input_batchnorm: false
- controller:
- class_path: mst.modules.TransformerController
- init_args:
- embed_dim: 512
- num_track_control_params: 27
- num_fx_bus_control_params: 25
- num_master_bus_control_params: 26
- num_layers: 12
- nhead: 8
-
- loss:
- class_path: torch.nn.MSELoss
-
-
-
-
-
diff --git a/configs/models/naive+verb.yaml b/configs/models/naive+verb.yaml
deleted file mode 100644
index eb70265..0000000
--- a/configs/models/naive+verb.yaml
+++ /dev/null
@@ -1,65 +0,0 @@
-model:
- class_path: mst.system.System
- init_args:
- active_eq_epoch: 0
- active_compressor_epoch: 0
- active_fx_bus_epoch: 0
- active_master_bus_epoch: 0
- mix_fn: mst.mixing.naive_random_mix
- mix_console:
- class_path: mst.modules.AdvancedMixConsole
- init_args:
- sample_rate: 44100
- input_min_gain_db: -48.0
- input_max_gain_db: 48.0
- output_min_gain_db: -48.0
- output_max_gain_db: 48.0
- eq_min_gain_db: -12.0
- eq_max_gain_db: 12.0
- min_pan: 0.0
- max_pan: 1.0
- model:
- class_path: mst.modules.MixStyleTransferModel
- init_args:
- track_encoder:
- class_path: mst.modules.SpectrogramEncoder
- init_args:
- embed_dim: 512
- n_fft: 2048
- hop_length: 512
- input_batchnorm: false
- mix_encoder:
- class_path: mst.modules.SpectrogramEncoder
- init_args:
- embed_dim: 512
- n_fft: 2048
- hop_length: 512
- input_batchnorm: false
- controller:
- class_path: mst.modules.TransformerController
- init_args:
- embed_dim: 512
- num_track_control_params: 27
- num_fx_bus_control_params: 25
- num_master_bus_control_params: 26
- num_layers: 12
- nhead: 8
-
- loss:
- class_path: auraloss.freq.MultiResolutionSTFTLoss
- init_args:
- fft_sizes:
- - 512
- - 2048
- - 8192
- hop_sizes:
- - 256
- - 1024
- - 4096
- win_lengths:
- - 512
- - 2048
- - 8192
-
-
-
diff --git a/configs/models/param-estim.yaml b/configs/models/param-estim.yaml
deleted file mode 100644
index bb33019..0000000
--- a/configs/models/param-estim.yaml
+++ /dev/null
@@ -1,36 +0,0 @@
-model:
- class_path: mst.param_system.ParameterEstimationSystem
- init_args:
- mix_console:
- class_path: mst.modules.AdvancedMixConsole
- init_args:
- sample_rate: 44100
- input_min_gain_db: -48.0
- input_max_gain_db: 48.0
- output_min_gain_db: -48.0
- output_max_gain_db: 48.0
- eq_min_gain_db: -12.0
- eq_max_gain_db: 12.0
- min_pan: 0.0
- max_pan: 1.0
- remixer:
- class_path: mst.modules.Remixer
- init_args:
- sample_rate: 44100
- encoder:
- class_path: mst.modules.SpectrogramEncoder
- init_args:
- embed_dim: 1024
- n_inputs: 1
- projector:
- class_path: mst.modules.ParameterProjector
- init_args:
- embed_dim: 2048
- num_tracks: 8
- num_track_control_params: 27
- num_fx_bus_control_params: 25
- num_master_bus_control_params: 26
-
-
-
-
diff --git a/hpc_run.sh b/hpc_run.sh
deleted file mode 100644
index b97dfaa..0000000
--- a/hpc_run.sh
+++ /dev/null
@@ -1,13 +0,0 @@
-#!/bin/bash
-#$ -cwd # Set the working directory for the job to the current directory
-#$ -j y
-#$ -pe smp 8 # Request 8 core
-#$ -l h_rt=240:0:0 # Request 10 days runtime
-#$ -l h_vmem= 20 # Request 8 * 20 = 160G total RAM
-#$ -l gpu=1 # request 1 GPU
-#$ -m bea
-#$ -l cluster="andrena"
-#$ -t 1-10
-
-
-python main.py fit -c configs/configs_hpc.yaml -c configs/optimizer.yaml -c configs/data/combined_hpc.yaml -c configs/models/naive_dmc_adv.yaml
\ No newline at end of file
diff --git a/mst/utils.py b/mst/utils.py
index 117f82d..6c8b757 100644
--- a/mst/utils.py
+++ b/mst/utils.py
@@ -66,14 +66,14 @@ def run_diffmst(
meter = pyln.Meter(44100)
# crop the input tracks and reference mix to the analysis length
- if tracks.shape[-1] > analysis_len:
+ if tracks.shape[-1] >= analysis_len:
analysis_tracks = tracks[
..., track_start_idx : track_start_idx + analysis_len
].clone()
else:
analysis_tracks = tracks.clone()
- if ref.shape[-1] > analysis_len:
+ if ref.shape[-1] >= analysis_len:
analysis_ref = ref[..., ref_start_idx : ref_start_idx + analysis_len]
else:
analysis_ref = ref.clone()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..57d9f32
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,97 @@
+absl-py==2.1.0
+aiohttp==3.9.3
+aiosignal==1.3.1
+antlr4-python3-runtime==4.9.3
+appdirs==1.4.4
+async-timeout==4.0.3
+attrs==23.2.0
+audioread==3.0.1
+auraloss==0.4.0
+bitsandbytes==0.41.0
+certifi==2024.2.2
+cffi==1.16.0
+charset-normalizer==3.3.2
+click==8.1.7
+contourpy==1.2.0
+cycler==0.12.1
+# Editable install with no version control (dasp-pytorch==0.0.1)
+-e /Users/svanka/Codes/Diff-MST/dasp-pytorch
+decorator==5.1.1
+-e git+ssh://git@github.com/sai-soum/Diff-MST.git@249706b4aaa66c16adaf5a0d46413c0337749508#egg=DiffMST
+docker-pycreds==0.4.0
+docstring_parser==0.16
+filelock==3.13.3
+fonttools==4.50.0
+frozenlist==1.4.1
+fsspec==2024.3.1
+future==1.0.0
+gitdb==4.0.11
+GitPython==3.1.43
+grpcio==1.62.1
+hydra-core==1.3.2
+idna==3.6
+importlib_resources==6.4.0
+Jinja2==3.1.3
+joblib==1.3.2
+jsonargparse==4.27.7
+kiwisolver==1.4.5
+lazy_loader==0.3
+librosa==0.10.1
+lightning-utilities==0.11.2
+llvmlite==0.42.0
+Markdown==3.6
+markdown-it-py==3.0.0
+MarkupSafe==2.1.5
+matplotlib==3.8.3
+mdurl==0.1.2
+mpmath==1.3.0
+msgpack==1.0.8
+multidict==6.0.5
+networkx==3.2.1
+numba==0.59.1
+numpy==1.26.4
+omegaconf==2.3.0
+packaging==24.0
+pandas==2.2.1
+pedalboard==0.9.3
+pillow==10.3.0
+platformdirs==4.2.0
+pooch==1.8.1
+protobuf==4.25.3
+psutil==5.9.8
+pycparser==2.22
+Pygments==2.17.2
+pyloudnorm==0.1.1
+pyparsing==3.1.2
+python-dateutil==2.9.0.post0
+pytorch-lightning==2.2.1
+pytz==2024.1
+PyYAML==6.0.1
+requests==2.31.0
+rich==13.7.1
+scikit-learn==1.4.1.post1
+scipy==1.12.0
+sentry-sdk==1.44.0
+setproctitle==1.3.3
+six==1.16.0
+smmap==5.0.1
+soundfile==0.12.1
+soxr==0.3.7
+sympy==1.12
+tabulate==0.9.0
+tensorboard==2.16.2
+tensorboard-data-server==0.7.2
+tensorboardX==2.6.2.2
+threadpoolctl==3.4.0
+torch==2.2.2
+torchaudio==2.2.2
+torchmetrics==1.3.2
+torchvision==0.17.2
+tqdm==4.66.2
+typeshed_client==2.5.1
+typing_extensions==4.10.0
+tzdata==2024.1
+urllib3==2.2.1
+wandb==0.16.5
+Werkzeug==3.0.2
+yarl==1.9.4
diff --git a/scripts/eval_ablation.py b/scripts/eval_ablation.py
new file mode 100644
index 0000000..6970cc2
--- /dev/null
+++ b/scripts/eval_ablation.py
@@ -0,0 +1,259 @@
+# run pretrained models over evaluation set to generate audio examples for the listening test
+import os
+import torch
+import torchaudio
+import pyloudnorm as pyln
+from mst.utils import load_diffmst, run_diffmst
+from mst.loss import compute_barkspectrum, compute_rms, compute_crest_factor, compute_stereo_width, compute_stereo_imbalance, AudioFeatureLoss
+import json
+import numpy as np
+import csv
+import glob
+
+
+def equal_loudness_mix(tracks: torch.Tensor, *args, **kwargs):
+
+ meter = pyln.Meter(44100)
+ target_lufs_db = -48.0
+
+ norm_tracks = []
+ for track_idx in range(tracks.shape[1]):
+ track = tracks[:, track_idx : track_idx + 1, :]
+ lufs_db = meter.integrated_loudness(track.squeeze(0).permute(1, 0).numpy())
+
+ if lufs_db < -80.0:
+ print(f"Skipping track {track_idx} with {lufs_db:.2f} LUFS.")
+ continue
+
+ lufs_delta_db = target_lufs_db - lufs_db
+ track *= 10 ** (lufs_delta_db / 20)
+ norm_tracks.append(track)
+
+ norm_tracks = torch.cat(norm_tracks, dim=1)
+ # create a sum mix with equal loudness
+ sum_mix = torch.sum(norm_tracks, dim=1, keepdim=True).repeat(1, 2, 1)
+ sum_mix /= sum_mix.abs().max()
+
+ return sum_mix, None, None, None
+
+class NumpyEncoder(json.JSONEncoder):
+ """ Special json encoder for numpy types """
+ def default(self, obj):
+ if isinstance(obj, np.integer):
+ return int(obj)
+ elif isinstance(obj, np.floating):
+ return float(obj)
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ return json.JSONEncoder.default(self, obj)
+
+if __name__ == "__main__":
+ meter = pyln.Meter(44100)
+ target_lufs_db = -22.0
+ output_dir = "outputs/ablation"
+ os.makedirs(output_dir, exist_ok=True)
+
+ methods = {
+ "diffmst-16": {
+ "model": load_diffmst(
+ "/Users/svanka/Downloads/b4naquji/config.yaml",
+ "/Users/svanka/Downloads/b4naquji/checkpoints/epoch=191-step=626608.ckpt",
+ ),
+ "func": run_diffmst,
+ },
+ "sum": {
+ "model": (None, None),
+ "func": equal_loudness_mix,
+ },
+ }
+
+ # get the validation examples
+ examples = {
+ "ecstasy": {
+ "tracks": "/Users/svanka/Downloads//diffmst-examples/song1/BenFlowers_Ecstasy_Full/",
+ "ref": "/Users/svanka/Codes/Diff-MST/outputs/ablation_ref_examples/_Feel it all Around_ by Washed Out (Portlandia Theme)_01/",
+ },
+ "by-my-side": {
+ "tracks": "/Users/svanka/Downloads//diffmst-examples/song2/Kat Wright_By My Side/",
+ "ref": "/Users/svanka/Codes/Diff-MST/outputs/ablation_ref_examples/The Dip - Paddle To The Stars (Lyric Video)_01/",
+ },
+ "haunted-aged": {
+ "tracks": "/Users/svanka/Downloads//diffmst-examples/song3/Titanium_HauntedAge_Full/",
+ "ref": "/Users/svanka/Codes/Diff-MST/outputs/ablation_ref_examples/Architects - _Doomsday__01/",
+ },
+ }
+
+
+ loss = AudioFeatureLoss([0.1,0.001,1.0,1.0,0.1], 44100)
+ AF = {}
+ #initialise to negative infinity
+
+ for example_name, example in examples.items():
+
+
+ AF[example_name] = {}
+ print(example_name)
+ example_dir = os.path.join(output_dir, example_name)
+ os.makedirs(example_dir, exist_ok=True)
+ json_dir = os.path.join(output_dir, "AF")
+ if not os.path.exists(json_dir):
+ os.makedirs(json_dir, exist_ok=True)
+ csv_path = os.path.join(json_dir,f"{example_name}.csv")
+ # if not os.path.exists(csv_path):
+ # os.makedirs(csv_path)
+ with open(csv_path, 'w') as f:
+ writer = csv.writer(f)
+ writer.writerow(["method", "audio_type","ablation","start_idx", "stop_idx", "rms", "crest_factor", "stereo_width", "stereo_imbalance", "barkspectrum", "net_AF_loss"])
+ f.close()
+ ref_loudness_target = -16.0
+
+ # --------------first find all the tracks----------------
+ track_filepaths = []
+ for root, dirs, files in os.walk(example["tracks"]):
+ for filepath in files:
+ if filepath.endswith(".wav"):
+ track_filepaths.append(os.path.join(root, filepath))
+
+ print(f"Found {len(track_filepaths)} tracks.")
+
+ # ----------------load the tracks----------------------------
+ tracks = []
+ lengths = []
+ for track_idx, track_filepath in enumerate(track_filepaths):
+ audio, sr = torchaudio.load(track_filepath, backend="soundfile")
+
+ if sr != 44100:
+ audio = torchaudio.functional.resample(audio, sr, 44100)
+
+ # loudness normalize the tracks to -48 LUFS
+ lufs_db = meter.integrated_loudness(audio.permute(1, 0).numpy())
+ # lufs_delta_db = -48 - lufs_db
+ # audio = audio * 10 ** (lufs_delta_db / 20)
+
+ print(track_idx, os.path.basename(track_filepath), audio.shape, sr, lufs_db)
+
+ if audio.shape[0] == 2:
+ audio = audio.mean(dim=0, keepdim=True)
+
+ chs, seq_len = audio.shape
+
+ for ch_idx in range(chs):
+ tracks.append(audio[ch_idx : ch_idx + 1, :])
+ lengths.append(audio.shape[-1])
+
+ # find max length and pad if shorter
+ max_length = max(lengths)
+ min_length = min(lengths)
+ for track_idx in range(len(tracks)):
+ tracks[track_idx] = torch.nn.functional.pad(
+ tracks[track_idx], (0, max_length - lengths[track_idx])
+ )
+
+ # stack into a tensor
+ tracks = torch.cat(tracks, dim=0)
+ tracks = tracks.view(1, -1, max_length)
+ tracks_length = max_length
+ refs = glob.glob(os.path.join(example["ref"],"*.wav"))
+ print("found refs", len(refs))
+ for ref in refs:
+ ref_name = os.path.basename(ref).replace(".wav", "")
+ test_type = ref_name.split("_")[-2] + "_" + ref_name.split("_")[-1]
+ print(test_type)
+
+ print(ref_name)
+ AF[example_name]["ref"] = {}
+ AF[example_name]["pred_mix"] = {}
+ ref_audio, ref_sr = torchaudio.load(ref, backend="soundfile")
+ if ref_sr != 44100:
+ ref_audio = torchaudio.functional.resample(ref_audio, ref_sr, 44100)
+ print(ref_audio.shape, ref_sr)
+ ref_length = ref_audio.shape[-1]
+ ref_audio = ref_audio.view(1, 2, -1)
+
+ #loudness normalize the reference mix to -16 LUFS
+ ref_lufs_db = meter.integrated_loudness(ref_audio.squeeze().permute(1, 0).numpy())
+ lufs_delta_db = ref_loudness_target - ref_lufs_db
+ ref_audio = ref_audio * 10 ** (lufs_delta_db / 20)
+
+
+ # --------------run inference----------------
+ #print(tracks.shape)
+ track_idx = int(tracks_length / 2)
+ ref_idx = int(ref_length / 2)
+ mix_tracks = tracks[..., track_idx - 220500 : track_idx + 220500]
+ ref_analysis = ref_audio[..., ref_idx - 220500 : ref_idx + 220500]
+
+ ref_path = os.path.join(example_dir, os.path.basename(ref).replace(".wav", "-ref-16.wav"))
+ torchaudio.save(ref_path, ref_analysis.squeeze(), 44100)
+
+ for method_name, method in methods.items():
+ AF[example_name]["ref"] [method_name] = {}
+ AF[example_name]["pred_mix"] [method_name] = {}
+
+ print(method_name)
+ model, mix_console = method["model"]
+ func = method["func"]
+
+ with torch.no_grad():
+ result = func(
+ mix_tracks.clone(),
+ ref_analysis.clone(),
+ model,
+ mix_console,
+ track_start_idx=0,
+ ref_start_idx=0,
+ )
+
+ (
+ pred_mix,
+ pred_track_param_dict,
+ pred_fx_bus_param_dict,
+ pred_master_bus_param_dict,
+ ) = result
+
+ bs, chs, seq_len = pred_mix.shape
+ print("pred_mix shape", pred_mix.shape)
+ # loudness normalize the output mix
+ mix_lufs_db = meter.integrated_loudness(
+ pred_mix.squeeze(0).permute(1, 0).numpy()
+ )
+ print("pred_mix_lufs_db", mix_lufs_db)
+ #print(mix_lufs_db)
+ lufs_delta_db = target_lufs_db - mix_lufs_db
+ pred_mix = pred_mix * 10 ** (lufs_delta_db / 20)
+ name = os.path.basename(ref).replace(".wav", "-pred_mix-16.wav")
+ mix_filepath = os.path.join(example_dir, f"{method_name}_{name}")
+ torchaudio.save(mix_filepath, pred_mix.view(chs, -1), 44100)
+
+ # compute audio features
+
+ AF[example_name]["pred_mix"][method_name]["mix-rms"] = 0.1*compute_rms(pred_mix, sample_rate = sr).mean().detach().cpu().numpy()
+ AF[example_name]["pred_mix"][method_name]["mix-crest_factor"] = 0.001*compute_crest_factor(pred_mix, sample_rate = sr).mean().detach().cpu().numpy()
+ AF[example_name]["pred_mix"][method_name]["mix-stereo_width"] = 1.0*compute_stereo_width(pred_mix, sample_rate = sr).detach().cpu().numpy()
+ AF[example_name]["pred_mix"][method_name]["mix-stereo_imbalance"] = 1.0*compute_stereo_imbalance(pred_mix, sample_rate = sr).detach().cpu().numpy()
+ AF[example_name]["pred_mix"][method_name]["mix-barkspectrum"] = 0.1*compute_barkspectrum(pred_mix, sample_rate = sr).mean().detach().cpu().numpy()
+
+ AF[example_name]["ref"][method_name]["mix-rms"] = 0.1*compute_rms(ref_analysis, sample_rate = sr).mean().detach().cpu().numpy()
+ AF[example_name]["ref"][method_name]["mix-crest_factor"] = 0.001*compute_crest_factor(ref_analysis, sample_rate = sr).mean().detach().cpu().numpy()
+ AF[example_name]["ref"][method_name]["mix-stereo_width"] = 1.0*compute_stereo_width(ref_analysis, sample_rate = sr).detach().cpu().numpy()
+ AF[example_name]["ref"][method_name]["mix-stereo_imbalance"] = 1.0*compute_stereo_imbalance(ref_analysis, sample_rate = sr).detach().cpu().numpy()
+ AF[example_name]["ref"][method_name]["mix-barkspectrum"] = 0.1*compute_barkspectrum(ref_analysis, sample_rate = sr).mean().detach().cpu().numpy()
+
+ AF_loss = loss(pred_mix, ref_analysis)
+ AF[example_name]["pred_mix"][method_name]["net_AF_loss"] = sum(AF_loss.values()).detach().cpu().numpy()
+ AF[example_name]["ref"][method_name]["net_AF_loss"] = AF[example_name]["pred_mix"][method_name]["net_AF_loss"]
+
+
+ # save resulting audio and parameters
+ #append to csv the method name, audio section, audio features values and net loss on different columns
+ with open(csv_path, 'a') as f:
+ writer = csv.writer(f)
+ writer.writerow([method_name, "pred_mix", test_type, track_idx - 220500, track_idx + 220500,AF[example_name]["pred_mix"][method_name]["mix-rms"], AF[example_name]["pred_mix"][method_name]["mix-crest_factor"], AF[example_name]["pred_mix"][method_name]["mix-stereo_width"], AF[example_name]["pred_mix"][method_name]["mix-stereo_imbalance"], AF[example_name]["pred_mix"][method_name]["mix-barkspectrum"], AF[example_name]["pred_mix"][method_name]["net_AF_loss"]])
+ writer.writerow([method_name, "ref", test_type, ref_idx - 220500, ref_idx + 220500,AF[example_name]["ref"][method_name]["mix-rms"], AF[example_name]["ref"][method_name]["mix-crest_factor"], AF[example_name]["ref"][method_name]["mix-stereo_width"], AF[example_name]["ref"][method_name]["mix-stereo_imbalance"], AF[example_name]["ref"][method_name]["mix-barkspectrum"], AF[example_name]["ref"][method_name]["net_AF_loss"]])
+ f.close()
+
+
+
+ #write disctionary to json
+
+
diff --git a/scripts/eval_all_combo.py b/scripts/eval_all_combo.py
new file mode 100644
index 0000000..6fa141e
--- /dev/null
+++ b/scripts/eval_all_combo.py
@@ -0,0 +1,282 @@
+# run pretrained models over evaluation set to generate audio examples for the listening test
+import os
+import torch
+import torchaudio
+import pyloudnorm as pyln
+from mst.utils import load_diffmst, run_diffmst
+from mst.loss import compute_barkspectrum, compute_rms, compute_crest_factor, compute_stereo_width, compute_stereo_imbalance, AudioFeatureLoss
+import json
+import numpy as np
+import csv
+
+
+def equal_loudness_mix(tracks: torch.Tensor, *args, **kwargs):
+
+ meter = pyln.Meter(44100)
+ target_lufs_db = -48.0
+
+ norm_tracks = []
+ for track_idx in range(tracks.shape[1]):
+ track = tracks[:, track_idx : track_idx + 1, :]
+ lufs_db = meter.integrated_loudness(track.squeeze(0).permute(1, 0).numpy())
+
+ if lufs_db < -80.0:
+ print(f"Skipping track {track_idx} with {lufs_db:.2f} LUFS.")
+ continue
+
+ lufs_delta_db = target_lufs_db - lufs_db
+ track *= 10 ** (lufs_delta_db / 20)
+ norm_tracks.append(track)
+
+ norm_tracks = torch.cat(norm_tracks, dim=1)
+ # create a sum mix with equal loudness
+ sum_mix = torch.sum(norm_tracks, dim=1, keepdim=True).repeat(1, 2, 1)
+ sum_mix /= sum_mix.abs().max()
+
+ return sum_mix, None, None, None
+
+class NumpyEncoder(json.JSONEncoder):
+ """ Special json encoder for numpy types """
+ def default(self, obj):
+ if isinstance(obj, np.integer):
+ return int(obj)
+ elif isinstance(obj, np.floating):
+ return float(obj)
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ return json.JSONEncoder.default(self, obj)
+
+if __name__ == "__main__":
+ meter = pyln.Meter(44100)
+ target_lufs_db = -22.0
+ output_dir = "outputs/listen"
+ os.makedirs(output_dir, exist_ok=True)
+
+ methods = {
+ "diffmst-16": {
+ "model": load_diffmst(
+ "/Users/svanka/Downloads/b4naquji/config.yaml",
+ "/Users/svanka/Downloads/b4naquji/checkpoints/epoch=191-step=626608.ckpt",
+ ),
+ "func": run_diffmst,
+ },
+ "sum": {
+ "model": (None, None),
+ "func": equal_loudness_mix,
+ },
+ }
+
+ # get the validation examples
+ examples = {
+ # "ecstasy": {
+ # "tracks": "/Users/svanka/Downloads//diffmst-examples/song1/BenFlowers_Ecstasy_Full/",
+ # "ref": "/Users/svanka/Downloads//diffmst-examples/song1/ref/_Feel it all Around_ by Washed Out (Portlandia Theme)_01.wav",
+ # },
+ # "by-my-side": {
+ # "tracks": "/Users/svanka/Downloads//diffmst-examples/song2/Kat Wright_By My Side/",
+ # "ref": "/Users/svanka/Downloads//diffmst-examples/song2/ref/The Dip - Paddle To The Stars (Lyric Video)_01.wav",
+ # },
+ "haunted-aged": {
+ "tracks": "/Users/svanka/Downloads//diffmst-examples/song3/Titanium_HauntedAge_Full/",
+ "ref": "/Users/svanka/Downloads//diffmst-examples/song3/ref/Architects - _Doomsday__01.wav",
+ },
+ }
+ loss = AudioFeatureLoss([0.1,0.001,1.0,1.0,0.1], 44100)
+ AF = {}
+ #initialise to negative infinity
+
+ for example_name, example in examples.items():
+
+ AF[example_name] = {}
+ print(example_name)
+ example_dir = os.path.join(output_dir, example_name)
+ os.makedirs(example_dir, exist_ok=True)
+ json_dir = os.path.join(output_dir, "AF")
+ if not os.path.exists(json_dir):
+ os.makedirs(json_dir, exist_ok=True)
+ csv_path = os.path.join(json_dir,f"{example_name}.csv")
+ # if not os.path.exists(csv_path):
+ # os.makedirs(csv_path)
+ with open(csv_path, 'w') as f:
+ writer = csv.writer(f)
+ writer.writerow(["method", "audio_section","track_start_idx", "track_stop_idx", "ref_start_idx", "ref_stop_idx", "rms", "crest_factor", "stereo_width", "stereo_imbalance", "barkspectrum", "net_AF_loss"])
+ f.close()
+
+ # ----------load reference mix---------------
+ ref_audio, ref_sr = torchaudio.load(example["ref"], backend="soundfile")
+ if ref_sr != 44100:
+ ref_audio = torchaudio.functional.resample(ref_audio, ref_sr, 44100)
+ print(ref_audio.shape, ref_sr)
+ ref_length = ref_audio.shape[-1]
+ # --------------first find all the tracks----------------
+ track_filepaths = []
+ for root, dirs, files in os.walk(example["tracks"]):
+ for filepath in files:
+ if filepath.endswith(".wav"):
+ track_filepaths.append(os.path.join(root, filepath))
+
+ print(f"Found {len(track_filepaths)} tracks.")
+
+ # ----------------load the tracks----------------------------
+ tracks = []
+ lengths = []
+ for track_idx, track_filepath in enumerate(track_filepaths):
+ audio, sr = torchaudio.load(track_filepath, backend="soundfile")
+
+ if sr != 44100:
+ audio = torchaudio.functional.resample(audio, sr, 44100)
+
+ # loudness normalize the tracks to -48 LUFS
+ lufs_db = meter.integrated_loudness(audio.permute(1, 0).numpy())
+ # lufs_delta_db = -48 - lufs_db
+ # audio = audio * 10 ** (lufs_delta_db / 20)
+
+ print(track_idx, os.path.basename(track_filepath), audio.shape, sr, lufs_db)
+
+ if audio.shape[0] == 2:
+ audio = audio.mean(dim=0, keepdim=True)
+
+ chs, seq_len = audio.shape
+
+ for ch_idx in range(chs):
+ tracks.append(audio[ch_idx : ch_idx + 1, :])
+ lengths.append(audio.shape[-1])
+
+ # find max length and pad if shorter
+ max_length = max(lengths)
+ min_length = min(lengths)
+ for track_idx in range(len(tracks)):
+ tracks[track_idx] = torch.nn.functional.pad(
+ tracks[track_idx], (0, max_length - lengths[track_idx])
+ )
+
+ # stack into a tensor
+ tracks = torch.cat(tracks, dim=0)
+ tracks = tracks.view(1, -1, max_length)
+ ref_audio = ref_audio.view(1, 2, -1)
+
+ # crop tracks to max of 60 seconds or so
+ # tracks = tracks[..., :4194304]
+ tracks_length = max_length
+
+ #print(tracks.shape)
+ track_start_idx = int(tracks_length / 4)
+ ref_start_idx = int(ref_length / 4)
+ track_stop_idx = int(3*tracks_length / 4)
+ ref_stop_idx = int(3*ref_length / 4)
+ #find the number of sets of track samples of 10 sec duration each
+ track_num_sets = int((track_stop_idx - track_start_idx) / 441000)
+ ref_num_sets = int((ref_stop_idx - ref_start_idx) / 441000)
+ print("track_num_sets", track_num_sets)
+ print("ref_num_sets", ref_num_sets)
+ min_AF_loss = float('inf')
+ min_AF_loss_example = None
+ for i in range(track_num_sets):
+ for j in range(ref_num_sets):
+ print(f"track-{i}-ref-{j}")
+ #run inference for every combination of track and ref samples and calculate audio features.
+ # We will save the audio features to a csv and audio files in the output directory
+ mix_tracks = tracks[..., track_start_idx + i*441000 : track_start_idx + (i+1)*441000]
+ ref_analysis = ref_audio[..., ref_start_idx + j*441000 : ref_start_idx + (j+1)*441000]
+
+ # create mixes varying the loudness of the reference
+ for ref_loudness_target in [-16.0]:
+ print("Ref loudness", ref_loudness_target)
+ ref_filepath = os.path.join(
+ example_dir,
+ f"ref-analysis-track-{i}-ref-{j}-lufs-{ref_loudness_target:0.0f}.wav",
+ )
+
+ # loudness normalize the reference mix section to -14 LUFS
+ ref_lufs_db = meter.integrated_loudness(
+ ref_analysis.squeeze().permute(1, 0).numpy()
+ )
+ print("ref_lufs_db", ref_lufs_db)
+ lufs_delta_db = ref_loudness_target - ref_lufs_db
+ ref_analysis = ref_analysis * 10 ** (lufs_delta_db / 20)
+
+ torchaudio.save(ref_filepath, ref_analysis.squeeze(), 44100)
+
+ AF_loss = 0
+ for method_name, method in methods.items():
+ AF[example_name][method_name] = {}
+ print(method_name)
+ # tracks (torch.Tensor): Set of input tracks with shape (bs, num_tracks, seq_len)
+ # ref_audio (torch.Tensor): Reference mix with shape (bs, 2, seq_len)
+
+ if method_name == "sum":
+ if ref_loudness_target != -16:
+ continue
+
+
+ model, mix_console = method["model"]
+ func = method["func"]
+
+ #print(tracks.shape, ref_audio.shape)
+ audio_section = f"track-{i}-ref-{j}-lufs-{ref_loudness_target:0.0f}"
+ AF[example_name][method_name][audio_section] = {}
+ AF[example_name][method_name][audio_section]["track_start_idx"] = track_start_idx + i*441000
+ AF[example_name][method_name][audio_section]["track_stop_idx"] = track_start_idx + (i+1)*441000
+ AF[example_name][method_name][audio_section]["ref_start_idx"] = ref_start_idx + j*441000
+ AF[example_name][method_name][audio_section]["ref_stop_idx"] = ref_start_idx + (j+1)*441000
+ with torch.no_grad():
+ result = func(
+ mix_tracks.clone(),
+ ref_analysis.clone(),
+ model,
+ mix_console,
+ track_start_idx=0,
+ ref_start_idx=0,
+ )
+
+ (
+ pred_mix,
+ pred_track_param_dict,
+ pred_fx_bus_param_dict,
+ pred_master_bus_param_dict,
+ ) = result
+
+ bs, chs, seq_len = pred_mix.shape
+ print("pred_mix shape", pred_mix.shape)
+ # loudness normalize the output mix
+ mix_lufs_db = meter.integrated_loudness(
+ pred_mix.squeeze(0).permute(1, 0).numpy()
+ )
+ print("pred_mix_lufs_db", mix_lufs_db)
+ #print(mix_lufs_db)
+ lufs_delta_db = target_lufs_db - mix_lufs_db
+ pred_mix = pred_mix * 10 ** (lufs_delta_db / 20)
+ mix_filepath = os.path.join(
+ example_dir,
+ f"{example_name}-{method_name}-tracks-{i}-ref={j}-lufs-{ref_loudness_target:0.0f}.wav",
+ )
+ torchaudio.save(mix_filepath, pred_mix.view(chs, -1), 44100)
+
+ # compute audio features
+ AF_loss = loss(pred_mix, ref_analysis)
+
+ for key, value in AF_loss.items():
+ AF[example_name][method_name][audio_section][key] = value.detach().cpu().numpy()
+ AF[example_name][method_name][audio_section]["net_AF_loss"] = sum(AF_loss.values()).detach().cpu().numpy()
+ print(AF[example_name][method_name][audio_section])
+
+ if AF[example_name][method_name][audio_section]["net_AF_loss"] < min_AF_loss:
+ min_AF_loss = AF[example_name][method_name][audio_section]["net_AF_loss"]
+ min_AF_loss_example = f"{example_name}-{method_name}-{audio_section}"
+ print("min_AF_loss", min_AF_loss)
+ print("min_AF_loss_example", min_AF_loss_example)
+ # save resulting audio and parameters
+ #append to csv the method name, audio section, audio features values and net loss on different columns
+
+ with open(csv_path, 'a') as f:
+ writer = csv.writer(f)
+ writer.writerow([method_name, audio_section, AF[example_name][method_name][audio_section]["track_start_idx"], AF[example_name][method_name][audio_section]["track_stop_idx"], AF[example_name][method_name][audio_section]["ref_start_idx"], AF[example_name][method_name][audio_section]["ref_stop_idx"], AF[example_name][method_name][audio_section]["mix-rms"], AF[example_name][method_name][audio_section]["mix-crest_factor"], AF[example_name][method_name][audio_section]["mix-stereo_width"], AF[example_name][method_name][audio_section]["mix-stereo_imbalance"], AF[example_name][method_name][audio_section]["mix-barkspectrum"], AF[example_name][method_name][audio_section]["net_AF_loss"]])
+ f.close()
+
+
+ print(f"for {example_name} min loss is {min_AF_loss} corresponding to {min_AF_loss_example}")
+ print()
+
+ #write disctionary to json
+
+
diff --git a/scripts/eval_listen.py b/scripts/eval_listen.py
index 6cd67b6..01f321f 100644
--- a/scripts/eval_listen.py
+++ b/scripts/eval_listen.py
@@ -35,7 +35,7 @@ def equal_loudness_mix(tracks: torch.Tensor, *args, **kwargs):
if __name__ == "__main__":
meter = pyln.Meter(44100)
target_lufs_db = -22.0
- output_dir = "outputs/listen"
+ output_dir = "outputs/listen_1"
os.makedirs(output_dir, exist_ok=True)
methods = {
@@ -113,7 +113,7 @@ def equal_loudness_mix(tracks: torch.Tensor, *args, **kwargs):
# lufs_delta_db = -48 - lufs_db
# audio = audio * 10 ** (lufs_delta_db / 20)
- print(track_idx, os.path.basename(track_filepath), audio.shape, sr, lufs_db)
+ #print(track_idx, os.path.basename(track_filepath), audio.shape, sr, lufs_db)
if audio.shape[0] == 2:
audio = audio.mean(dim=0, keepdim=True)
@@ -123,14 +123,14 @@ def equal_loudness_mix(tracks: torch.Tensor, *args, **kwargs):
for ch_idx in range(chs):
tracks.append(audio[ch_idx : ch_idx + 1, :])
lengths.append(audio.shape[-1])
-
+ print("Loaded tracks.")
# find max length and pad if shorter
max_length = max(lengths)
for track_idx in range(len(tracks)):
tracks[track_idx] = torch.nn.functional.pad(
tracks[track_idx], (0, max_length - lengths[track_idx])
)
-
+ print("Padded tracks.")
# stack into a tensor
tracks = torch.cat(tracks, dim=0)
tracks = tracks.view(1, -1, max_length)
@@ -144,6 +144,8 @@ def equal_loudness_mix(tracks: torch.Tensor, *args, **kwargs):
# create a sum mix with equal loudness
sum_mix = torch.sum(tracks, dim=1, keepdim=True).squeeze(0)
sum_filepath = os.path.join(example_dir, f"{example_name}-sum.wav")
+ os.makepath(sum_filepath)
+ print("sum_mix path created")
# loudness normalize the sum mix
sum_lufs_db = meter.integrated_loudness(sum_mix.permute(1, 0).numpy())
@@ -151,10 +153,12 @@ def equal_loudness_mix(tracks: torch.Tensor, *args, **kwargs):
sum_mix = sum_mix * 10 ** (lufs_delta_db / 20)
torchaudio.save(sum_filepath, sum_mix.view(1, -1), 44100)
+ print("Sum mix saved.")
# save the reference mix
ref_filepath = os.path.join(example_dir, "ref-full.wav")
torchaudio.save(ref_filepath, ref_audio.squeeze(), 44100)
+ print("Reference mix saved.")
for song_section in ["verse", "chorus"]:
print("Mixing", song_section)
diff --git a/scripts/find_lowest_loss.py b/scripts/find_lowest_loss.py
new file mode 100644
index 0000000..79bbf48
--- /dev/null
+++ b/scripts/find_lowest_loss.py
@@ -0,0 +1,19 @@
+import csv
+import sys
+import pandas as pd
+import os
+
+if __name__ == "__main__":
+ csv_paths = ["/Users/svanka/Codes/Diff-MST/outputs/listen/AF/by-my-side.csv",
+ "/Users/svanka/Codes/Diff-MST/outputs/listen/AF/ecstasy.csv"
+ ]
+
+ for csv_path in csv_paths:
+ print(os.path.basename(csv_path).replace('.csv', ''))
+ #read csv using pandas
+ df = pd.read_csv(csv_path)
+ #find the row with the lowest loss
+ lowest_loss = df['net_AF_loss'].min()
+ #find the method,audio_section with the lowest loss
+ lowest_loss_row = df.loc[df['net_AF_loss'] == lowest_loss]
+ print(lowest_loss_row)
diff --git a/scripts/gain_testing.py b/scripts/gain_testing.py
new file mode 100644
index 0000000..eddcc04
--- /dev/null
+++ b/scripts/gain_testing.py
@@ -0,0 +1,219 @@
+
+# run pretrained models over evaluation set to generate audio examples for the listening test
+import os
+import torch
+import torchaudio
+import pyloudnorm as pyln
+from mst.utils import load_diffmst, run_diffmst
+from mst.loss import compute_barkspectrum, compute_rms, compute_crest_factor, compute_stereo_width, compute_stereo_imbalance, AudioFeatureLoss
+import json
+import numpy as np
+import csv
+import glob
+import yaml
+
+
+def equal_loudness_mix(tracks: torch.Tensor, *args, **kwargs):
+
+ meter = pyln.Meter(44100)
+ target_lufs_db = -48.0
+
+ norm_tracks = []
+ for track_idx in range(tracks.shape[1]):
+ track = tracks[:, track_idx : track_idx + 1, :]
+ lufs_db = meter.integrated_loudness(track.squeeze(0).permute(1, 0).numpy())
+
+ if lufs_db < -80.0:
+ print(f"Skipping track {track_idx} with {lufs_db:.2f} LUFS.")
+ continue
+
+ lufs_delta_db = target_lufs_db - lufs_db
+ track *= 10 ** (lufs_delta_db / 20)
+ norm_tracks.append(track)
+
+ norm_tracks = torch.cat(norm_tracks, dim=1)
+ # create a sum mix with equal loudness
+ sum_mix = torch.sum(norm_tracks, dim=1, keepdim=True).repeat(1, 2, 1)
+ sum_mix /= sum_mix.abs().max()
+
+ return sum_mix, None, None, None
+
+
+
+if __name__ == "__main__":
+ meter = pyln.Meter(44100)
+ target_mix_lufs_db = -16.0
+ target_track_lufs_db = -48.0
+ output_dir = "outputs/gain_testing_diff_song_individual_tracks"
+ os.makedirs(output_dir, exist_ok=True)
+
+ methods = {
+ "diffmst-16": {
+ "model": load_diffmst(
+ "/Users/svanka/Downloads/b4naquji/config.yaml",
+ "/Users/svanka/Downloads/b4naquji/checkpoints/epoch=191-step=626608.ckpt",
+ ),
+ "func": run_diffmst,
+ },
+ # "sum": {
+ # "model": (None, None),
+ # "func": equal_loudness_mix,
+ # },
+ }
+
+ ref_dir = "/Users/svanka/Downloads/DSD100subset/sources/Dev/055 - Angels In Amplifiers - I'm Alright"
+ #mix_dir = "/Users/svanka/Downloads/DSD100subset/sources/Dev/055 - Angels In Amplifiers - I'm Alright"
+ mix_dir = "/Users/svanka/Downloads/DSD100subset/Sources/Test/049 - Young Griffo - Facade"
+
+ ref_tracks = glob.glob(os.path.join(ref_dir, "*.wav"))
+ mix_tracks = glob.glob(os.path.join(mix_dir, "*.wav"))
+
+ print(len(ref_tracks), len(mix_tracks))
+ #order the tracks in ref_tracks to vocals, bass, other, drums
+ ref_tracks_ordered = [""] * 4
+ for track in ref_tracks:
+ if "vocals" in track:
+ ref_tracks_ordered[0] = track
+ elif "bass" in track:
+ ref_tracks_ordered[1] = track
+ elif "other" in track:
+ ref_tracks_ordered[2] = track
+ elif "drums" in track:
+ ref_tracks_ordered[3] = track
+ ref_tracks = ref_tracks_ordered
+
+ print(ref_tracks)
+ # print(mix_tracks)
+
+
+ #we will predict a mix for one track from reference, sum of two, sum of three, sum of four tracks from reference as the reference for model
+ # and the mix as the input
+
+ tracks = []
+ #info = torchaudio.info(mix_tracks[0])
+
+
+ track_instrument = []
+ for track in mix_tracks:
+ #audio, sr = torchaudio.load(track, frame_offset = int((info.num_frames)/2 - 220500), num_frames = 441000, backend="soundfile")
+ audio, sr = torchaudio.load(track,num_frames = 441000, backend="soundfile")
+ if sr != 44100:
+ audio = torchaudio.functional.resample(audio, sr, 44100)
+
+ if audio.shape[0] == 2:
+ audio = audio.mean(dim=0, keepdim=True)
+
+ tracks.append(audio)
+ track_instrument.append(os.path.basename(track).replace(".wav", ""))
+
+
+ tracks = torch.cat(tracks, dim=0)
+ print("tracks shape", tracks.shape)
+ tracks = tracks.unsqueeze(0)
+ print("tracks shape", tracks.shape)
+
+ #create a sum mix
+ sum_mix, _, _, _ = equal_loudness_mix(tracks)
+ print("sum_mix shape", sum_mix.shape)
+ save_path = os.path.join(output_dir, f"{os.path.basename(mix_dir)}-sum_mix.wav")
+ torchaudio.save(save_path, sum_mix.view(2, -1), 44100)
+
+ ref_mix_tracks = []
+ info = torchaudio.info(ref_tracks[0])
+ name = "ref_mix-16="
+ data = {}
+ data["track_instrument"] = track_instrument
+ for i , ref_track in enumerate(ref_tracks):
+ instrument = name + "-" + os.path.basename(ref_track).replace(".wav", "")
+ print(instrument)
+ #name = instrument
+ ref_audio, sr = torchaudio.load(ref_track, frame_offset = int((info.num_frames)/2 - 220500), num_frames = 441000, backend="soundfile")
+ if sr != 44100:
+ ref_audio = torchaudio.functional.resample(ref_audio, sr, 44100)
+
+ #loudness normalize the reference mix to -48 LUFS
+ ref_lufs_db = meter.integrated_loudness(ref_audio.squeeze().permute(1, 0).numpy())
+ lufs_delta_db = target_track_lufs_db - ref_lufs_db
+ ref_audio = ref_audio * 10 ** (lufs_delta_db / 20)
+
+ #ref_mix_tracks.append(ref_audio)
+ ref_mix_tracks = [ref_audio]
+ ref_mix = torch.cat(ref_mix_tracks, dim=0)
+ #create a stereo sum mix
+ ref_mix = ref_mix.sum(dim=0, keepdim=True).repeat(1, 2, 1)
+ #normalise to -16 LUFS
+ ref_mix_lufs_db = meter.integrated_loudness(ref_mix.squeeze().permute(1, 0).numpy())
+ lufs_delta_db = target_mix_lufs_db - ref_mix_lufs_db
+ ref_mix = ref_mix * 10 ** (lufs_delta_db / 20)
+ ref_save_path = os.path.join(output_dir, f"{os.path.basename(ref_dir)}-{instrument}.wav")
+ torchaudio.save(ref_save_path, ref_mix.view(2, -1), 44100)
+
+ yaml_path = os.path.join(output_dir, f"{os.path.basename(ref_dir)}-{instrument}.yaml")
+ data["ref_mix"] = ref_save_path
+ data["ref_instruments"] = instrument
+ data["sum_mix"] = save_path
+ #check if the json file exists
+ print("tracks shape", tracks.shape)
+ print("ref_mix shape", ref_mix.shape)
+
+
+
+
+ for method_name, method in methods.items():
+ model, mix_console = method["model"]
+ func = method["func"]
+ with torch.no_grad():
+ result = func(
+ tracks.clone(),
+ ref_mix.clone(),
+ model,
+ mix_console,
+ track_start_idx=0,
+ ref_start_idx=0,
+ )
+
+ (
+ pred_mix,
+ pred_track_param_dict,
+ pred_fx_bus_param_dict,
+ pred_master_bus_param_dict,
+ ) = result
+
+
+ bs, chs, seq_len = pred_mix.shape
+ print("pred_mix shape", pred_mix.shape)
+ # loudness normalize the output mix
+ mix_lufs_db = meter.integrated_loudness(
+ pred_mix.squeeze(0).permute(1, 0).numpy()
+ )
+ print("pred_mix_lufs_db", mix_lufs_db)
+ #print(mix_lufs_db)
+ lufs_delta_db = target_mix_lufs_db - mix_lufs_db
+ pred_mix = pred_mix * 10 ** (lufs_delta_db / 20)
+ pred_mix_name = os.path.basename(mix_dir) + f"-pred_mix-ref_mix-16={instrument}.wav"
+ mix_filepath = os.path.join(output_dir, pred_mix_name)
+ torchaudio.save(mix_filepath, pred_mix.view(chs, -1), 44100)
+ # append to the json file param_dicts
+
+ #print(pred_track_param_dict["input_gain"])
+
+ data["pred_mix"] = pred_mix_name
+ data["gain_values"] = pred_track_param_dict['input_fader']['gain_db'].detach().cpu().numpy().tolist()[0]
+ #print(type(pred_track_param_dict['input_fader']['gain_db']))
+
+
+ with open(yaml_path, "w") as f:
+ yaml.dump(data, f)
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/scripts/generate_ablation_examples.py b/scripts/generate_ablation_examples.py
new file mode 100644
index 0000000..942fbc7
--- /dev/null
+++ b/scripts/generate_ablation_examples.py
@@ -0,0 +1,99 @@
+import torch
+import torchaudio
+import os
+import glob
+
+def low_pass_audio(y, sr, cutoff_freq):
+ # cutoff_freq: in Hz
+ # y: waveform
+ # sr: sample rate
+ # cutoff_freq: cutoff frequency
+ # cutoff_freq = 4000
+ y = torchaudio.functional.lowpass_biquad(y, sr, cutoff_freq)
+ return y
+
+def high_pass_audio(y, sr, cutoff_freq):
+ # cutoff_freq: in Hz
+ # y: waveform
+ # sr: sample rate
+ # cutoff_freq: cutoff frequency
+ # cutoff_freq = 4000
+ y = torchaudio.functional.highpass_biquad(y, sr, cutoff_freq)
+ return y
+
+def band_pass_audio(y, sr, low_cutoff_freq, high_cutoff_freq):
+ # cutoff_freq: in Hz
+ # y: waveform
+ # sr: sample rate
+ # low_cutoff_freq: low cutoff frequency
+ # high_cutoff_freq: high cutoff frequency
+ # low_cutoff_freq = 4000
+ # high_cutoff_freq = 8000
+ y = torchaudio.functional.bandpass_biquad(y, sr, low_cutoff_freq, high_cutoff_freq)
+ return y
+
+def pan_left_audio(y, sr):
+ # y: waveform
+ # sr: sample rate
+ if y.shape[0] != 2:
+ raise ValueError("Audio must have 2 channels for panning.")
+
+ # Apply extreme panning to the left channel
+ panned_waveform = torch.zeros_like(y)
+ panned_waveform[0] = y[0] # Left channel remains unchanged
+ panned_waveform[1] = y[1] * 0.1 # Decrease amplitude of right channel (adjust value as needed)
+ return panned_waveform
+
+def pan_right_audio(y, sr):
+ # y: waveform
+ # sr: sample rate
+ if y.shape[0] != 2:
+ raise ValueError("Audio must have 2 channels for panning.")
+
+ # Apply extreme panning to the right channel
+ panned_waveform = torch.zeros_like(y)
+ panned_waveform[0] = y[0] * 0.1 # Decrease amplitude of left channel (adjust value as needed)
+ panned_waveform[1] = y[1] # Right channel remains unchanged
+ return panned_waveform
+
+
+
+
+if __name__ == "__main__":
+ ref_audio_paths = ["/Users/svanka/Downloads//diffmst-examples/song1/ref/_Feel it all Around_ by Washed Out (Portlandia Theme)_01.wav",
+ "/Users/svanka/Downloads//diffmst-examples/song2/ref/The Dip - Paddle To The Stars (Lyric Video)_01.wav",
+ "/Users/svanka/Downloads//diffmst-examples/song3/ref/Architects - _Doomsday__01.wav"]
+
+ ref_save_path = "outputs/ablation_ref_examples"
+ os.makedirs(ref_save_path, exist_ok=True)
+
+ for ref_audio_path in ref_audio_paths:
+ print(os.path.basename(ref_audio_path) + "...")
+
+ save_path = os.path.join(ref_save_path, os.path.basename(ref_audio_path).replace(".wav", ""))
+ os.makedirs(save_path, exist_ok=True)
+
+
+ y, sr = torchaudio.load(ref_audio_path, backend="soundfile")
+
+ # Apply low-pass filter
+ y_low_pass = low_pass_audio(y, sr, 5000)
+ torchaudio.save(os.path.join(save_path, os.path.basename(ref_audio_path).replace(".wav", "_low_pass.wav")), y_low_pass, sr)
+
+ # Apply high-pass filter
+ y_high_pass = high_pass_audio(y, sr, 4000)
+ torchaudio.save(os.path.join(save_path, os.path.basename(ref_audio_path).replace(".wav", "_high_pass.wav")), y_high_pass, sr)
+
+ # Apply band-pass filter
+ y_band_pass = band_pass_audio(y, sr, 500, 8000)
+ torchaudio.save(os.path.join(save_path, os.path.basename(ref_audio_path).replace(".wav", "_band_pass.wav")), y_band_pass, sr)
+
+ # Pan left
+ y_pan_left = pan_left_audio(y, sr)
+ torchaudio.save(os.path.join(save_path, os.path.basename(ref_audio_path).replace(".wav", "_pan_left.wav")), y_pan_left, sr)
+
+ # Pan right
+ y_pan_right = pan_right_audio(y, sr)
+ torchaudio.save(os.path.join(save_path, os.path.basename(ref_audio_path).replace(".wav", "_pan_right.wav")), y_pan_right, sr)
+
+
diff --git a/scripts/plots.ipynb b/scripts/plots.ipynb
new file mode 100644
index 0000000..d96bc41
--- /dev/null
+++ b/scripts/plots.ipynb
@@ -0,0 +1,2383 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Requirement already satisfied: pandas in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (2.2.1)\n",
+ "Requirement already satisfied: numpy<2,>=1.26.0 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from pandas) (1.26.4)\n",
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from pandas) (2.9.0)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from pandas) (2024.1)\n",
+ "Requirement already satisfied: tzdata>=2022.7 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from pandas) (2024.1)\n",
+ "Requirement already satisfied: six>=1.5 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n",
+ "Requirement already satisfied: matplotlib in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (3.8.3)\n",
+ "Requirement already satisfied: contourpy>=1.0.1 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from matplotlib) (1.2.1)\n",
+ "Requirement already satisfied: cycler>=0.10 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from matplotlib) (0.12.1)\n",
+ "Requirement already satisfied: fonttools>=4.22.0 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from matplotlib) (4.50.0)\n",
+ "Requirement already satisfied: kiwisolver>=1.3.1 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from matplotlib) (1.4.5)\n",
+ "Requirement already satisfied: numpy<2,>=1.21 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from matplotlib) (1.26.4)\n",
+ "Requirement already satisfied: packaging>=20.0 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from matplotlib) (24.0)\n",
+ "Requirement already satisfied: pillow>=8 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from matplotlib) (10.3.0)\n",
+ "Requirement already satisfied: pyparsing>=2.3.1 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from matplotlib) (3.1.2)\n",
+ "Requirement already satisfied: python-dateutil>=2.7 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from matplotlib) (2.9.0)\n",
+ "Requirement already satisfied: six>=1.5 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)\n",
+ "Requirement already satisfied: seaborn in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (0.13.2)\n",
+ "Requirement already satisfied: numpy!=1.24.0,>=1.20 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from seaborn) (1.26.4)\n",
+ "Requirement already satisfied: pandas>=1.2 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from seaborn) (2.2.1)\n",
+ "Requirement already satisfied: matplotlib!=3.6.1,>=3.4 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from seaborn) (3.8.3)\n",
+ "Requirement already satisfied: contourpy>=1.0.1 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (1.2.1)\n",
+ "Requirement already satisfied: cycler>=0.10 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (0.12.1)\n",
+ "Requirement already satisfied: fonttools>=4.22.0 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (4.50.0)\n",
+ "Requirement already satisfied: kiwisolver>=1.3.1 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (1.4.5)\n",
+ "Requirement already satisfied: packaging>=20.0 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (24.0)\n",
+ "Requirement already satisfied: pillow>=8 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (10.3.0)\n",
+ "Requirement already satisfied: pyparsing>=2.3.1 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (3.1.2)\n",
+ "Requirement already satisfied: python-dateutil>=2.7 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (2.9.0)\n",
+ "Requirement already satisfied: pytz>=2020.1 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from pandas>=1.2->seaborn) (2024.1)\n",
+ "Requirement already satisfied: tzdata>=2022.7 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from pandas>=1.2->seaborn) (2024.1)\n",
+ "Requirement already satisfied: six>=1.5 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib!=3.6.1,>=3.4->seaborn) (1.16.0)\n",
+ "Collecting scikit-learn\n",
+ " Using cached scikit_learn-1.4.1.post1-cp312-cp312-macosx_12_0_arm64.whl.metadata (11 kB)\n",
+ "Requirement already satisfied: numpy<2.0,>=1.19.5 in /Users/svanka/miniforge3/envs/dmst/lib/python3.12/site-packages (from scikit-learn) (1.26.4)\n",
+ "Collecting scipy>=1.6.0 (from scikit-learn)\n",
+ " Downloading scipy-1.13.0-cp312-cp312-macosx_12_0_arm64.whl.metadata (60 kB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m60.6/60.6 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hCollecting joblib>=1.2.0 (from scikit-learn)\n",
+ " Using cached joblib-1.3.2-py3-none-any.whl.metadata (5.4 kB)\n",
+ "Collecting threadpoolctl>=2.0.0 (from scikit-learn)\n",
+ " Using cached threadpoolctl-3.4.0-py3-none-any.whl.metadata (13 kB)\n",
+ "Using cached scikit_learn-1.4.1.post1-cp312-cp312-macosx_12_0_arm64.whl (10.5 MB)\n",
+ "Using cached joblib-1.3.2-py3-none-any.whl (302 kB)\n",
+ "Downloading scipy-1.13.0-cp312-cp312-macosx_12_0_arm64.whl (30.4 MB)\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m30.4/30.4 MB\u001b[0m \u001b[31m21.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
+ "\u001b[?25hUsing cached threadpoolctl-3.4.0-py3-none-any.whl (17 kB)\n",
+ "Installing collected packages: threadpoolctl, scipy, joblib, scikit-learn\n",
+ "Successfully installed joblib-1.3.2 scikit-learn-1.4.1.post1 scipy-1.13.0 threadpoolctl-3.4.0\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install pandas\n",
+ "!pip install matplotlib\n",
+ "!pip install seaborn\n",
+ "!pip install scikit-learn"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import pandas as pd\n",
+ "import matplotlib\n",
+ "import csv\n",
+ "import glob\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/Users/svanka/Codes/Diff-MST/outputs/ablation/AF/by-my-side.csv\n",
+ "20\n",
+ "/Users/svanka/Codes/Diff-MST/outputs/ablation/AF/ecstasy.csv\n",
+ "20\n",
+ "/Users/svanka/Codes/Diff-MST/outputs/ablation/AF/haunted-aged.csv\n",
+ "20\n",
+ "60\n"
+ ]
+ }
+ ],
+ "source": [
+ "csv_path = \"/Users/svanka/Codes/Diff-MST/outputs/ablation/AF\"\n",
+ "#append all the csv to a single dataframe and add a column with the name of the file\n",
+ "\n",
+ "df_list = []\n",
+ "\n",
+ "for files in glob.glob(csv_path + \"/*.csv\"):\n",
+ " print(files)\n",
+ " csv_df = pd.read_csv(files)\n",
+ " #add a new clumn with the name of the file and append it to the dataframe\n",
+ " csv_df['file_name'] = os.path.basename(files).replace('.csv','')\n",
+ " print(len(csv_df))\n",
+ " df_list.append(csv_df)\n",
+ "df = pd.concat(df_list, ignore_index=True)\n",
+ "print(len(df))\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " method | \n",
+ " audio_type | \n",
+ " ablation | \n",
+ " start_idx | \n",
+ " stop_idx | \n",
+ " rms | \n",
+ " crest_factor | \n",
+ " stereo_width | \n",
+ " stereo_imbalance | \n",
+ " barkspectrum | \n",
+ " net_AF_loss | \n",
+ " file_name | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " diffmst-16 | \n",
+ " pred_mix | \n",
+ " high_pass | \n",
+ " 5022137 | \n",
+ " 5463137 | \n",
+ " 0.004074 | \n",
+ " 0.021787 | \n",
+ " [0.3353039] | \n",
+ " [-0.17681053] | \n",
+ " 0.653527 | \n",
+ " 0.122524 | \n",
+ " by-my-side | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " diffmst-16 | \n",
+ " ref | \n",
+ " high_pass | \n",
+ " 3223724 | \n",
+ " 3664724 | \n",
+ " 0.007692 | \n",
+ " 0.016933 | \n",
+ " [0.18699424] | \n",
+ " [-0.28859037] | \n",
+ " 0.698815 | \n",
+ " 0.122524 | \n",
+ " by-my-side | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " sum | \n",
+ " pred_mix | \n",
+ " high_pass | \n",
+ " 5022137 | \n",
+ " 5463137 | \n",
+ " 0.006213 | \n",
+ " 0.020306 | \n",
+ " [0.] | \n",
+ " [0.] | \n",
+ " -0.543323 | \n",
+ " 31.851294 | \n",
+ " by-my-side | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " sum | \n",
+ " ref | \n",
+ " high_pass | \n",
+ " 3223724 | \n",
+ " 3664724 | \n",
+ " 0.007692 | \n",
+ " 0.016933 | \n",
+ " [0.18699424] | \n",
+ " [-0.28859037] | \n",
+ " 0.698815 | \n",
+ " 31.851294 | \n",
+ " by-my-side | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " diffmst-16 | \n",
+ " pred_mix | \n",
+ " low_pass | \n",
+ " 5022137 | \n",
+ " 5463137 | \n",
+ " 0.006155 | \n",
+ " 0.015004 | \n",
+ " [0.22357161] | \n",
+ " [0.13522747] | \n",
+ " 0.654838 | \n",
+ " 0.117024 | \n",
+ " by-my-side | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " diffmst-16 | \n",
+ " ref | \n",
+ " low_pass | \n",
+ " 3223724 | \n",
+ " 3664724 | \n",
+ " 0.011253 | \n",
+ " 0.013086 | \n",
+ " [0.2034282] | \n",
+ " [0.03135305] | \n",
+ " 0.751761 | \n",
+ " 0.117024 | \n",
+ " by-my-side | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " sum | \n",
+ " pred_mix | \n",
+ " low_pass | \n",
+ " 5022137 | \n",
+ " 5463137 | \n",
+ " 0.006213 | \n",
+ " 0.020306 | \n",
+ " [0.] | \n",
+ " [0.] | \n",
+ " -0.543323 | \n",
+ " 32.614502 | \n",
+ " by-my-side | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " sum | \n",
+ " ref | \n",
+ " low_pass | \n",
+ " 3223724 | \n",
+ " 3664724 | \n",
+ " 0.011253 | \n",
+ " 0.013086 | \n",
+ " [0.2034282] | \n",
+ " [0.03135305] | \n",
+ " 0.751761 | \n",
+ " 32.614502 | \n",
+ " by-my-side | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " diffmst-16 | \n",
+ " pred_mix | \n",
+ " band_pass | \n",
+ " 5022137 | \n",
+ " 5463137 | \n",
+ " 0.007041 | \n",
+ " 0.013663 | \n",
+ " [0.15272886] | \n",
+ " [-0.10040125] | \n",
+ " 0.531103 | \n",
+ " 0.946850 | \n",
+ " by-my-side | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " diffmst-16 | \n",
+ " ref | \n",
+ " band_pass | \n",
+ " 3223724 | \n",
+ " 3664724 | \n",
+ " 0.016523 | \n",
+ " 0.006827 | \n",
+ " [0.02459108] | \n",
+ " [-0.16006266] | \n",
+ " 0.269596 | \n",
+ " 0.946850 | \n",
+ " by-my-side | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " sum | \n",
+ " pred_mix | \n",
+ " band_pass | \n",
+ " 5022137 | \n",
+ " 5463137 | \n",
+ " 0.006213 | \n",
+ " 0.020306 | \n",
+ " [0.] | \n",
+ " [0.] | \n",
+ " -0.543323 | \n",
+ " 22.868963 | \n",
+ " by-my-side | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " sum | \n",
+ " ref | \n",
+ " band_pass | \n",
+ " 3223724 | \n",
+ " 3664724 | \n",
+ " 0.016523 | \n",
+ " 0.006827 | \n",
+ " [0.02459108] | \n",
+ " [-0.16006266] | \n",
+ " 0.269596 | \n",
+ " 22.868963 | \n",
+ " by-my-side | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " diffmst-16 | \n",
+ " pred_mix | \n",
+ " pan_left | \n",
+ " 5022137 | \n",
+ " 5463137 | \n",
+ " 0.005618 | \n",
+ " 0.014758 | \n",
+ " [0.30145487] | \n",
+ " [-0.8022489] | \n",
+ " 0.679761 | \n",
+ " 0.446519 | \n",
+ " by-my-side | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " diffmst-16 | \n",
+ " ref | \n",
+ " pan_left | \n",
+ " 3223724 | \n",
+ " 3664724 | \n",
+ " 0.008366 | \n",
+ " 0.013010 | \n",
+ " [0.7625234] | \n",
+ " [-0.97916067] | \n",
+ " 0.806224 | \n",
+ " 0.446519 | \n",
+ " by-my-side | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " sum | \n",
+ " pred_mix | \n",
+ " pan_left | \n",
+ " 5022137 | \n",
+ " 5463137 | \n",
+ " 0.006213 | \n",
+ " 0.020306 | \n",
+ " [0.] | \n",
+ " [0.] | \n",
+ " -0.543323 | \n",
+ " 36.472347 | \n",
+ " by-my-side | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " sum | \n",
+ " ref | \n",
+ " pan_left | \n",
+ " 3223724 | \n",
+ " 3664724 | \n",
+ " 0.008366 | \n",
+ " 0.013010 | \n",
+ " [0.7625234] | \n",
+ " [-0.97916067] | \n",
+ " 0.806224 | \n",
+ " 36.472347 | \n",
+ " by-my-side | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " diffmst-16 | \n",
+ " pred_mix | \n",
+ " pan_right | \n",
+ " 5022137 | \n",
+ " 5463137 | \n",
+ " 0.005376 | \n",
+ " 0.015142 | \n",
+ " [0.38029966] | \n",
+ " [0.79385364] | \n",
+ " 0.685591 | \n",
+ " 0.332059 | \n",
+ " by-my-side | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " diffmst-16 | \n",
+ " ref | \n",
+ " pan_right | \n",
+ " 3223724 | \n",
+ " 3664724 | \n",
+ " 0.008894 | \n",
+ " 0.013010 | \n",
+ " [0.7729033] | \n",
+ " [0.9811843] | \n",
+ " 0.794389 | \n",
+ " 0.332059 | \n",
+ " by-my-side | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " sum | \n",
+ " pred_mix | \n",
+ " pan_right | \n",
+ " 5022137 | \n",
+ " 5463137 | \n",
+ " 0.006213 | \n",
+ " 0.020306 | \n",
+ " [0.] | \n",
+ " [0.] | \n",
+ " -0.543323 | \n",
+ " 36.097370 | \n",
+ " by-my-side | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " sum | \n",
+ " ref | \n",
+ " pan_right | \n",
+ " 3223724 | \n",
+ " 3664724 | \n",
+ " 0.008894 | \n",
+ " 0.013010 | \n",
+ " [0.7729033] | \n",
+ " [0.9811843] | \n",
+ " 0.794389 | \n",
+ " 36.097370 | \n",
+ " by-my-side | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " diffmst-16 | \n",
+ " pred_mix | \n",
+ " low_pass | \n",
+ " 5376628 | \n",
+ " 5817628 | \n",
+ " 0.006615 | \n",
+ " 0.012723 | \n",
+ " [0.13142872] | \n",
+ " [0.01397015] | \n",
+ " 0.666661 | \n",
+ " 0.064587 | \n",
+ " ecstasy | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " diffmst-16 | \n",
+ " ref | \n",
+ " low_pass | \n",
+ " 1340076 | \n",
+ " 1781076 | \n",
+ " 0.012675 | \n",
+ " 0.009881 | \n",
+ " [0.18987568] | \n",
+ " [0.08012921] | \n",
+ " 0.714403 | \n",
+ " 0.064587 | \n",
+ " ecstasy | \n",
+ "
\n",
+ " \n",
+ " 22 | \n",
+ " sum | \n",
+ " pred_mix | \n",
+ " low_pass | \n",
+ " 5376628 | \n",
+ " 5817628 | \n",
+ " 0.005921 | \n",
+ " 0.016681 | \n",
+ " [0.] | \n",
+ " [0.] | \n",
+ " -0.523093 | \n",
+ " 31.953974 | \n",
+ " ecstasy | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " sum | \n",
+ " ref | \n",
+ " low_pass | \n",
+ " 1340076 | \n",
+ " 1781076 | \n",
+ " 0.012675 | \n",
+ " 0.009881 | \n",
+ " [0.18987568] | \n",
+ " [0.08012921] | \n",
+ " 0.714403 | \n",
+ " 31.953974 | \n",
+ " ecstasy | \n",
+ "
\n",
+ " \n",
+ " 24 | \n",
+ " diffmst-16 | \n",
+ " pred_mix | \n",
+ " pan_right | \n",
+ " 5376628 | \n",
+ " 5817628 | \n",
+ " 0.005277 | \n",
+ " 0.012306 | \n",
+ " [1.0680251] | \n",
+ " [0.87477577] | \n",
+ " 0.709141 | \n",
+ " 0.147120 | \n",
+ " ecstasy | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " diffmst-16 | \n",
+ " ref | \n",
+ " pan_right | \n",
+ " 1340076 | \n",
+ " 1781076 | \n",
+ " 0.009609 | \n",
+ " 0.010431 | \n",
+ " [0.77672297] | \n",
+ " [0.9830406] | \n",
+ " 0.761654 | \n",
+ " 0.147120 | \n",
+ " ecstasy | \n",
+ "
\n",
+ " \n",
+ " 26 | \n",
+ " sum | \n",
+ " pred_mix | \n",
+ " pan_right | \n",
+ " 5376628 | \n",
+ " 5817628 | \n",
+ " 0.005921 | \n",
+ " 0.016681 | \n",
+ " [0.] | \n",
+ " [0.] | \n",
+ " -0.523093 | \n",
+ " 35.442448 | \n",
+ " ecstasy | \n",
+ "
\n",
+ " \n",
+ " 27 | \n",
+ " sum | \n",
+ " ref | \n",
+ " pan_right | \n",
+ " 1340076 | \n",
+ " 1781076 | \n",
+ " 0.009609 | \n",
+ " 0.010431 | \n",
+ " [0.77672297] | \n",
+ " [0.9830406] | \n",
+ " 0.761654 | \n",
+ " 35.442448 | \n",
+ " ecstasy | \n",
+ "
\n",
+ " \n",
+ " 28 | \n",
+ " diffmst-16 | \n",
+ " pred_mix | \n",
+ " band_pass | \n",
+ " 5376628 | \n",
+ " 5817628 | \n",
+ " 0.007070 | \n",
+ " 0.009780 | \n",
+ " [0.09989402] | \n",
+ " [-0.23715419] | \n",
+ " 0.580188 | \n",
+ " 0.801008 | \n",
+ " ecstasy | \n",
+ "
\n",
+ " \n",
+ " 29 | \n",
+ " diffmst-16 | \n",
+ " ref | \n",
+ " band_pass | \n",
+ " 1340076 | \n",
+ " 1781076 | \n",
+ " 0.011060 | \n",
+ " 0.008109 | \n",
+ " [0.54091203] | \n",
+ " [0.0680251] | \n",
+ " 0.389929 | \n",
+ " 0.801008 | \n",
+ " ecstasy | \n",
+ "
\n",
+ " \n",
+ " 30 | \n",
+ " sum | \n",
+ " pred_mix | \n",
+ " band_pass | \n",
+ " 5376628 | \n",
+ " 5817628 | \n",
+ " 0.005921 | \n",
+ " 0.016681 | \n",
+ " [0.] | \n",
+ " [0.] | \n",
+ " -0.523093 | \n",
+ " 25.474741 | \n",
+ " ecstasy | \n",
+ "
\n",
+ " \n",
+ " 31 | \n",
+ " sum | \n",
+ " ref | \n",
+ " band_pass | \n",
+ " 1340076 | \n",
+ " 1781076 | \n",
+ " 0.011060 | \n",
+ " 0.008109 | \n",
+ " [0.54091203] | \n",
+ " [0.0680251] | \n",
+ " 0.389929 | \n",
+ " 25.474741 | \n",
+ " ecstasy | \n",
+ "
\n",
+ " \n",
+ " 32 | \n",
+ " diffmst-16 | \n",
+ " pred_mix | \n",
+ " pan_left | \n",
+ " 5376628 | \n",
+ " 5817628 | \n",
+ " 0.005214 | \n",
+ " 0.014098 | \n",
+ " [0.7043585] | \n",
+ " [-0.97385424] | \n",
+ " 0.681030 | \n",
+ " 0.145156 | \n",
+ " ecstasy | \n",
+ "
\n",
+ " \n",
+ " 33 | \n",
+ " diffmst-16 | \n",
+ " ref | \n",
+ " pan_left | \n",
+ " 1340076 | \n",
+ " 1781076 | \n",
+ " 0.009433 | \n",
+ " 0.010431 | \n",
+ " [0.7444616] | \n",
+ " [-0.97688454] | \n",
+ " 0.773411 | \n",
+ " 0.145156 | \n",
+ " ecstasy | \n",
+ "
\n",
+ " \n",
+ " 34 | \n",
+ " sum | \n",
+ " pred_mix | \n",
+ " pan_left | \n",
+ " 5376628 | \n",
+ " 5817628 | \n",
+ " 0.005921 | \n",
+ " 0.016681 | \n",
+ " [0.] | \n",
+ " [0.] | \n",
+ " -0.523093 | \n",
+ " 35.710136 | \n",
+ " ecstasy | \n",
+ "
\n",
+ " \n",
+ " 35 | \n",
+ " sum | \n",
+ " ref | \n",
+ " pan_left | \n",
+ " 1340076 | \n",
+ " 1781076 | \n",
+ " 0.009433 | \n",
+ " 0.010431 | \n",
+ " [0.7444616] | \n",
+ " [-0.97688454] | \n",
+ " 0.773411 | \n",
+ " 35.710136 | \n",
+ " ecstasy | \n",
+ "
\n",
+ " \n",
+ " 36 | \n",
+ " diffmst-16 | \n",
+ " pred_mix | \n",
+ " high_pass | \n",
+ " 5376628 | \n",
+ " 5817628 | \n",
+ " 0.003912 | \n",
+ " 0.023446 | \n",
+ " [0.64094937] | \n",
+ " [0.39073116] | \n",
+ " 0.639729 | \n",
+ " 0.519041 | \n",
+ " ecstasy | \n",
+ "
\n",
+ " \n",
+ " 37 | \n",
+ " diffmst-16 | \n",
+ " ref | \n",
+ " high_pass | \n",
+ " 1340076 | \n",
+ " 1781076 | \n",
+ " 0.006252 | \n",
+ " 0.020270 | \n",
+ " [0.10621171] | \n",
+ " [-0.03977505] | \n",
+ " 0.641428 | \n",
+ " 0.519041 | \n",
+ " ecstasy | \n",
+ "
\n",
+ " \n",
+ " 38 | \n",
+ " sum | \n",
+ " pred_mix | \n",
+ " high_pass | \n",
+ " 5376628 | \n",
+ " 5817628 | \n",
+ " 0.005921 | \n",
+ " 0.016681 | \n",
+ " [0.] | \n",
+ " [0.] | \n",
+ " -0.523093 | \n",
+ " 30.196657 | \n",
+ " ecstasy | \n",
+ "
\n",
+ " \n",
+ " 39 | \n",
+ " sum | \n",
+ " ref | \n",
+ " high_pass | \n",
+ " 1340076 | \n",
+ " 1781076 | \n",
+ " 0.006252 | \n",
+ " 0.020270 | \n",
+ " [0.10621171] | \n",
+ " [-0.03977505] | \n",
+ " 0.641428 | \n",
+ " 30.196657 | \n",
+ " ecstasy | \n",
+ "
\n",
+ " \n",
+ " 40 | \n",
+ " diffmst-16 | \n",
+ " pred_mix | \n",
+ " high_pass | \n",
+ " 4989668 | \n",
+ " 5430668 | \n",
+ " 0.003969 | \n",
+ " 0.018614 | \n",
+ " [0.53428715] | \n",
+ " [-0.13642734] | \n",
+ " 0.665139 | \n",
+ " 0.095481 | \n",
+ " haunted-aged | \n",
+ "
\n",
+ " \n",
+ " 41 | \n",
+ " diffmst-16 | \n",
+ " ref | \n",
+ " high_pass | \n",
+ " 6420140 | \n",
+ " 6861140 | \n",
+ " 0.008070 | \n",
+ " 0.020721 | \n",
+ " [0.3847169] | \n",
+ " [0.00603298] | \n",
+ " 0.708236 | \n",
+ " 0.095481 | \n",
+ " haunted-aged | \n",
+ "
\n",
+ " \n",
+ " 42 | \n",
+ " sum | \n",
+ " pred_mix | \n",
+ " high_pass | \n",
+ " 4989668 | \n",
+ " 5430668 | \n",
+ " 0.005992 | \n",
+ " 0.016583 | \n",
+ " [0.] | \n",
+ " [0.] | \n",
+ " -0.519098 | \n",
+ " 32.275440 | \n",
+ " haunted-aged | \n",
+ "
\n",
+ " \n",
+ " 43 | \n",
+ " sum | \n",
+ " ref | \n",
+ " high_pass | \n",
+ " 6420140 | \n",
+ " 6861140 | \n",
+ " 0.008070 | \n",
+ " 0.020721 | \n",
+ " [0.3847169] | \n",
+ " [0.00603298] | \n",
+ " 0.708236 | \n",
+ " 32.275440 | \n",
+ " haunted-aged | \n",
+ "
\n",
+ " \n",
+ " 44 | \n",
+ " diffmst-16 | \n",
+ " pred_mix | \n",
+ " low_pass | \n",
+ " 4989668 | \n",
+ " 5430668 | \n",
+ " 0.006412 | \n",
+ " 0.015502 | \n",
+ " [0.02812384] | \n",
+ " [-0.00424067] | \n",
+ " 0.675351 | \n",
+ " 0.180971 | \n",
+ " haunted-aged | \n",
+ "
\n",
+ " \n",
+ " 45 | \n",
+ " diffmst-16 | \n",
+ " ref | \n",
+ " low_pass | \n",
+ " 6420140 | \n",
+ " 6861140 | \n",
+ " 0.012228 | \n",
+ " 0.010140 | \n",
+ " [0.16704528] | \n",
+ " [-0.00012205] | \n",
+ " 0.781249 | \n",
+ " 0.180971 | \n",
+ " haunted-aged | \n",
+ "
\n",
+ " \n",
+ " 46 | \n",
+ " sum | \n",
+ " pred_mix | \n",
+ " low_pass | \n",
+ " 4989668 | \n",
+ " 5430668 | \n",
+ " 0.005992 | \n",
+ " 0.016583 | \n",
+ " [0.] | \n",
+ " [0.] | \n",
+ " -0.519098 | \n",
+ " 33.682160 | \n",
+ " haunted-aged | \n",
+ "
\n",
+ " \n",
+ " 47 | \n",
+ " sum | \n",
+ " ref | \n",
+ " low_pass | \n",
+ " 6420140 | \n",
+ " 6861140 | \n",
+ " 0.012228 | \n",
+ " 0.010140 | \n",
+ " [0.16704528] | \n",
+ " [-0.00012205] | \n",
+ " 0.781249 | \n",
+ " 33.682160 | \n",
+ " haunted-aged | \n",
+ "
\n",
+ " \n",
+ " 48 | \n",
+ " diffmst-16 | \n",
+ " pred_mix | \n",
+ " band_pass | \n",
+ " 4989668 | \n",
+ " 5430668 | \n",
+ " 0.006382 | \n",
+ " 0.013647 | \n",
+ " [0.34050697] | \n",
+ " [-0.26903212] | \n",
+ " 0.628616 | \n",
+ " 1.383198 | \n",
+ " haunted-aged | \n",
+ "
\n",
+ " \n",
+ " 49 | \n",
+ " diffmst-16 | \n",
+ " ref | \n",
+ " band_pass | \n",
+ " 6420140 | \n",
+ " 6861140 | \n",
+ " 0.009437 | \n",
+ " 0.008510 | \n",
+ " [0.02690221] | \n",
+ " [-0.0218279] | \n",
+ " 0.304304 | \n",
+ " 1.383198 | \n",
+ " haunted-aged | \n",
+ "
\n",
+ " \n",
+ " 50 | \n",
+ " sum | \n",
+ " pred_mix | \n",
+ " band_pass | \n",
+ " 4989668 | \n",
+ " 5430668 | \n",
+ " 0.005992 | \n",
+ " 0.016583 | \n",
+ " [0.] | \n",
+ " [0.] | \n",
+ " -0.519098 | \n",
+ " 23.336586 | \n",
+ " haunted-aged | \n",
+ "
\n",
+ " \n",
+ " 51 | \n",
+ " sum | \n",
+ " ref | \n",
+ " band_pass | \n",
+ " 6420140 | \n",
+ " 6861140 | \n",
+ " 0.009437 | \n",
+ " 0.008510 | \n",
+ " [0.02690221] | \n",
+ " [-0.0218279] | \n",
+ " 0.304304 | \n",
+ " 23.336586 | \n",
+ " haunted-aged | \n",
+ "
\n",
+ " \n",
+ " 52 | \n",
+ " diffmst-16 | \n",
+ " pred_mix | \n",
+ " pan_right | \n",
+ " 4989668 | \n",
+ " 5430668 | \n",
+ " 0.004372 | \n",
+ " 0.016404 | \n",
+ " [0.5658581] | \n",
+ " [0.931149] | \n",
+ " 0.756382 | \n",
+ " 0.147527 | \n",
+ " haunted-aged | \n",
+ "
\n",
+ " \n",
+ " 53 | \n",
+ " diffmst-16 | \n",
+ " ref | \n",
+ " pan_right | \n",
+ " 6420140 | \n",
+ " 6861140 | \n",
+ " 0.009281 | \n",
+ " 0.009984 | \n",
+ " [0.7547042] | \n",
+ " [0.9802046] | \n",
+ " 0.831521 | \n",
+ " 0.147527 | \n",
+ " haunted-aged | \n",
+ "
\n",
+ " \n",
+ " 54 | \n",
+ " sum | \n",
+ " pred_mix | \n",
+ " pan_right | \n",
+ " 4989668 | \n",
+ " 5430668 | \n",
+ " 0.005992 | \n",
+ " 0.016583 | \n",
+ " [0.] | \n",
+ " [0.] | \n",
+ " -0.519098 | \n",
+ " 37.195953 | \n",
+ " haunted-aged | \n",
+ "
\n",
+ " \n",
+ " 55 | \n",
+ " sum | \n",
+ " ref | \n",
+ " pan_right | \n",
+ " 6420140 | \n",
+ " 6861140 | \n",
+ " 0.009281 | \n",
+ " 0.009984 | \n",
+ " [0.7547042] | \n",
+ " [0.9802046] | \n",
+ " 0.831521 | \n",
+ " 37.195953 | \n",
+ " haunted-aged | \n",
+ "
\n",
+ " \n",
+ " 56 | \n",
+ " diffmst-16 | \n",
+ " pred_mix | \n",
+ " pan_left | \n",
+ " 4989668 | \n",
+ " 5430668 | \n",
+ " 0.005333 | \n",
+ " 0.016097 | \n",
+ " [0.22387566] | \n",
+ " [-0.75253445] | \n",
+ " 0.722902 | \n",
+ " 0.521889 | \n",
+ " haunted-aged | \n",
+ "
\n",
+ " \n",
+ " 57 | \n",
+ " diffmst-16 | \n",
+ " ref | \n",
+ " pan_left | \n",
+ " 6420140 | \n",
+ " 6861140 | \n",
+ " 0.009256 | \n",
+ " 0.009984 | \n",
+ " [0.75463414] | \n",
+ " [-0.98019147] | \n",
+ " 0.830292 | \n",
+ " 0.521889 | \n",
+ " haunted-aged | \n",
+ "
\n",
+ " \n",
+ " 58 | \n",
+ " sum | \n",
+ " pred_mix | \n",
+ " pan_left | \n",
+ " 4989668 | \n",
+ " 5430668 | \n",
+ " 0.005992 | \n",
+ " 0.016583 | \n",
+ " [0.] | \n",
+ " [0.] | \n",
+ " -0.519098 | \n",
+ " 37.161100 | \n",
+ " haunted-aged | \n",
+ "
\n",
+ " \n",
+ " 59 | \n",
+ " sum | \n",
+ " ref | \n",
+ " pan_left | \n",
+ " 6420140 | \n",
+ " 6861140 | \n",
+ " 0.009256 | \n",
+ " 0.009984 | \n",
+ " [0.75463414] | \n",
+ " [-0.98019147] | \n",
+ " 0.830292 | \n",
+ " 37.161100 | \n",
+ " haunted-aged | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " method audio_type ablation start_idx stop_idx rms \\\n",
+ "0 diffmst-16 pred_mix high_pass 5022137 5463137 0.004074 \n",
+ "1 diffmst-16 ref high_pass 3223724 3664724 0.007692 \n",
+ "2 sum pred_mix high_pass 5022137 5463137 0.006213 \n",
+ "3 sum ref high_pass 3223724 3664724 0.007692 \n",
+ "4 diffmst-16 pred_mix low_pass 5022137 5463137 0.006155 \n",
+ "5 diffmst-16 ref low_pass 3223724 3664724 0.011253 \n",
+ "6 sum pred_mix low_pass 5022137 5463137 0.006213 \n",
+ "7 sum ref low_pass 3223724 3664724 0.011253 \n",
+ "8 diffmst-16 pred_mix band_pass 5022137 5463137 0.007041 \n",
+ "9 diffmst-16 ref band_pass 3223724 3664724 0.016523 \n",
+ "10 sum pred_mix band_pass 5022137 5463137 0.006213 \n",
+ "11 sum ref band_pass 3223724 3664724 0.016523 \n",
+ "12 diffmst-16 pred_mix pan_left 5022137 5463137 0.005618 \n",
+ "13 diffmst-16 ref pan_left 3223724 3664724 0.008366 \n",
+ "14 sum pred_mix pan_left 5022137 5463137 0.006213 \n",
+ "15 sum ref pan_left 3223724 3664724 0.008366 \n",
+ "16 diffmst-16 pred_mix pan_right 5022137 5463137 0.005376 \n",
+ "17 diffmst-16 ref pan_right 3223724 3664724 0.008894 \n",
+ "18 sum pred_mix pan_right 5022137 5463137 0.006213 \n",
+ "19 sum ref pan_right 3223724 3664724 0.008894 \n",
+ "20 diffmst-16 pred_mix low_pass 5376628 5817628 0.006615 \n",
+ "21 diffmst-16 ref low_pass 1340076 1781076 0.012675 \n",
+ "22 sum pred_mix low_pass 5376628 5817628 0.005921 \n",
+ "23 sum ref low_pass 1340076 1781076 0.012675 \n",
+ "24 diffmst-16 pred_mix pan_right 5376628 5817628 0.005277 \n",
+ "25 diffmst-16 ref pan_right 1340076 1781076 0.009609 \n",
+ "26 sum pred_mix pan_right 5376628 5817628 0.005921 \n",
+ "27 sum ref pan_right 1340076 1781076 0.009609 \n",
+ "28 diffmst-16 pred_mix band_pass 5376628 5817628 0.007070 \n",
+ "29 diffmst-16 ref band_pass 1340076 1781076 0.011060 \n",
+ "30 sum pred_mix band_pass 5376628 5817628 0.005921 \n",
+ "31 sum ref band_pass 1340076 1781076 0.011060 \n",
+ "32 diffmst-16 pred_mix pan_left 5376628 5817628 0.005214 \n",
+ "33 diffmst-16 ref pan_left 1340076 1781076 0.009433 \n",
+ "34 sum pred_mix pan_left 5376628 5817628 0.005921 \n",
+ "35 sum ref pan_left 1340076 1781076 0.009433 \n",
+ "36 diffmst-16 pred_mix high_pass 5376628 5817628 0.003912 \n",
+ "37 diffmst-16 ref high_pass 1340076 1781076 0.006252 \n",
+ "38 sum pred_mix high_pass 5376628 5817628 0.005921 \n",
+ "39 sum ref high_pass 1340076 1781076 0.006252 \n",
+ "40 diffmst-16 pred_mix high_pass 4989668 5430668 0.003969 \n",
+ "41 diffmst-16 ref high_pass 6420140 6861140 0.008070 \n",
+ "42 sum pred_mix high_pass 4989668 5430668 0.005992 \n",
+ "43 sum ref high_pass 6420140 6861140 0.008070 \n",
+ "44 diffmst-16 pred_mix low_pass 4989668 5430668 0.006412 \n",
+ "45 diffmst-16 ref low_pass 6420140 6861140 0.012228 \n",
+ "46 sum pred_mix low_pass 4989668 5430668 0.005992 \n",
+ "47 sum ref low_pass 6420140 6861140 0.012228 \n",
+ "48 diffmst-16 pred_mix band_pass 4989668 5430668 0.006382 \n",
+ "49 diffmst-16 ref band_pass 6420140 6861140 0.009437 \n",
+ "50 sum pred_mix band_pass 4989668 5430668 0.005992 \n",
+ "51 sum ref band_pass 6420140 6861140 0.009437 \n",
+ "52 diffmst-16 pred_mix pan_right 4989668 5430668 0.004372 \n",
+ "53 diffmst-16 ref pan_right 6420140 6861140 0.009281 \n",
+ "54 sum pred_mix pan_right 4989668 5430668 0.005992 \n",
+ "55 sum ref pan_right 6420140 6861140 0.009281 \n",
+ "56 diffmst-16 pred_mix pan_left 4989668 5430668 0.005333 \n",
+ "57 diffmst-16 ref pan_left 6420140 6861140 0.009256 \n",
+ "58 sum pred_mix pan_left 4989668 5430668 0.005992 \n",
+ "59 sum ref pan_left 6420140 6861140 0.009256 \n",
+ "\n",
+ " crest_factor stereo_width stereo_imbalance barkspectrum net_AF_loss \\\n",
+ "0 0.021787 [0.3353039] [-0.17681053] 0.653527 0.122524 \n",
+ "1 0.016933 [0.18699424] [-0.28859037] 0.698815 0.122524 \n",
+ "2 0.020306 [0.] [0.] -0.543323 31.851294 \n",
+ "3 0.016933 [0.18699424] [-0.28859037] 0.698815 31.851294 \n",
+ "4 0.015004 [0.22357161] [0.13522747] 0.654838 0.117024 \n",
+ "5 0.013086 [0.2034282] [0.03135305] 0.751761 0.117024 \n",
+ "6 0.020306 [0.] [0.] -0.543323 32.614502 \n",
+ "7 0.013086 [0.2034282] [0.03135305] 0.751761 32.614502 \n",
+ "8 0.013663 [0.15272886] [-0.10040125] 0.531103 0.946850 \n",
+ "9 0.006827 [0.02459108] [-0.16006266] 0.269596 0.946850 \n",
+ "10 0.020306 [0.] [0.] -0.543323 22.868963 \n",
+ "11 0.006827 [0.02459108] [-0.16006266] 0.269596 22.868963 \n",
+ "12 0.014758 [0.30145487] [-0.8022489] 0.679761 0.446519 \n",
+ "13 0.013010 [0.7625234] [-0.97916067] 0.806224 0.446519 \n",
+ "14 0.020306 [0.] [0.] -0.543323 36.472347 \n",
+ "15 0.013010 [0.7625234] [-0.97916067] 0.806224 36.472347 \n",
+ "16 0.015142 [0.38029966] [0.79385364] 0.685591 0.332059 \n",
+ "17 0.013010 [0.7729033] [0.9811843] 0.794389 0.332059 \n",
+ "18 0.020306 [0.] [0.] -0.543323 36.097370 \n",
+ "19 0.013010 [0.7729033] [0.9811843] 0.794389 36.097370 \n",
+ "20 0.012723 [0.13142872] [0.01397015] 0.666661 0.064587 \n",
+ "21 0.009881 [0.18987568] [0.08012921] 0.714403 0.064587 \n",
+ "22 0.016681 [0.] [0.] -0.523093 31.953974 \n",
+ "23 0.009881 [0.18987568] [0.08012921] 0.714403 31.953974 \n",
+ "24 0.012306 [1.0680251] [0.87477577] 0.709141 0.147120 \n",
+ "25 0.010431 [0.77672297] [0.9830406] 0.761654 0.147120 \n",
+ "26 0.016681 [0.] [0.] -0.523093 35.442448 \n",
+ "27 0.010431 [0.77672297] [0.9830406] 0.761654 35.442448 \n",
+ "28 0.009780 [0.09989402] [-0.23715419] 0.580188 0.801008 \n",
+ "29 0.008109 [0.54091203] [0.0680251] 0.389929 0.801008 \n",
+ "30 0.016681 [0.] [0.] -0.523093 25.474741 \n",
+ "31 0.008109 [0.54091203] [0.0680251] 0.389929 25.474741 \n",
+ "32 0.014098 [0.7043585] [-0.97385424] 0.681030 0.145156 \n",
+ "33 0.010431 [0.7444616] [-0.97688454] 0.773411 0.145156 \n",
+ "34 0.016681 [0.] [0.] -0.523093 35.710136 \n",
+ "35 0.010431 [0.7444616] [-0.97688454] 0.773411 35.710136 \n",
+ "36 0.023446 [0.64094937] [0.39073116] 0.639729 0.519041 \n",
+ "37 0.020270 [0.10621171] [-0.03977505] 0.641428 0.519041 \n",
+ "38 0.016681 [0.] [0.] -0.523093 30.196657 \n",
+ "39 0.020270 [0.10621171] [-0.03977505] 0.641428 30.196657 \n",
+ "40 0.018614 [0.53428715] [-0.13642734] 0.665139 0.095481 \n",
+ "41 0.020721 [0.3847169] [0.00603298] 0.708236 0.095481 \n",
+ "42 0.016583 [0.] [0.] -0.519098 32.275440 \n",
+ "43 0.020721 [0.3847169] [0.00603298] 0.708236 32.275440 \n",
+ "44 0.015502 [0.02812384] [-0.00424067] 0.675351 0.180971 \n",
+ "45 0.010140 [0.16704528] [-0.00012205] 0.781249 0.180971 \n",
+ "46 0.016583 [0.] [0.] -0.519098 33.682160 \n",
+ "47 0.010140 [0.16704528] [-0.00012205] 0.781249 33.682160 \n",
+ "48 0.013647 [0.34050697] [-0.26903212] 0.628616 1.383198 \n",
+ "49 0.008510 [0.02690221] [-0.0218279] 0.304304 1.383198 \n",
+ "50 0.016583 [0.] [0.] -0.519098 23.336586 \n",
+ "51 0.008510 [0.02690221] [-0.0218279] 0.304304 23.336586 \n",
+ "52 0.016404 [0.5658581] [0.931149] 0.756382 0.147527 \n",
+ "53 0.009984 [0.7547042] [0.9802046] 0.831521 0.147527 \n",
+ "54 0.016583 [0.] [0.] -0.519098 37.195953 \n",
+ "55 0.009984 [0.7547042] [0.9802046] 0.831521 37.195953 \n",
+ "56 0.016097 [0.22387566] [-0.75253445] 0.722902 0.521889 \n",
+ "57 0.009984 [0.75463414] [-0.98019147] 0.830292 0.521889 \n",
+ "58 0.016583 [0.] [0.] -0.519098 37.161100 \n",
+ "59 0.009984 [0.75463414] [-0.98019147] 0.830292 37.161100 \n",
+ "\n",
+ " file_name \n",
+ "0 by-my-side \n",
+ "1 by-my-side \n",
+ "2 by-my-side \n",
+ "3 by-my-side \n",
+ "4 by-my-side \n",
+ "5 by-my-side \n",
+ "6 by-my-side \n",
+ "7 by-my-side \n",
+ "8 by-my-side \n",
+ "9 by-my-side \n",
+ "10 by-my-side \n",
+ "11 by-my-side \n",
+ "12 by-my-side \n",
+ "13 by-my-side \n",
+ "14 by-my-side \n",
+ "15 by-my-side \n",
+ "16 by-my-side \n",
+ "17 by-my-side \n",
+ "18 by-my-side \n",
+ "19 by-my-side \n",
+ "20 ecstasy \n",
+ "21 ecstasy \n",
+ "22 ecstasy \n",
+ "23 ecstasy \n",
+ "24 ecstasy \n",
+ "25 ecstasy \n",
+ "26 ecstasy \n",
+ "27 ecstasy \n",
+ "28 ecstasy \n",
+ "29 ecstasy \n",
+ "30 ecstasy \n",
+ "31 ecstasy \n",
+ "32 ecstasy \n",
+ "33 ecstasy \n",
+ "34 ecstasy \n",
+ "35 ecstasy \n",
+ "36 ecstasy \n",
+ "37 ecstasy \n",
+ "38 ecstasy \n",
+ "39 ecstasy \n",
+ "40 haunted-aged \n",
+ "41 haunted-aged \n",
+ "42 haunted-aged \n",
+ "43 haunted-aged \n",
+ "44 haunted-aged \n",
+ "45 haunted-aged \n",
+ "46 haunted-aged \n",
+ "47 haunted-aged \n",
+ "48 haunted-aged \n",
+ "49 haunted-aged \n",
+ "50 haunted-aged \n",
+ "51 haunted-aged \n",
+ "52 haunted-aged \n",
+ "53 haunted-aged \n",
+ "54 haunted-aged \n",
+ "55 haunted-aged \n",
+ "56 haunted-aged \n",
+ "57 haunted-aged \n",
+ "58 haunted-aged \n",
+ "59 haunted-aged "
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 73,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "18\n",
+ " method audio_type ablation start_idx stop_idx rms \\\n",
+ "12 diffmst-16 pred_mix pan_left 5022137 5463137 0.005618 \n",
+ "13 diffmst-16 ref pan_left 3223724 3664724 0.008366 \n",
+ "14 sum pred_mix pan_left 5022137 5463137 0.006213 \n",
+ "16 diffmst-16 pred_mix pan_right 5022137 5463137 0.005376 \n",
+ "17 diffmst-16 ref pan_right 3223724 3664724 0.008894 \n",
+ "\n",
+ " crest_factor stereo_width stereo_imbalance barkspectrum net_AF_loss \\\n",
+ "12 0.014758 [0.30145487] [-0.8022489] 0.679761 0.446519 \n",
+ "13 0.013010 [0.7625234] [-0.97916067] 0.806224 0.446519 \n",
+ "14 0.020306 [0.] [0.] -0.543323 36.472347 \n",
+ "16 0.015142 [0.38029966] [0.79385364] 0.685591 0.332059 \n",
+ "17 0.013010 [0.7729033] [0.9811843] 0.794389 0.332059 \n",
+ "\n",
+ " file_name \n",
+ "12 by-my-side \n",
+ "13 by-my-side \n",
+ "14 by-my-side \n",
+ "16 by-my-side \n",
+ "17 by-my-side \n"
+ ]
+ }
+ ],
+ "source": [
+ "#create a dataframe with ablation = pan_left and ablation = pan_right\n",
+ "\n",
+ "df_pan = df[(df['ablation'] == 'pan_left') | (df['ablation'] == 'pan_right')]\n",
+ "#remove rows with method sum and audio_type =ref and not the ones with either sum and audio_type = ref\n",
+ "df_pan = df_pan[~((df_pan['method'] == 'sum') & (df_pan['audio_type'] != 'pred_mix'))]\n",
+ "print(len(df_pan))\n",
+ "print(df_pan.head())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 74,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " method audio_type ablation start_idx stop_idx rms \\\n",
+ "12 diffmst-16 pred_mix pan_left 5022137 5463137 0.005618 \n",
+ "13 diffmst-16 ref pan_left 3223724 3664724 0.008366 \n",
+ "14 sum pred_mix pan_left 5022137 5463137 0.006213 \n",
+ "16 diffmst-16 pred_mix pan_right 5022137 5463137 0.005376 \n",
+ "17 diffmst-16 ref pan_right 3223724 3664724 0.008894 \n",
+ "\n",
+ " crest_factor stereo_width stereo_imbalance barkspectrum net_AF_loss \\\n",
+ "12 0.014758 0.301455 -0.802249 0.679761 0.446519 \n",
+ "13 0.013010 0.762523 -0.979161 0.806224 0.446519 \n",
+ "14 0.020306 0.000000 0.000000 -0.543323 36.472347 \n",
+ "16 0.015142 0.380300 0.793854 0.685591 0.332059 \n",
+ "17 0.013010 0.772903 0.981184 0.794389 0.332059 \n",
+ "\n",
+ " file_name \n",
+ "12 by-my-side \n",
+ "13 by-my-side \n",
+ "14 by-my-side \n",
+ "16 by-my-side \n",
+ "17 by-my-side \n"
+ ]
+ }
+ ],
+ "source": [
+ "#the values for stereo imbalance and stereo width are list of size 1 saved as string. for ex, [2.34]. We need to convert them to float\n",
+ "df_pan['stereo_imbalance'] = df_pan['stereo_imbalance'].apply(lambda x: float(x[1:-1]))\n",
+ "df_pan['stereo_width'] = df_pan['stereo_width'].apply(lambda x: float(x[1:-1]))\n",
+ "print(df_pan.head())\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 75,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " method audio_type ablation start_idx stop_idx rms \\\n",
+ "12 diffmst-16 pred_mix pan_left 5022137 5463137 0.005618 \n",
+ "13 diffmst-16 ref pan_left 3223724 3664724 0.008366 \n",
+ "14 sum pred_mix pan_left 5022137 5463137 0.006213 \n",
+ "16 diffmst-16 pred_mix pan_right 5022137 5463137 0.005376 \n",
+ "17 diffmst-16 ref pan_right 3223724 3664724 0.008894 \n",
+ "\n",
+ " crest_factor stereo_width stereo_imbalance barkspectrum net_AF_loss \\\n",
+ "12 0.014758 0.301455 -0.802249 0.679761 0.446519 \n",
+ "13 0.013010 0.762523 -0.979161 0.806224 0.446519 \n",
+ "14 0.020306 0.000000 0.000000 -0.543323 36.472347 \n",
+ "16 0.015142 0.380300 0.793854 0.685591 0.332059 \n",
+ "17 0.013010 0.772903 0.981184 0.794389 0.332059 \n",
+ "\n",
+ " file_name audio_type_encoding method_encoding \n",
+ "12 by-my-side 0 0 \n",
+ "13 by-my-side 1 0 \n",
+ "14 by-my-side 0 1 \n",
+ "16 by-my-side 0 0 \n",
+ "17 by-my-side 1 0 \n"
+ ]
+ }
+ ],
+ "source": [
+ "#encode audio_type and method to numbers\n",
+ "from sklearn.preprocessing import LabelEncoder\n",
+ "le = LabelEncoder()\n",
+ "df_pan['audio_type_encoding'] = le.fit_transform(df_pan['audio_type'])\n",
+ "df_pan['method_encoding'] = le.fit_transform(df_pan['method'])\n",
+ "print(df_pan.head())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 89,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "#find the type of the element of first row in method_encoding\n",
+ "print(type(df_pan['audio_type'].iloc[0]))\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 123,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n",
+ "posx and posy should be finite values\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ "