diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000..e2a10ab --- /dev/null +++ b/.gitignore @@ -0,0 +1,18 @@ +lightning_logs +plots/ +test/ +*.ckpt +scratch/ +outputs/ +.DS_Store +*.vscode +env/ +*.egg-info +__pycache__ +.ipynb_checkpoints +examples/ +med/ +audio/ +outputs/ +env_test/ +checkpoints/ \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100755 index 0000000..de324e1 --- /dev/null +++ b/LICENSE @@ -0,0 +1,25 @@ +ADOBE RESEARCH LICENSE + +This license (the "License") between Adobe Inc., having a place of business at 345 Park Avenue, San Jose, California 95110-2704 ("Adobe"), and you, the individual or entity exercising rights under this License ("you" or "your"), sets forth the terms for your use of certain research materials that are owned by Adobe (the "Licensed Materials"). By exercising rights under this License, you accept and agree to be bound by its terms. If you are exercising rights under this license on behalf of an entity, then "you" means you and such entity, and you (personally) represent and warrant that you (personally) have all necessary authority to bind that entity to the terms of this License. + +1. GRANT OF LICENSE. + +1.1 Adobe grants you a nonexclusive, worldwide, royalty-free, fully paid license to (A) reproduce, use, modify, and publicly display and perform the Licensed Materials for noncommercial research purposes only; and (B) redistribute the Licensed Materials, and modifications or derivative works thereof, for noncommercial research purposes only, provided that you give recipients a copy of this License. + +1.2 You may add your own copyright statement to your modifications and may provide additional or different license terms for use, reproduction, modification, public display and performance, and redistribution of your modifications and derivative works, provided that such license terms limit the use, reproduction, modification, public display and performance, and redistribution of such modifications and derivative works to noncommercial research purposes only. + +1.3 For purposes of this License, noncommercial research purposes include academic research, teaching, and testing, but do not include commercial licensing or distribution, development of commercial products, or any other activity which results in commercial gain. + +2. OWNERSHIP AND ATTRIBUTION. Adobe and its licensors own all right, title, and interest in the Licensed Materials. You must keep intact any copyright or other notices or disclaimers in the Licensed Materials. + +3. DISCLAIMER OF WARRANTIES. THE LICENSED MATERIALS ARE PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND. THE ENTIRE RISK AS TO THE RESULTS AND PERFORMANCE OF THE LICENSED MATERIALS IS ASSUMED BY YOU. ADOBE DISCLAIMS ALL WARRANTIES, EXPRESS, IMPLIED OR STATUTORY, WITH REGARD TO ANY LICENSED MATERIALS PROVIDED UNDER THIS LICENSE, INCLUDING, BUT NOT LIMITED TO, ANY IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NONINFRINGEMENT OF THIRD-PARTY RIGHTS. + +4. LIMITATION OF LIABILITY. IN NO EVENT WILL ADOBE BE LIABLE FOR ANY ACTUAL, INCIDENTAL, SPECIAL OR CONSEQUENTIAL DAMAGES OF ANY NATURE WHATSOEVER, INCLUDING WITHOUT LIMITATION, LOSS OF PROFITS OR OTHER COMMERCIAL LOSS, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF ANY LICENSED MATERIALS PROVIDED UNDER THIS LICENSE, EVEN IF ADOBE HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +5. TERM AND TERMINATION. + +5.1 The License is effective upon acceptance by you and will remain in effect unless terminated earlier as permitted under this License. + +5.2 If you breach any material provision of this License, then your rights will terminate immediately. + +5.3 All clauses which by their nature should survive the termination of this License will survive such termination. In addition, and without limiting the generality of the preceding sentence, Sections 2 (Ownership and Attribution), 3 (Disclaimer of Warranties), 4 (Limitation of Liability) will survive termination of this License. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100755 index 0000000..c287670 --- /dev/null +++ b/README.md @@ -0,0 +1,290 @@ +
+ +# DeepAFx-ST + +Style transfer of audio effects with differentiable signal processing + + +[![Demo](https://img.shields.io/badge/Web-Demo-blue)](https://csteinmetz1.github.io/DeepAFx-ST) +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](TBD) +[![arXiv](https://img.shields.io/badge/arXiv-2010.04237-b31b1b.svg)](TBD) + + + +[Christian J. Steinmetz](http://Christiansteinmetz.com)1*, [Nicholas J. Bryan](https://ccrma.stanford.edu/~njb/)2, and [Joshua D. Reiss](http://www.eecs.qmul.ac.uk/~josh/)1 + +1 Centre for Digital Music, Queen Mary University of London
+2 Adobe Research
+*Work performed in-part while an intern at Adobe Research. + + +[![Demo Video](docs/deepafx-st-headline.png)](https://youtu.be/IZp455wiMk4) + + +
+ + + + + + +- [Abstract](#abstract) +- [Install & Usage](#install--usage) +- [Inference](#inference) +- [Training](#training) +- [Style evaluation](#style-evaluation) +- [License](#license) + + + +## Abstract +We present a framework that can impose the audio effects and production style from one recording to another by example with the goal of simplifying the audio production process. We train a deep neural network to analyze an input recording and a style reference recording, and predict the control parameters of audio effects used to render the output. +In contrast to past work, we integrate audio effects as differentiable operators in our framework, perform backpropagation through audio effects, and optimize end-to-end using an audio-domain loss. We use a self-supervised training strategy enabling automatic control of audio effects without the use of any labeled or paired training data. We survey a range of existing and new approaches for differentiable signal processing, showing how each can be integrated into our framework while discussing their trade-offs. We evaluate our approach on both speech and music tasks, demonstrating that our approach generalizes both to unseen recordings and even to sample rates different than those seen during training. Our approach produces convincing production style transfer results with the ability to transform input recordings to produced recordings, yielding audio effect control parameters that enable interpretability and user interaction. + +For more details, please see: +"[Style Transfer of Audio Effects with Differentiable Signal Processing](TBD)", [Christian J. Steinmetz](http://Christiansteinmetz.com), [Nicholas J. Bryan](https://ccrma.stanford.edu/~njb/), [Joshua D. Reiss](http://www.eecs.qmul.ac.uk/~josh/). arXiv, 2022. If you use ideas or code from this work, pleace cite our paper: + +```BibTex +@article{Steinmetz2022DeepAFxST, + title={Style Transfer of Audio Effects with Differentiable Signal Processing}, + author={Christian J. Steinmetz and Nicholas J. Bryan and Joshua D. Reiss}, + year={2022}, + archivePrefix={arXiv}, + primaryClass={cs.SD} +} +``` + + + + + + +## Install & Usage + +Clone the repo, create a virtual environment, and then install the `deepafx_st` package. + +``` +cd + +# Option 1: Using virtual envs +python -m venv env/ +source env/bin/activate + +# Option 2: Using conda +conda create -n deepafx-st python=3.8 -y +conda activate deepafx-st + + +# Update pip and install +pip install --upgrade pip +pip install --pre -e . + +# Optional if using AWS for data +pip install awscli + +# Linux install +apt-get install libsndfile1 +apt-get install sox +apt-get install ffmpeg +apt-get install wget +``` + +Download pretrained models and example files and untar in one shot +```Bash +cd +wget https://github.com/adobe-research/DeepAFx-ST/releases/download/v0.1.0/checkpoints_and_examples.tar.gz -O - | tar -xz + +``` +Note, you can also find our pretrained checkpoints via the Github UI stored in a tagged release at [https://github.com/adobe-research/DeepAFx-ST/tags](https://github.com/adobe-research/DeepAFx-ST/tags). + +After download and untar'ing, your `checkpoint` and `examples` folder structures should be the following: +```Bash +/checkpoints/README.md +/checkpoints/cdpam/ +/checkpoints/probes/ +/checkpoints/proxies/ +/checkpoints/style/ +/examples/voice_raw.wav +/examples/voice_produced.wav +``` + + +## Inference + +Apply pre-trained models to your own audio examples with the `process.py` script. +Simply call the `scripts/process.py` passing your input audio `-i` along with your reference `-r` and the path to a pretrained model checkpoint `-c. + +``` +cd +python scripts/process.py -i .wav -r .wav -c + +# Speech models + +# Autodiff speech model +python scripts/process.py -i examples/voice_raw.wav -r examples/voice_produced.wav -c ./checkpoints/style/libritts/autodiff/lightning_logs/version_1/checkpoints/epoch=367-step=1226911-val-libritts-autodiff.ckpt + +# Proxy0 speech model +python scripts/process.py -i examples/voice_raw.wav -r examples/voice_produced.wav -c ./checkpoints/style/libritts/proxy0/lightning_logs/version_0/checkpoints/epoch\=327-step\=1093551-val-libritts-proxy0.ckpt + +# Proxy2 speech model +python scripts/process.py -i examples/voice_raw.wav -r examples/voice_produced.wav -c ./checkpoints/style/libritts/proxy2/lightning_logs/version_0/checkpoints/epoch\=84-step\=283389-val-libritts-proxy2.ckpt + +# SPSA speech model +python scripts/process.py -i examples/voice_raw.wav -r examples/voice_produced.wav -c checkpoints/style/libritts/spsa/lightning_logs/version_2/checkpoints/epoch\=367-step\=1226911-val-libritts-spsa.ckpt + +# TCN1 speech model +python scripts/process.py -i examples/voice_raw.wav -r examples/voice_produced.wav -c checkpoints/style/libritts/tcn1/lightning_logs/version_1/checkpoints/epoch\=367-step\=1226911-val-libritts-tcn1.ckpt + +# TCN2 speech model +python scripts/process.py -i examples/voice_raw.wav -r examples/voice_produced.wav -c checkpoints/style/libritts/tcn2/lightning_logs/version_1/checkpoints/epoch\=396-step\=1323597-val-libritts-tcn2.ckpt + +# Music models + +# Autodiff music model +python scripts/process.py -i examples/voice_raw.wav -r examples/voice_produced.wav -c checkpoints/style/jamendo/autodiff/lightning_logs/version_0/checkpoints/epoch\=362-step\=1210241-val-jamendo-autodiff.ckpt + +# Proxy0 music model +python scripts/process.py -i examples/voice_raw.wav -r examples/voice_produced.wav -c checkpoints/style/jamendo/proxy0/lightning_logs/version_0/checkpoints/epoch\=362-step\=1210241-val-jamendo-proxy0.ckpt + +# proxy0m music model +python scripts/process.py -i examples/voice_raw.wav -r examples/voice_produced.wav -c checkpoints/style/jamendo/proxy0m/lightning_logs/version_0/checkpoints/epoch\=331-step\=276887-val-jamendo-proxy0.ckpt + +# Proxy2 music model +python scripts/process.py -i examples/voice_raw.wav -r examples/voice_produced.wav -c checkpoints/style/jamendo/proxy2/lightning_logs/version_0/checkpoints/epoch\=8-step\=30005-val-jamendo-proxy2.ckpt + +# Proxy2m music model +python scripts/process.py -i examples/voice_raw.wav -r examples/voice_produced.wav -c checkpoints/style/jamendo/proxy2m/lightning_logs/version_0/checkpoints/epoch\=341-step\=285227-val-jamendo-proxy2.ckpt + +# SPSA music model +python scripts/process.py -i examples/voice_raw.wav -r examples/voice_produced.wav -c checkpoints/style/jamendo/spsa/lightning_logs/version_0/checkpoints/epoch\=362-step\=1210241-val-jamendo-spsa.ckpt + +# TCN1 music model +python scripts/process.py -i examples/voice_raw.wav -r examples/voice_produced.wav -c checkpoints/style/jamendo/tcn1/lightning_logs/version_0/checkpoints/epoch\=362-step\=1210241-val-jamendo-tcn1.ckpt + +# TCN2 music model +python scripts/process.py -i examples/voice_raw.wav -r examples/voice_produced.wav -c checkpoints/style/jamendo/tcn2/lightning_logs/version_0/checkpoints/epoch\=286-step\=956857-val-jamendo-tcn2.ckpt + +``` + +## Training + +### Datasets + +Training and evaluating the models will require one or more of the datasets. Download all the datasets with the following. + +``` +python scripts/download.py --datasets daps vctk jamendo libritts musdb --output /path/to/output --download --process +``` + +You can download individual datasets if desired. + +``` +python scripts/download.py --datasets daps --output /path/to/output --download --process +``` +Note, data download can take several days due to the dataset server speeds. We recommend downloading once and making your own storage setup. You will need approx. 1TB of local storage space to download and pre-process all datasets. + +For the style classifcation task we need to render the synthetic style datasets. This can be done for DAPS and MUSDB18 using the [`scripts/run_generate_styles.sh`](scripts/run_generate_styles.sh) script. +You will need to update the paths in this script to reflect your local file system and then call the script. + +``` +./script/run_generate_styles.sh` +`` + +### Style transfer + +A number of predefined training configurations are defined in bash scripts in the `configs/` directory. +We perform experiments on speech using [LibriTTS] and on music using the [MTG-Jamendo dataset](). +By default, this will train 6 different model configurations using a different method for differentiable . +This will place one job on each GPU, and assumes at least 6 GPUs, each with at least 16 GB of VRAM. +You can launch training by calling the appropriate + +``` +./configs/train_all_libritts_style.sh +./configs/train_all_jamendo_style.sh +``` +Note, you will need to modify the data paths in the scripts above to fit your setup. + +There are four main configurations for training the style transfer models. +This is specified by the `--processor_model` flag when launching the training script, and +must be one of the following: + +1. `tcn1` - End-to-end audio processing neural network (1 network) with control parameters. +2. `tcn2` - End-to-end audio processing neural network (2 networks) with control parameters. +3. `proxy0` - Neural network proxies with control parameters for audio processing. +4. `proxy2` - Neural network proxies with control parameters for audio processing. +5. `spsa` - Gradient estimation with DSP audio effects using SPSA methods. +6. `autodiff` - DSP audio effects implemnted directly in PyTorch. + + +### Proxy training + +If desired you can also re-train the neural network audio effect proxies. +Training the controller with neural proxies requires first pretrianing the proxy networks to emulate a parametric EQ and dynamic range compressor. +Calling the proxy pre-training scripts in the `configs/` directly will train these models using the training set from VCTK. + +``` +./configs/train_all_proxies.sh +``` + +### Probe training + +Linear probes for the production style classification task can be trained with the following script. + +``` +./configs/train_all_probes.sh +``` +Note, some additional data pre-processing is required and needs updating. + +## Style evaluation + +Evaluating the pretrained models along with the baselines is carried out with the `eval.py` script. +A predefined evaluation configuration to reproduce our results can be called as follows. +Be sure to update the paths at the top of the script to reflect the location of the datasets. + +``` +./configs/eval_style.sh +``` + +To evaluate a set of models call the script passing the directory containing the pretrained checkpoints. +The following example demonstrates evaluating models trained on daps. + +### Probe evaluation + +Pretrained linear probes can be evaluated separately with their own script. + +``` +./configs/eval_probes.sh +``` + +### Timing + +We compute timings on both CPU and GPU for the different approaches using `python scripts/timing.py`. + +``` +rb_infer : sec/step 0.0186 0.0037 RTF +dsp_infer : sec/step 0.0172 0.0034 RTF +autodiff_cpu_infer : sec/step 0.0295 0.0059 RTF +autodiff_gpu_infer : sec/step 0.0049 0.0010 RTF +tcn1_cpu_infer : sec/step 0.6580 0.1316 RTF +tcn2_cpu_infer : sec/step 1.3409 0.2682 RTF +tcn1_gpu_infer : sec/step 0.0114 0.0023 RTF +tcn2_gpu_infer : sec/step 0.0223 0.0045 RTF +autodiff_gpu_grad : sec/step 0.3086 0.0617 RTF +np_norm_gpu_grad : sec/step 0.4346 0.0869 RTF +np_hh_gpu_grad : sec/step 0.4379 0.0876 RTF +np_fh_gpu_grad : sec/step 0.4339 0.0868 RTF +tcn1_gpu_grad : sec/step 0.4382 0.0876 RTF +tcn2_gpu_grad : sec/step 0.6424 0.1285 RTF +spsa_gpu_grad : sec/step 0.4132 0.0826 RTF +``` + +The above results were from a machine with the following configuration. + +``` +Intel(R) Xeon(R) CPU E5-2623 v3 @ 3.00GHz (16 core) +GeForce GTX 1080 Ti +``` + + +## License +Unless otherwise specified via local comments per file, all code and models are licensed via the [Adobe Research License](LICENSE). Copyright (c) Adobe Systems Incorporated. All rights reserved. diff --git a/configs/eval_probes.sh b/configs/eval_probes.sh new file mode 100755 index 0000000..1dbd8ca --- /dev/null +++ b/configs/eval_probes.sh @@ -0,0 +1,19 @@ +checkpoint_dir="./checkpoints" +root_dir="/path/to/data" # path to audio datasets +output_dir="/path/to/data/eval" # path to store audio utputs + +CUDA_VISIBLE_DEVICES=0 python scripts/eval_probes.py \ +--ckpt_dir "$checkpoint_dir/probes/speech" \ +--eval_dataset "$root_dir/daps_24000_styles_100/" \ +--subset test \ +--audio_type speech \ +--output_dir probes \ +--gpu \ + +CUDA_VISIBLE_DEVICES=0 python scripts/eval_probes.py \ +--ckpt_dir "$checkpoint_dir/probes/music" \ +--eval_dataset "$root_dir/musdb18_44100_styles_100/" \ +--audio_type music \ +--subset test \ +--output_dir probes \ +--gpu \ \ No newline at end of file diff --git a/configs/eval_style.sh b/configs/eval_style.sh new file mode 100755 index 0000000..c090e64 --- /dev/null +++ b/configs/eval_style.sh @@ -0,0 +1,149 @@ +gpu_id=0 +num_examples=1000 # number of evaluation examples per dataset +checkpoint_dir="./checkpoints" +root_dir="/path/to/data" # path to audio datasets +output_dir="/path/to/data/eval" # path to store audio outputs + +# ----------------------- LibriTTS ----------------------- +CUDA_VISIBLE_DEVICES="$gpu_id" python scripts/eval_style.py \ +"$checkpoint_dir/style/libritts/" \ +--root_dir "$root_dir" \ +--gpu \ +--dataset libritts \ +--dataset_dir "LibriTTS/train_clean_360_24000c" \ +--spsa_version 2 \ +--tcn1_version 1 \ +--autodiff_version 1 \ +--tcn2_version 1 \ +--subset "test" \ +--output "$output_dir" \ +--examples "$num_examples" \ +--save \ + +# ----------------------- DAPS ----------------------- +CUDA_VISIBLE_DEVICES="$gpu_id" python scripts/eval_style.py \ +"$checkpoint_dir/style/libritts/" \ +--root_dir "$root_dir" \ +--gpu \ +--dataset daps \ +--dataset_dir "daps_24000/cleanraw" \ +--spsa_version 2 \ +--tcn1_version 1 \ +--autodiff_version 1 \ +--tcn2_version 1 \ +--subset "train" \ +--output "$output_dir" \ +--examples "$num_examples" \ +--save \ + +# ----------------------- VCTK ----------------------- +CUDA_VISIBLE_DEVICES="$gpu_id" python scripts/eval_style.py \ +"$checkpoint_dir/style/libritts/" \ +--root_dir "$root_dir" \ +--gpu \ +--dataset vctk \ +--dataset_dir "vctk_24000" \ +--spsa_version 2 \ +--tcn1_version 1 \ +--autodiff_version 1 \ +--tcn2_version 1 \ +--subset "train" \ +--examples "$num_examples" \ +--output "$output_dir" \ + +# ----------------------- Jamendo @ 24kHz (test) ----------------------- +CUDA_VISIBLE_DEVICES="$gpu_id" python scripts/eval_style.py \ +"$checkpoint_dir/style/jamendo/" \ +--root_dir "$root_dir" \ +--gpu \ +--dataset jamendo \ +--dataset_dir "mtg-jamendo_24000/" \ +--spsa_version 0 \ +--tcn1_version 0 \ +--autodiff_version 0 \ +--tcn2_version 0 \ +--subset test \ +--save \ +--ext flac \ +--examples "$num_examples" \ +--output "$output_dir" \ + +# ----------------------- Jamendo @ 24kHz (test) ----------------------- +CUDA_VISIBLE_DEVICES="$gpu_id" python scripts/eval_style.py \ +"$checkpoint_dir/style/jamendo/" \ +--root_dir "$root_dir" \ +--gpu \ +--dataset jamendo_44100 \ +--dataset_dir "mtg-jamendo_44100/" \ +--spsa_version 0 \ +--tcn1_version 0 \ +--autodiff_version 0 \ +--tcn2_version 0 \ +--subset test \ +--length 262144 \ +--save \ +--ext wav \ +--examples "$num_examples" \ +--output "$output_dir" \ + +# ----------------------- MUSDB18 @ 24kHz (train) ----------------------- +CUDA_VISIBLE_DEVICES="$gpu_id" python scripts/eval_style.py \ +"$checkpoint_dir/style/jamendo/" \ +--root_dir "$root_dir" \ +--gpu \ +--dataset musdb18_24000 \ +--dataset_dir "musdb18_24000/" \ +--spsa_version 0 \ +--tcn1_version 0 \ +--autodiff_version 0 \ +--tcn2_version 0 \ +--subset train \ +--length 131072 \ +--save \ +--ext wav \ +--examples "$num_examples" \ +--output "$output_dir" \ + +# ----------------------- MUSDB18 @ 44.1kHz (train) ----------------------- +CUDA_VISIBLE_DEVICES="$gpu_id" python scripts/eval_style.py \ +"$checkpoint_dir/style/jamendo/" \ +--root_dir "$root_dir" \ +--gpu \ +--dataset musdb18_44100 \ +--dataset_dir "musdb18_44100/" \ +--spsa_version 0 \ +--tcn1_version 0 \ +--autodiff_version 0 \ +--tcn2_version 0 \ +--subset train \ +--length 262144 \ +--save \ +--ext wav \ +--examples "$num_examples" \ +--output "$output_dir" \ + +# ----------------------- Style case study (SPSA) ----------------------- +## Style case study on DAPS +CUDA_VISIBLE_DEVICES="$gpu_id" python scripts/style_case_study.py \ +--ckpt_paths \ +"$checkpoint_dir/style/libritts/spsa/lightning_logs/version_2/checkpoints/epoch=367-step=1226911-val-libritts-spsa.ckpt" \ +"$checkpoint_dir/style/libritts/autodiff/lightning_logs/version_1/checkpoints/epoch=367-step=1226911-val-libritts-autodiff.ckpt" \ +"$checkpoint_dir/style/libritts/proxy0/lightning_logs/version_0/checkpoints/epoch=327-step=1093551-val-libritts-proxy0.ckpt" \ +--style_audio "$root_dir/daps_24000_styles_1000_diverse/train" \ +--output_dir "$root_dir/style_case_study" \ +--gpu \ +--save \ +--plot \ + +## Style case study on MUSDB18 @ 44.1 kHz +CUDA_VISIBLE_DEVICES="$gpu_id" python scripts/style_case_study.py \ +--ckpt_paths \ +"$checkpoint_dir/style/jamendo/autodiff/lightning_logs/version_0/checkpoints/epoch=362-step=1210241-val-jamendo-autodiff.ckpt" \ +"$checkpoint_dir/style/jamendo/spsa/lightning_logs/version_0/checkpoints/epoch=362-step=1210241-val-jamendo-spsa.ckpt" \ +"$checkpoint_dir/style/jamendo/proxy0/lightning_logs/version_0/checkpoints/epoch=362-step=1210241-val-jamendo-proxy0.ckpt" \ +--style_audio "$root_dir/musdb18_44100_styles_100/train" \ +--output_dir "$root_dir/style_case_study_musdb18" \ +--sample_rate 44100 \ +--gpu \ +--save \ +--plot \ diff --git a/configs/train_all_jamendo_style.sh b/configs/train_all_jamendo_style.sh new file mode 100755 index 0000000..4135439 --- /dev/null +++ b/configs/train_all_jamendo_style.sh @@ -0,0 +1,82 @@ +#!/bin/bash + +root_data_dir=/path/to/data +multi_gpu=0 # set to 1 to launch on sequential GPUs +gpu_id=0 # starting GPU id +# by default start on GPU #1 (id=0) +checkpoint_dir="./checkpoints" + +for processor_model in tcn1 tcn2 spsa proxy0 proxy2 autodiff +do + + if [ "$processor_model" = "tcn1" ]; then + lr=1e-4 + elif [ "$processor_model" = "tcn2" ]; then + lr=1e-4 + elif [ "$processor_model" = "spsa" ]; then + lr=1e-5 + elif [ "$processor_model" = "proxy0" ]; then + lr=1e-4 + elif [ "$processor_model" = "proxy1" ]; then + lr=1e-4 + elif [ "$processor_model" = "proxy2" ]; then + lr=1e-4 + elif [ "$processor_model" = "autodiff" ]; then + lr=1e-4 + else + lr=1e-4 + fi + + + echo "Training $processor_model on GPU $gpu_id with learning rate = $lr" + + CUDA_VISIBLE_DEVICES="$gpu_id" python scripts/train_style.py \ + --processor_model $processor_model \ + --gpus 1 \ + --audio_dir "$root_data_dir" \ + --ext wav \ + --input_dirs "mtg-jamendo_24000/" \ + --style_transfer \ + --buffer_size_gb 1.0 \ + --buffer_reload_rate 2000 \ + --train_frac 0.9 \ + --freq_corrupt \ + --drc_corrupt \ + --sample_rate 24000 \ + --train_length 131072 \ + --train_examples_per_epoch 20000 \ + --val_length 131072 \ + --val_examples_per_epoch 200 \ + --random_scale_input \ + --encoder_model efficient_net \ + --encoder_embed_dim 1024 \ + --encoder_width_mult 1 \ + --recon_losses mrstft l1 \ + --recon_loss_weight 1.0 100.0 \ + --tcn_causal \ + --tcn_nblocks 4 \ + --tcn_dilation_growth 8 \ + --tcn_channel_width 64 \ + --tcn_kernel_size 13 \ + --spsa_epsilon 0.0005 \ + --spsa_verbose \ + --spsa_parallel \ + --proxy_ckpts \ + "$checkpoint_dir/proxies/libritts/peq/lightning_logs/version_1/checkpoints/epoch=111-step=139999-val-libritts-peq.ckpt" \ + "$checkpoint_dir/proxies/libritts/comp/lightning_logs/version_1/checkpoints/epoch=255-step=319999-val-libritts-comp.ckpt" \ + --freeze_proxies \ + --lr "$lr" \ + --num_workers 8 \ + --batch_size 4 \ + --gradient_clip_val 4.0 \ + --max_epochs 400 \ + --accelerator ddp \ + --default_root_dir "$root_data_dir/logs_debug/style/jamendo/$processor_model" + + # set the GPU ID + if [ $multi_gpu -eq 1 ]; then + ((gpu_id=gpu_id+1)) + fi +done + +# --batch_size 6 diff --git a/configs/train_all_libritts_style.sh b/configs/train_all_libritts_style.sh new file mode 100755 index 0000000..2d61b91 --- /dev/null +++ b/configs/train_all_libritts_style.sh @@ -0,0 +1,78 @@ +#!/bin/bash + +root_data_dir=/path/to/data +multi_gpu=0 # set to 1 to launch on sequential GPUs +gpu_id=0 # starting GPU id +# by default start on GPU #1 (id=0) +checkpoint_dir="./checkpoints" + +for processor_model in tcn1 tcn2 spsa proxy0 proxy2 autodiff +do + + if [ "$processor_model" = "tcn1" ]; then + lr=1e-4 + elif [ "$processor_model" = "tcn2" ]; then + lr=1e-4 + elif [ "$processor_model" = "spsa" ]; then + lr=1e-5 # lower learning rate + elif [ "$processor_model" = "proxy0" ]; then + lr=1e-4 + elif [ "$processor_model" = "proxy1" ]; then + lr=1e-4 + elif [ "$processor_model" = "proxy2" ]; then + lr=1e-4 + elif [ "$processor_model" = "autodiff" ]; then + lr=1e-4 + else + lr=1e-4 + fi + + echo "Training $processor_model on GPU $gpu_id with learning rate = $lr" + + CUDA_VISIBLE_DEVICES="$gpu_id" python scripts/train_style.py \ + --processor_model $processor_model \ + --gpus 1 \ + --audio_dir "$root_data_dir/LibriTTS" \ + --input_dirs "train_clean_360_24000c" \ + --style_transfer \ + --buffer_size_gb 1.0 \ + --buffer_reload_rate 2000 \ + --train_frac 0.9 \ + --freq_corrupt \ + --drc_corrupt \ + --sample_rate 24000 \ + --train_length 131072 \ + --train_examples_per_epoch 20000 \ + --val_length 131072 \ + --val_examples_per_epoch 200 \ + --random_scale_input \ + --encoder_model efficient_net \ + --encoder_embed_dim 1024 \ + --encoder_width_mult 1 \ + --recon_losses mrstft l1 \ + --recon_loss_weight 1.0 100.0 \ + --tcn_causal \ + --tcn_nblocks 4 \ + --tcn_dilation_growth 8 \ + --tcn_channel_width 64 \ + --tcn_kernel_size 13 \ + --spsa_epsilon 0.0005 \ + --spsa_verbose \ + --spsa_parallel \ + --proxy_ckpts \ + "$checkpoint_dir/proxies/libritts/peq/lightning_logs/version_1/checkpoints/epoch=111-step=139999-val-libritts-peq.ckpt" \ + "$checkpoint_dir/proxies/libritts/comp/lightning_logs/version_1/checkpoints/epoch=255-step=319999-val-libritts-comp.ckpt" \ + --freeze_proxies \ + --lr "$lr" \ + --num_workers 8 \ + --batch_size 6 \ + --gradient_clip_val 4.0 \ + --max_epochs 400 \ + --accelerator ddp \ + --default_root_dir "$root_data_dir/logs/style/libritts/$processor_model" \ + + # set the GPU ID + if [ $multi_gpu -eq 1 ]; then + ((gpu_id=gpu_id+1)) + fi +done \ No newline at end of file diff --git a/configs/train_all_probes.sh b/configs/train_all_probes.sh new file mode 100755 index 0000000..d31b46c --- /dev/null +++ b/configs/train_all_probes.sh @@ -0,0 +1,79 @@ +#!/bin/bash + +root_data_dir=/path/to/data +multi_gpu=0 # set to 1 to launch on sequential GPUs +gpu_id=0 # starting GPU id +checkpoint_dir="./checkpoints" + +# random_mel openl3 deepafx_st_spsa deepafx_st_proxy0 deepafx_st_autodiff cdpam + +probe_type=linear # always use linear probe + +for audio_type in speech +do + if [ "$audio_type" = "speech" ]; then + audio_dir="daps_24000_styles_100" + deepafx_st_autodiff_ckpt="$checkpoint_dir/style/libritts/autodiff/lightning_logs/version_1/checkpoints/epoch=367-step=1226911-val-libritts-autodiff.ckpt" + deepafx_st_spsa_ckpt="$checkpoint_dir/style/libritts/spsa/lightning_logs/version_2/checkpoints/epoch=367-step=1226911-val-libritts-spsa.ckpt" + deepafx_st_proxy0_ckpt="$checkpoint_dir/style/libritts/proxy0/lightning_logs/version_0/checkpoints/epoch=327-step=1093551-val-libritts-proxy0.ckpt" + elif [ "$audio_type" = "music" ]; then + audio_dir="musdb18_44100_styles_100" + deepafx_st_autodiff_ckpt="$checkpoint_dir/jamendo/style/jamendo/autodiff/lightning_logs/version_0/checkpoints/epoch=362-step=1210241-val-jamendo-autodiff.ckpt" + deepafx_st_spsa_ckpt="$checkpoint_dir/jamendo/style/jamendo/spsa/lightning_logs/version_0/checkpoints/epoch=362-step=1210241-val-jamendo-spsa.ckpt" + deepafx_st_proxy0_ckpt="$checkpoint_dir/jamendo/style/jamendo/proxy0/lightning_logs/version_0/checkpoints/epoch=362-step=1210241-val-jamendo-proxy0.ckpt" + fi + + for encoder_type in random_mel openl3 deepafx_st_spsa deepafx_st_proxy0 deepafx_st_autodiff cdpam + do + + if [ "$encoder_type" = "deepafx_st_autodiff" ]; then + lr=1e-3 + encoder_sample_rate=24000 + elif [ "$encoder_type" = "deepafx_st_spsa" ]; then + lr=1e-3 + encoder_sample_rate=24000 + elif [ "$encoder_type" = "deepafx_st_proxy0" ]; then + lr=1e-3 + encoder_sample_rate=24000 + elif [ "$encoder_type" = "random_mel" ]; then + lr=1e-3 + encoder_sample_rate=24000 + elif [ "$encoder_type" = "openl3" ]; then + lr=1e-3 + encoder_sample_rate=48000 + elif [ "$encoder_type" = "cdpam" ]; then + lr=1e-3 + encoder_sample_rate=22050 + else + lr=1e-3 + fi + + echo "Training $audio_type $encoder_type encoder with $probe_type probe on GPU $gpu_id" + + CUDA_VISIBLE_DEVICES="$gpu_id" python scripts/train_probe.py \ + --gpus 1 \ + --task "style" \ + --audio_dir "$root_data_dir/$audio_dir" \ + --sample_rate 24000 \ + --encoder_sample_rate "$encoder_sample_rate" \ + --encoder_type $encoder_type \ + --deepafx_st_autodiff_ckpt "$deepafx_st_autodiff_ckpt" \ + --deepafx_st_spsa_ckpt "$deepafx_st_spsa_ckpt" \ + --deepafx_st_proxy0_ckpt "$deepafx_st_proxy0_ckpt" \ + --cdpam_ckpt "$checkpoint_dir/cdpam/scratchJNDdefault_best_model.pth" \ + --probe_type $probe_type \ + --lr "$lr" \ + --num_workers 4 \ + --batch_size 16 \ + --gradient_clip_val 200.0 \ + --max_epochs 400 \ + --accelerator ddp \ + --default_root_dir "$root_data_dir/probes/$audio_type/$encoder_type-$probe_type" \ + + # set the GPU ID + if [ $multi_gpu -eq 1 ]; then + ((gpu_id=gpu_id+1)) + fi + + done +done diff --git a/configs/train_all_proxies.sh b/configs/train_all_proxies.sh new file mode 100755 index 0000000..6f35169 --- /dev/null +++ b/configs/train_all_proxies.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +root_data_dir=/path/to/data +multi_gpu=0 # set to 1 to launch on sequential GPUs +gpu_id=0 # starting GPU id + +for processor in peq comp channel +do + + echo "Training $processor proxy on GPU $gpu_id" + + # Single Parametric EQ + CUDA_VISIBLE_DEVICES="$gpu_id" python scripts/train_proxy.py \ + --gpus 1 \ + --input_dir "$root_data_dir/LibriTTS/train_clean_360_24000c" \ + --sample_rate 24000 \ + --train_length 65536 \ + --train_examples_per_epoch 20000 \ + --val_length 65536 \ + --val_examples_per_epoch 200 \ + --buffer_size_gb 1.0 \ + --buffer_reload_rate 2000 \ + --processor $processor \ + --causal \ + --output_gain \ + --nblocks 4 \ + --dilation_growth 8 \ + --channel_width 64 \ + --kernel_size 13 \ + --lr 3e-4 \ + --lr_patience 10 \ + --num_workers 8 \ + --batch_size 16 \ + --gradient_clip_val 10.0 \ + --max_epochs 400 \ + --accelerator ddp \ + --default_root_dir "$root_data_dir/logs/proxies/libritts/$processor" + + # set the GPU ID + if [ $multi_gpu -eq 1 ]; then + ((gpu_id=gpu_id+1)) + fi + +done \ No newline at end of file diff --git a/deepafx_st/__init__.py b/deepafx_st/__init__.py new file mode 100755 index 0000000..cdaf0a5 --- /dev/null +++ b/deepafx_st/__init__.py @@ -0,0 +1,4 @@ +#!/usr/bin/env python +"""Top-level module for deepafx_st""" + +from .version import version as __version__ diff --git a/deepafx_st/callbacks/audio.py b/deepafx_st/callbacks/audio.py new file mode 100755 index 0000000..39b1433 --- /dev/null +++ b/deepafx_st/callbacks/audio.py @@ -0,0 +1,184 @@ +import auraloss +import numpy as np +import pytorch_lightning as pl + +from deepafx_st.callbacks.plotting import plot_multi_spectrum +from deepafx_st.metrics import ( + LoudnessError, + SpectralCentroidError, + CrestFactorError, + PESQ, + MelSpectralDistance, +) + + +class LogAudioCallback(pl.callbacks.Callback): + def __init__(self, num_examples=4, peak_normalize=True, sample_rate=22050): + super().__init__() + self.num_examples = 4 + self.peak_normalize = peak_normalize + + self.metrics = { + "PESQ": PESQ(sample_rate), + "MRSTFT": auraloss.freq.MultiResolutionSTFTLoss( + fft_sizes=[32, 128, 512, 2048, 8192, 32768], + hop_sizes=[16, 64, 256, 1024, 4096, 16384], + win_lengths=[32, 128, 512, 2048, 8192, 32768], + w_sc=0.0, + w_phs=0.0, + w_lin_mag=1.0, + w_log_mag=1.0, + ), + "MSD": MelSpectralDistance(sample_rate), + "SCE": SpectralCentroidError(sample_rate), + "CFE": CrestFactorError(), + "LUFS": LoudnessError(sample_rate), + } + + self.outputs = [] + + def on_validation_batch_end( + self, + trainer, + pl_module, + outputs, + batch, + batch_idx, + dataloader_idx, + ): + """Called when the validation batch ends.""" + + if outputs is not None: + examples = np.min([self.num_examples, outputs["x"].shape[0]]) + self.outputs.append(outputs) + + if batch_idx == 0: + for n in range(examples): + if batch_idx == 0: + self.log_audio( + outputs, + n, + pl_module.hparams.sample_rate, + pl_module.hparams.val_length, + trainer.global_step, + trainer.logger, + ) + + def on_validation_end(self, trainer, pl_module): + metrics = { + "PESQ": [], + "MRSTFT": [], + "MSD": [], + "SCE": [], + "CFE": [], + "LUFS": [], + } + for output in self.outputs: + for metric_name, metric in self.metrics.items(): + try: + val = metric(output["y_hat"], output["y"]) + metrics[metric_name].append(val) + except: + pass + + # log final mean metrics + for metric_name, metric in metrics.items(): + val = np.mean(metric) + trainer.logger.experiment.add_scalar( + f"metrics/{metric_name}", val, trainer.global_step + ) + + # clear outputs + self.outputs = [] + + def compute_metrics(self, metrics_dict, outputs, batch_idx, global_step): + # extract audio + y = outputs["y"][batch_idx, ...].float() + y_hat = outputs["y_hat"][batch_idx, ...].float() + + # compute all metrics + for metric_name, metric in self.metrics.items(): + try: + val = metric(y_hat.view(1, 1, -1), y.view(1, 1, -1)) + metrics_dict[metric_name].append(val) + except: + pass + + def log_audio(self, outputs, batch_idx, sample_rate, n_fft, global_step, logger): + x = outputs["x"][batch_idx, ...].float() + y = outputs["y"][batch_idx, ...].float() + y_hat = outputs["y_hat"][batch_idx, ...].float() + + if self.peak_normalize: + x /= x.abs().max() + y /= y.abs().max() + y_hat /= y_hat.abs().max() + + logger.experiment.add_audio( + f"x/{batch_idx+1}", + x[0:1, :], + global_step, + sample_rate=sample_rate, + ) + + logger.experiment.add_audio( + f"y/{batch_idx+1}", + y[0:1, :], + global_step, + sample_rate=sample_rate, + ) + + logger.experiment.add_audio( + f"y_hat/{batch_idx+1}", + y_hat[0:1, :], + global_step, + sample_rate=sample_rate, + ) + + if "y_ref" in outputs: + y_ref = outputs["y_ref"][batch_idx, ...].float() + + if self.peak_normalize: + y_ref /= y_ref.abs().max() + + logger.experiment.add_audio( + f"y_ref/{batch_idx+1}", + y_ref[0:1, :], + global_step, + sample_rate=sample_rate, + ) + logger.experiment.add_image( + f"spec/{batch_idx+1}", + compare_spectra( + y_hat[0:1, :], + y[0:1, :], + x[0:1, :], + sample_rate=sample_rate, + n_fft=n_fft, + ), + global_step, + ) + + +def compare_spectra( + deepafx_y_hat, y, x, baseline_y_hat=None, sample_rate=44100, n_fft=16384 +): + legend = ["Corrupted"] + signals = [x] + if baseline_y_hat is not None: + legend.append("Baseline") + signals.append(baseline_y_hat) + + legend.append("DeepAFx") + signals.append(deepafx_y_hat) + legend.append("Target") + signals.append(y) + + image = plot_multi_spectrum( + ys=signals, + legend=legend, + sample_rate=sample_rate, + n_fft=n_fft, + ) + + return image diff --git a/deepafx_st/callbacks/ckpt.py b/deepafx_st/callbacks/ckpt.py new file mode 100755 index 0000000..e01d53e --- /dev/null +++ b/deepafx_st/callbacks/ckpt.py @@ -0,0 +1,33 @@ +import os +import sys +import shutil +import pytorch_lightning as pl + + +class CopyPretrainedCheckpoints(pl.callbacks.Callback): + def __init__(self): + super().__init__() + + def on_fit_start(self, trainer, pl_module): + """Before training, move the pre-trained checkpoints + to the current checkpoint directory. + + """ + # copy any pre-trained checkpoints to new directory + if pl_module.hparams.processor_model == "proxy": + pretrained_ckpt_dir = os.path.join( + pl_module.logger.experiment.log_dir, "pretrained_checkpoints" + ) + if not os.path.isdir(pretrained_ckpt_dir): + os.makedirs(pretrained_ckpt_dir) + cp_proxy_ckpts = [] + for proxy_ckpt in pl_module.hparams.proxy_ckpts: + new_ckpt = shutil.copy( + proxy_ckpt, + pretrained_ckpt_dir, + ) + cp_proxy_ckpts.append(new_ckpt) + print(f"Moved checkpoint to {new_ckpt}.") + # overwrite to the paths in current experiment logs + pl_module.hparams.proxy_ckpts = cp_proxy_ckpts + print(pl_module.hparams.proxy_ckpts) diff --git a/deepafx_st/callbacks/params.py b/deepafx_st/callbacks/params.py new file mode 100755 index 0000000..e327671 --- /dev/null +++ b/deepafx_st/callbacks/params.py @@ -0,0 +1,87 @@ +import numpy as np +import pytorch_lightning as pl +import matplotlib.pyplot as plt + +import deepafx_st.utils as utils + + +class LogParametersCallback(pl.callbacks.Callback): + def __init__(self, num_examples=4): + super().__init__() + self.num_examples = 4 + + def on_validation_epoch_start(self, trainer, pl_module): + """At the start of validation init storage for parameters.""" + self.params = [] + + def on_validation_batch_end( + self, + trainer, + pl_module, + outputs, + batch, + batch_idx, + dataloader_idx, + ): + """Called when the validation batch ends. + + Here we log the parameters only from the first batch. + + """ + if outputs is not None and batch_idx == 0: + examples = np.min([self.num_examples, outputs["x"].shape[0]]) + for n in range(examples): + self.log_parameters( + outputs, + n, + pl_module.processor.ports, + trainer.global_step, + trainer.logger, + True if batch_idx == 0 else False, + ) + + def on_validation_epoch_end(self, trainer, pl_module): + pass + + def log_parameters(self, outputs, batch_idx, ports, global_step, logger, log=True): + p = outputs["p"][batch_idx, ...] + + table = "" + + # table += f"""## {plugin["name"]}\n""" + table += "| Index| Name | Value | Units | Min | Max | Default | Raw Value | \n" + table += "|------|------|------:|:------|----:|----:|--------:| ---------:| \n" + + start_idx = 0 + # set plugin parameters based on provided normalized parameters + for port_list in ports: + for pidx, port in enumerate(port_list): + param_max = port["max"] + param_min = port["min"] + param_name = port["name"] + param_default = port["default"] + param_units = port["units"] + + param_val = p[start_idx] + denorm_val = utils.denormalize(param_val, param_max, param_min) + + # add values to table in row + table += f"| {start_idx + 1} | {param_name} " + if np.abs(denorm_val) > 10: + table += f"| {denorm_val:0.1f} " + table += f"| {param_units} " + table += f"| {param_min:0.1f} | {param_max:0.1f} " + table += f"| {param_default:0.1f} " + else: + table += f"| {denorm_val:0.3f} " + table += f"| {param_units} " + table += f"| {param_min:0.3f} | {param_max:0.3f} " + table += f"| {param_default:0.3f} " + + table += f"| {np.squeeze(param_val):0.2f} | \n" + start_idx += 1 + + table += "\n\n" + + if log: + logger.experiment.add_text(f"params/{batch_idx+1}", table, global_step) diff --git a/deepafx_st/callbacks/plotting.py b/deepafx_st/callbacks/plotting.py new file mode 100755 index 0000000..1dc90a0 --- /dev/null +++ b/deepafx_st/callbacks/plotting.py @@ -0,0 +1,126 @@ +import io +import torch +import PIL.Image +import numpy as np +import scipy.signal +import librosa.display +import matplotlib.pyplot as plt + +from torch.functional import Tensor +from torchvision.transforms import ToTensor + + +def compute_comparison_spectrogram( + x: np.ndarray, + y: np.ndarray, + sample_rate: float = 44100, + n_fft: int = 2048, + hop_length: int = 1024, +) -> Tensor: + X = librosa.stft(x, n_fft=n_fft, hop_length=hop_length) + X_db = librosa.amplitude_to_db(np.abs(X), ref=np.max) + + Y = librosa.stft(y, n_fft=n_fft, hop_length=hop_length) + Y_db = librosa.amplitude_to_db(np.abs(Y), ref=np.max) + + fig, axs = plt.subplots(figsize=(9, 6), nrows=2) + img = librosa.display.specshow( + X_db, + ax=axs[0], + hop_length=hop_length, + x_axis="time", + y_axis="log", + sr=sample_rate, + ) + # fig.colorbar(img, ax=axs[0]) + img = librosa.display.specshow( + Y_db, + ax=axs[1], + hop_length=hop_length, + x_axis="time", + y_axis="log", + sr=sample_rate, + ) + # fig.colorbar(img, ax=axs[1]) + + plt.tight_layout() + + buf = io.BytesIO() + plt.savefig(buf, format="jpeg") + buf.seek(0) + image = PIL.Image.open(buf) + image = ToTensor()(image) + plt.close("all") + + return image + + +def plot_multi_spectrum( + ys=None, + Hs=None, + legend=[], + title="Spectrum", + filename=None, + sample_rate=44100, + n_fft=1024, + zero_mean=False, +): + + if Hs is None: + Hs = [] + for y in ys: + X = get_average_spectrum(y, n_fft) + X_sm = smooth_spectrum(X) + Hs.append(X_sm) + + bin_width = (sample_rate / 2) / (n_fft // 2) + freqs = np.arange(0, (sample_rate / 2) + bin_width, step=bin_width) + + fig, ax1 = plt.subplots() + + for idx, H in enumerate(Hs): + H = np.nan_to_num(H) + H = np.clip(H, 0, np.max(H)) + H_dB = 20 * np.log10(H + 1e-8) + if zero_mean: + H_dB -= np.mean(H_dB) + if "Target" in legend[idx]: + ax1.plot(freqs, H_dB, linestyle="--", color="k") + else: + ax1.plot(freqs, H_dB) + + plt.legend(legend) + + ax1.set_xscale("log") + ax1.set_ylim([-80, 0]) + ax1.set_xlim([100, 11000]) + plt.title(title) + plt.ylabel("Magnitude (dB)") + plt.xlabel("Frequency (Hz)") + plt.grid(c="lightgray", which="both") + + if filename is not None: + plt.savefig(f"{filename}.png", dpi=300) + + plt.tight_layout() + + buf = io.BytesIO() + plt.savefig(buf, format="jpeg") + buf.seek(0) + image = PIL.Image.open(buf) + image = ToTensor()(image) + plt.close("all") + + return image + + +def smooth_spectrum(H): + # apply Savgol filter for smoothed target curve + return scipy.signal.savgol_filter(H, 1025, 2) + + +def get_average_spectrum(x, n_fft): + X = torch.stft(x, n_fft, return_complex=True, normalized=True) + X = X.abs() # convert to magnitude + X = X.mean(dim=-1).view(-1) # average across frames + return X diff --git a/deepafx_st/data/audio.py b/deepafx_st/data/audio.py new file mode 100755 index 0000000..1b81607 --- /dev/null +++ b/deepafx_st/data/audio.py @@ -0,0 +1,177 @@ +import os +import glob +import torch +import warnings +import torchaudio +import pyloudnorm as pyln + + +class AudioFile(object): + def __init__(self, filepath, preload=False, half=False, target_loudness=None): + """Base class for audio files to handle metadata and loading. + + Args: + filepath (str): Path to audio file to load from disk. + preload (bool, optional): If set, load audio data into RAM. Default: False + half (bool, optional): If set, store audio data as float16 to save space. Default: False + target_loudness (float, optional): Loudness normalize to dB LUFS value. Default: + """ + super().__init__() + + self.filepath = filepath + self.half = half + self.target_loudness = target_loudness + self.loaded = False + + if preload: + self.load() + num_frames = self.audio.shape[-1] + num_channels = self.audio.shape[0] + else: + metadata = torchaudio.info(filepath) + audio = None + self.sample_rate = metadata.sample_rate + num_frames = metadata.num_frames + num_channels = metadata.num_channels + + self.num_frames = num_frames + self.num_channels = num_channels + + def load(self): + audio, sr = torchaudio.load(self.filepath, normalize=True) + self.audio = audio + self.sample_rate = sr + + if self.target_loudness is not None: + self.loudness_normalize() + + if self.half: + self.audio = audio.half() + + self.loaded = True + + def loudness_normalize(self): + meter = pyln.Meter(self.sample_rate) + + # conver mono to stereo + if self.audio.shape[0] == 1: + tmp_audio = self.audio.repeat(2, 1) + else: + tmp_audio = self.audio + + # measure integrated loudness + input_loudness = meter.integrated_loudness(tmp_audio.numpy().T) + + # compute and apply gain + gain_dB = self.target_loudness - input_loudness + gain_ln = 10 ** (gain_dB / 20.0) + self.audio *= gain_ln + + # check for potentially clipped samples + if self.audio.abs().max() >= 1.0: + warnings.warn("Possible clipped samples in output.") + + +class AudioFileDataset(torch.utils.data.Dataset): + """Base class for audio file datasets loaded from disk. + + Datasets can be either paired or unpaired. A paired dataset requires passing the `target_dir` path. + + Args: + input_dir (List[str]): List of paths to the directories containing input audio files. + target_dir (List[str], optional): List of paths to the directories containing correponding audio files. Default: [] + subset (str, optional): Dataset subset. One of ["train", "val", "test"]. Default: "train" + length (int, optional): Number of samples to load for each example. Default: 65536 + normalize (bool, optional): Normalize audio amplitiude to -1 to 1. Default: True + train_frac (float, optional): Fraction of the files to use for training subset. Default: 0.8 + val_frac (float, optional): Fraction of the files to use for validation subset. Default: 0.1 + preload (bool, optional): Read audio files into RAM at the start of training. Default: False + num_examples_per_epoch (int, optional): Define an epoch as certain number of audio examples. Default: 10000 + ext (str, optional): Expected audio file extension. Default: "wav" + """ + + def __init__( + self, + input_dirs, + target_dirs=[], + subset="train", + length=65536, + normalize=True, + train_per=0.8, + val_per=0.1, + preload=False, + num_examples_per_epoch=10000, + ext="wav", + ): + super().__init__() + self.input_dirs = input_dirs + self.target_dirs = target_dirs + self.subset = subset + self.length = length + self.normalize = normalize + self.train_per = train_per + self.val_per = val_per + self.preload = preload + self.num_examples_per_epoch = num_examples_per_epoch + self.ext = ext + + self.input_filepaths = [] + for input_dir in input_dirs: + search_path = os.path.join(input_dir, f"*.{ext}") + self.input_filepaths += glob.glob(search_path) + self.input_filepaths = sorted(self.input_filepaths) + + self.target_filepaths = [] + for target_dir in target_dirs: + search_path = os.path.join(target_dir, f"*.{ext}") + self.target_filepaths += glob.glob(search_path) + self.target_filepaths = sorted(self.target_filepaths) + + # both sets must have same number of files in paired dataset + assert len(self.target_filepaths) == len(self.input_filepaths) + + # get details about audio files + self.input_files = [] + for input_filepath in self.input_filepaths: + self.input_files.append( + AudioFile(input_filepath, preload=preload, normalize=normalize) + ) + + self.target_files = [] + if target_dir is not None: + for target_filepath in self.target_filepaths: + self.target_files.append( + AudioFile(target_filepath, preload=preload, normalize=normalize) + ) + + def __len__(self): + return self.num_examples_per_epoch + + def __getitem__(self, idx): + """ """ + + # index the current audio file + input_file = self.input_files[idx] + + # load the audio data if needed + if not input_file.loaded: + input_file.load() + + # get a random patch of size `self.length` + start_idx = int(torch.rand() * (input_file.num_frames - self.length)) + stop_idx = start_idx + self.length + input_audio = input_file.audio[:, start_idx:stop_idx] + + # if there is a target file, get it (and load) + if len(self.target_files) > 0: + target_file = self.target_files[idx] + + if not target_file.loaded: + target_file.load() + + # use the same cropping indices + target_audio = target_file.audio[:, start_idx:stop_idx] + + return input_audio, target_audio + else: + return input_audio diff --git a/deepafx_st/data/augmentations.py b/deepafx_st/data/augmentations.py new file mode 100755 index 0000000..93f3fda --- /dev/null +++ b/deepafx_st/data/augmentations.py @@ -0,0 +1,235 @@ +import torch +import torchaudio +import numpy as np + + +def gain(xs, min_dB=-12, max_dB=12): + + gain_dB = (torch.rand(1) * (max_dB - min_dB)) + min_dB + gain_ln = 10 ** (gain_dB / 20) + + for idx, x in enumerate(xs): + xs[idx] = x * gain_ln + + return xs + + +def peaking_filter(xs, sr=44100, frequency=1000, width_q=0.707, gain_db=12): + + # gain_db = ((torch.rand(1) * 6) + 6).numpy().squeeze() + # width_q = (torch.rand(1) * 4).numpy().squeeze() + # frequency = ((torch.rand(1) * 9960) + 40).numpy().squeeze() + + # if torch.rand(1) > 0.5: + # gain_db = -gain_db + + effects = [["equalizer", f"{frequency}", f"{width_q}", f"{gain_db}"]] + + for idx, x in enumerate(xs): + y, sr = torchaudio.sox_effects.apply_effects_tensor( + x, sr, effects, channels_first=True + ) + xs[idx] = y + + return xs + + +def pitch_shift(xs, min_shift=-200, max_shift=200, sr=44100): + + shift = min_shift + (torch.rand(1)).numpy().squeeze() * (max_shift - min_shift) + + effects = [["pitch", f"{shift}"]] + + for idx, x in enumerate(xs): + y, sr = torchaudio.sox_effects.apply_effects_tensor( + x, sr, effects, channels_first=True + ) + xs[idx] = y + + return xs + + +def time_stretch(xs, min_stretch=0.8, max_stretch=1.2, sr=44100): + + stretch = min_stretch + (torch.rand(1)).numpy().squeeze() * ( + max_stretch - min_stretch + ) + + effects = [["tempo", f"{stretch}"]] + for idx, x in enumerate(xs): + y, sr = torchaudio.sox_effects.apply_effects_tensor( + x, sr, effects, channels_first=True + ) + xs[idx] = y + + return xs + + +def frequency_corruption(xs, sr=44100): + + effects = [] + + # apply a random number of peaking bands from 0 to 4s + bands = [[200, 2000], [800, 4000], [2000, 8000], [4000, int((sr // 2) * 0.9)]] + total_gain_db = 0.0 + for band in bands: + if torch.rand(1).sum() > 0.2: + frequency = (torch.randint(band[0], band[1], [1])).numpy().squeeze() + width_q = ((torch.rand(1) * 10) + 0.1).numpy().squeeze() + gain_db = ((torch.rand(1) * 48)).numpy().squeeze() + + if torch.rand(1).sum() > 0.5: + gain_db = -gain_db + + total_gain_db += gain_db + + if np.abs(total_gain_db) >= 24: + continue + + cmd = ["equalizer", f"{frequency}", f"{width_q}", f"{gain_db}"] + effects.append(cmd) + + # low shelf (bass) + if torch.rand(1).sum() > 0.2: + gain_db = ((torch.rand(1) * 24)).numpy().squeeze() + frequency = (torch.randint(20, 200, [1])).numpy().squeeze() + if torch.rand(1).sum() > 0.5: + gain_db = -gain_db + effects.append(["bass", f"{gain_db}", f"{frequency}"]) + + # high shelf (treble) + if torch.rand(1).sum() > 0.2: + gain_db = ((torch.rand(1) * 24)).numpy().squeeze() + frequency = (torch.randint(4000, int((sr // 2) * 0.9), [1])).numpy().squeeze() + if torch.rand(1).sum() > 0.5: + gain_db = -gain_db + effects.append(["treble", f"{gain_db}", f"{frequency}"]) + + for idx, x in enumerate(xs): + y, sr = torchaudio.sox_effects.apply_effects_tensor( + x.view(1, -1) * 10 ** (-48 / 20), sr, effects, channels_first=True + ) + # apply gain back + y *= 10 ** (48 / 20) + + xs[idx] = y + + return xs + + +def dynamic_range_corruption(xs, sr=44100): + """Apply an expander.""" + + attack = (torch.rand([1]).numpy()[0] * 0.05) + 0.001 + release = (torch.rand([1]).numpy()[0] * 0.2) + attack + knee = (torch.rand([1]).numpy()[0] * 12) + 0.0 + + # design the compressor transfer function + start = -100.0 + threshold = -( + (torch.rand([1]).numpy()[0] * 20) + 10 + ) # threshold from -30 to -10 dB + ratio = (torch.rand([1]).numpy()[0] * 4.0) + 1 # ratio from 1:1 to 5:1 + + # compute the transfer curve + point = -((-threshold / -ratio) + (-start / ratio) + -threshold) + + # apply some makeup gain + makeup = torch.rand([1]).numpy()[0] * 6 + + effects = [ + [ + "compand", + f"{attack},{release}", + f"{knee}:{point},{start},{threshold},{threshold}", + f"{makeup}", + f"{start}", + ] + ] + + for idx, x in enumerate(xs): + # if the input is clipping normalize it + if x.abs().max() >= 1.0: + x /= x.abs().max() + gain_db = -((torch.rand(1) * 24)).numpy().squeeze() + x *= 10 ** (gain_db / 20.0) + + y, sr = torchaudio.sox_effects.apply_effects_tensor( + x.view(1, -1), sr, effects, channels_first=True + ) + xs[idx] = y + + return xs + + +def dynamic_range_compression(xs, sr=44100): + """Apply a compressor.""" + + attack = (torch.rand([1]).numpy()[0] * 0.05) + 0.0005 + release = (torch.rand([1]).numpy()[0] * 0.2) + attack + knee = (torch.rand([1]).numpy()[0] * 12) + 0.0 + + # design the compressor transfer function + start = -100.0 + threshold = -((torch.rand([1]).numpy()[0] * 52) + 12) + # threshold from -64 to -12 dB + ratio = (torch.rand([1]).numpy()[0] * 10.0) + 1 # ratio from 1:1 to 10:1 + + # compute the transfer curve + point = threshold * (1 - (1 / ratio)) + + # apply some makeup gain + makeup = torch.rand([1]).numpy()[0] * 6 + + effects = [ + [ + "compand", + f"{attack},{release}", + f"{knee}:{start},{threshold},{threshold},0,{point}", + f"{makeup}", + f"{start}", + f"{attack}", + ] + ] + + for idx, x in enumerate(xs): + y, sr = torchaudio.sox_effects.apply_effects_tensor( + x.view(1, -1), sr, effects, channels_first=True + ) + xs[idx] = y + + return xs + + +def lowpass_filter(xs, sr=44100, frequency=4000): + effects = [["lowpass", f"{frequency}"]] + + for idx, x in enumerate(xs): + y, sr = torchaudio.sox_effects.apply_effects_tensor( + x, sr, effects, channels_first=True + ) + xs[idx] = y + + return xs + + +def apply(xs, sr, augmentations): + + # iterate over augmentation dict + for aug, params in augmentations.items(): + if aug == "gain": + xs = gain(xs, **params) + elif aug == "peak": + xs = peaking_filter(xs, **params) + elif aug == "lowpass": + xs = lowpass_filter(xs, **params) + elif aug == "pitch": + xs = pitch_shift(xs, **params) + elif aug == "tempo": + xs = time_stretch(xs, **params) + elif aug == "freq_corrupt": + xs = frequency_corruption(xs, **params) + else: + raise RuntimeError("Invalid augmentation: {aug}") + + return xs diff --git a/deepafx_st/data/dataset.py b/deepafx_st/data/dataset.py new file mode 100755 index 0000000..41ebff6 --- /dev/null +++ b/deepafx_st/data/dataset.py @@ -0,0 +1,344 @@ +import os +import sys +import csv +import glob +import torch +import random +from tqdm import tqdm +from typing import List, Any + +from deepafx_st.data.audio import AudioFile +import deepafx_st.utils as utils +import deepafx_st.data.augmentations as augmentations + + +class AudioDataset(torch.utils.data.Dataset): + """Audio dataset which returns an input and target file. + + Args: + audio_dir (str): Path to the top level of the audio dataset. + input_dir (List[str], optional): List of paths to the directories containing input audio files. Default: ["clean"] + subset (str, optional): Dataset subset. One of ["train", "val", "test"]. Default: "train" + length (int, optional): Number of samples to load for each example. Default: 65536 + train_frac (float, optional): Fraction of the files to use for training subset. Default: 0.8 + val_frac (float, optional): Fraction of the files to use for validation subset. Default: 0.1 + buffer_size_gb (float, optional): Size of audio to read into RAM in GB at any given time. Default: 10.0 + Note: This is the buffer size PER DataLoader worker. So total RAM = buffer_size_gb * num_workers + buffer_reload_rate (int, optional): Number of items to generate before loading next chunk of dataset. Default: 10000 + half (bool, optional): Sotre audio samples as float 16. Default: False + num_examples_per_epoch (int, optional): Define an epoch as certain number of audio examples. Default: 10000 + random_scale_input (bool, optional): Apply random gain scaling to input utterances. Default: False + random_scale_target (bool, optional): Apply same random gain scaling to target utterances. Default: False + augmentations (dict, optional): List of augmentation types to apply to inputs. Default: [] + freq_corrupt (bool, optional): Apply bad EQ filters. Default: False + drc_corrupt (bool, optional): Apply an expander to corrupt dynamic range. Default: False + ext (str, optional): Expected audio file extension. Default: "wav" + """ + + def __init__( + self, + audio_dir, + input_dirs: List[str] = ["cleanraw"], + subset: str = "train", + length: int = 65536, + train_frac: float = 0.8, + val_per: float = 0.1, + buffer_size_gb: float = 1.0, + buffer_reload_rate: float = 1000, + half: bool = False, + num_examples_per_epoch: int = 10000, + random_scale_input: bool = False, + random_scale_target: bool = False, + augmentations: dict = {}, + freq_corrupt: bool = False, + drc_corrupt: bool = False, + ext: str = "wav", + ): + super().__init__() + self.audio_dir = audio_dir + self.dataset_name = os.path.basename(audio_dir) + self.input_dirs = input_dirs + self.subset = subset + self.length = length + self.train_frac = train_frac + self.val_per = val_per + self.buffer_size_gb = buffer_size_gb + self.buffer_reload_rate = buffer_reload_rate + self.half = half + self.num_examples_per_epoch = num_examples_per_epoch + self.random_scale_input = random_scale_input + self.random_scale_target = random_scale_target + self.augmentations = augmentations + self.freq_corrupt = freq_corrupt + self.drc_corrupt = drc_corrupt + self.ext = ext + + self.input_filepaths = [] + for input_dir in input_dirs: + search_path = os.path.join(audio_dir, input_dir, f"*.{ext}") + self.input_filepaths += glob.glob(search_path) + self.input_filepaths = sorted(self.input_filepaths) + + # create dataset split based on subset + self.input_filepaths = utils.split_dataset( + self.input_filepaths, + subset, + train_frac, + ) + + # get details about input audio files + input_files = {} + input_dur_frames = 0 + for input_filepath in tqdm(self.input_filepaths, ncols=80): + file_id = os.path.basename(input_filepath) + audio_file = AudioFile( + input_filepath, + preload=False, + half=half, + ) + if audio_file.num_frames < (self.length * 2): + continue + input_files[file_id] = audio_file + input_dur_frames += input_files[file_id].num_frames + + if len(list(input_files.items())) < 1: + raise RuntimeError(f"No files found in {search_path}.") + + input_dur_hr = (input_dur_frames / input_files[file_id].sample_rate) / 3600 + print( + f"\nLoaded {len(input_files)} files for {subset} = {input_dur_hr:0.2f} hours." + ) + + self.sample_rate = input_files[file_id].sample_rate + + # save a csv file with details about the train and test split + splits_dir = os.path.join("configs", "splits") + if not os.path.isdir(splits_dir): + os.makedirs(splits_dir) + csv_filepath = os.path.join(splits_dir, f"{self.dataset_name}_{self.subset}_set.csv") + + with open(csv_filepath, "w") as fp: + dw = csv.DictWriter(fp, ["file_id", "filepath", "type", "subset"]) + dw.writeheader() + for input_filepath in self.input_filepaths: + dw.writerow( + { + "file_id": self.get_file_id(input_filepath), + "filepath": input_filepath, + "type": "input", + "subset": self.subset, + } + ) + + # some setup for iteratble loading of the dataset into RAM + self.items_since_load = self.buffer_reload_rate + + def __len__(self): + return self.num_examples_per_epoch + + def load_audio_buffer(self): + self.input_files_loaded = {} # clear audio buffer + self.items_since_load = 0 # reset iteration counter + nbytes_loaded = 0 # counter for data in RAM + + # different subset in each + random.shuffle(self.input_filepaths) + + # load files into RAM + for input_filepath in self.input_filepaths: + file_id = os.path.basename(input_filepath) + audio_file = AudioFile( + input_filepath, + preload=True, + half=self.half, + ) + + if audio_file.num_frames < (self.length * 2): + continue + + self.input_files_loaded[file_id] = audio_file + + nbytes = audio_file.audio.element_size() * audio_file.audio.nelement() + nbytes_loaded += nbytes + + # check the size of loaded data + if nbytes_loaded > self.buffer_size_gb * 1e9: + break + + def generate_pair(self): + # ------------------------ Input audio ---------------------- + rand_input_file_id = None + input_file = None + start_idx = None + stop_idx = None + while True: + rand_input_file_id = self.get_random_file_id(self.input_files_loaded.keys()) + + # use this random key to retrieve an input file + input_file = self.input_files_loaded[rand_input_file_id] + + # load the audio data if needed + if not input_file.loaded: + raise RuntimeError("Audio not loaded.") + + # get a random patch of size `self.length` x 2 + start_idx, stop_idx = self.get_random_patch( + input_file, int(self.length * 2) + ) + if start_idx >= 0: + break + + input_audio = input_file.audio[:, start_idx:stop_idx].clone().detach() + input_audio = input_audio.view(1, -1) + + if self.half: + input_audio = input_audio.float() + + # peak normalize to -12 dBFS + input_audio /= input_audio.abs().max() + input_audio *= 10 ** (-12.0 / 20) # with min 3 dBFS headroom + + if len(list(self.augmentations.items())) > 0: + if torch.rand(1).sum() < 0.5: + input_audio_aug = augmentations.apply( + [input_audio], + self.sample_rate, + self.augmentations, + )[0] + else: + input_audio_aug = input_audio.clone() + else: + input_audio_aug = input_audio.clone() + + input_audio_corrupt = input_audio_aug.clone() + # apply frequency and dynamic range corrpution (expander) + if self.freq_corrupt and torch.rand(1).sum() < 0.75: + input_audio_corrupt = augmentations.frequency_corruption( + [input_audio_corrupt], self.sample_rate + )[0] + + # peak normalize again before passing through dynamic range expander + input_audio_corrupt /= input_audio_corrupt.abs().max() + input_audio_corrupt *= 10 ** (-12.0 / 20) # with min 3 dBFS headroom + + if self.drc_corrupt and torch.rand(1).sum() < 0.10: + input_audio_corrupt = augmentations.dynamic_range_corruption( + [input_audio_corrupt], self.sample_rate + )[0] + + # ------------------------ Target audio ---------------------- + # use the same augmented audio clip, add different random EQ and compressor + + target_audio_corrupt = input_audio_aug.clone() + # apply frequency and dynamic range corrpution (expander) + if self.freq_corrupt and torch.rand(1).sum() < 0.75: + target_audio_corrupt = augmentations.frequency_corruption( + [target_audio_corrupt], self.sample_rate + )[0] + + # peak normalize again before passing through dynamic range compressor + input_audio_corrupt /= input_audio_corrupt.abs().max() + input_audio_corrupt *= 10 ** (-12.0 / 20) # with min 3 dBFS headroom + + if self.drc_corrupt and torch.rand(1).sum() < 0.75: + target_audio_corrupt = augmentations.dynamic_range_compression( + [target_audio_corrupt], self.sample_rate + )[0] + + return input_audio_corrupt, target_audio_corrupt + + def __getitem__(self, _): + """ """ + + # increment counter + self.items_since_load += 1 + + # load next chunk into buffer if needed + if self.items_since_load > self.buffer_reload_rate: + self.load_audio_buffer() + + # generate pairs for style training + input_audio, target_audio = self.generate_pair() + + # ------------------------ Conform length of files ------------------- + input_audio = utils.conform_length(input_audio, int(self.length * 2)) + target_audio = utils.conform_length(target_audio, int(self.length * 2)) + + # ------------------------ Apply fade in and fade out ------------------- + input_audio = utils.linear_fade(input_audio, sample_rate=self.sample_rate) + target_audio = utils.linear_fade(target_audio, sample_rate=self.sample_rate) + + # ------------------------ Final normalizeation ---------------------- + # always peak normalize final input to -12 dBFS + input_audio /= input_audio.abs().max() + input_audio *= 10 ** (-12.0 / 20.0) + + # always peak normalize the target to -12 dBFS + target_audio /= target_audio.abs().max() + target_audio *= 10 ** (-12.0 / 20.0) + + return input_audio, target_audio + + @staticmethod + def get_random_file_id(keys): + # generate a random index into the keys of the input files + rand_input_idx = torch.randint(0, len(keys) - 1, [1])[0] + # find the key (file_id) correponding to the random index + rand_input_file_id = list(keys)[rand_input_idx] + + return rand_input_file_id + + @staticmethod + def get_random_patch(audio_file, length, check_silence=True): + silent = True + count = 0 + while silent: + count += 1 + start_idx = torch.randint(0, audio_file.num_frames - length - 1, [1])[0] + # int(torch.rand(1) * (audio_file.num_frames - length)) + stop_idx = start_idx + length + patch = audio_file.audio[:, start_idx:stop_idx].clone().detach() + + length = patch.shape[-1] + first_patch = patch[..., : length // 2] + second_patch = patch[..., length // 2 :] + + if ( + (first_patch**2).mean() > 1e-5 and (second_patch**2).mean() > 1e-5 + ) or not check_silence: + silent = False + + if count > 100: + print("get_random_patch count", count) + return -1, -1 + # break + + return start_idx, stop_idx + + def get_file_id(self, filepath): + """Given a filepath extract the DAPS file id. + + Args: + filepath (str): Path to an audio files in the DAPS dataset. + + Returns: + file_id (str): DAPS file id of the form _ + file_set (str): The DAPS set to which the file belongs. + """ + file_id = os.path.basename(filepath).split("_")[:2] + file_id = "_".join(file_id) + return file_id + + def get_file_set(self, filepath): + """Given a filepath extract the DAPS file set name. + + Args: + filepath (str): Path to an audio files in the DAPS dataset. + + Returns: + file_set (str): The DAPS set to which the file belongs. + """ + file_set = os.path.basename(filepath).split("_")[2:] + file_set = "_".join(file_set) + file_set = file_set.replace(f".{self.ext}", "") + return file_set diff --git a/deepafx_st/data/proxy.py b/deepafx_st/data/proxy.py new file mode 100755 index 0000000..79d6afd --- /dev/null +++ b/deepafx_st/data/proxy.py @@ -0,0 +1,181 @@ +import os +import json +import glob +import torch +import random +from tqdm import tqdm + +# from deepafx_st.plugins.channel import Channel +from deepafx_st.processors.processor import Processor +from deepafx_st.data.audio import AudioFile +import deepafx_st.utils as utils + + +class DSPProxyDataset(torch.utils.data.Dataset): + """Class for generating input-output audio from Python DSP effects. + + Args: + input_dir (List[str]): List of paths to the directories containing input audio files. + processor (Processor): Processor object to create proxy of. + processor_type (str): Processor name. + subset (str, optional): Dataset subset. One of ["train", "val", "test"]. Default: "train" + buffer_size_gb (float, optional): Size of audio to read into RAM in GB at any given time. Default: 10.0 + Note: This is the buffer size PER DataLoader worker. So total RAM = buffer_size_gb * num_workers + buffer_reload_rate (int, optional): Number of items to generate before loading next chunk of dataset. Default: 10000 + length (int, optional): Number of samples to load for each example. Default: 65536 + num_examples_per_epoch (int, optional): Define an epoch as certain number of audio examples. Default: 10000 + ext (str, optional): Expected audio file extension. Default: "wav" + hard_clip (bool, optional): Hard clip outputs between -1 and 1. Default: True + """ + + def __init__( + self, + input_dir: str, + processor: Processor, + processor_type: str, + subset="train", + length=65536, + buffer_size_gb=1.0, + buffer_reload_rate=1000, + half=False, + num_examples_per_epoch=10000, + ext="wav", + soft_clip=True, + ): + super().__init__() + self.input_dir = input_dir + self.processor = processor + self.processor_type = processor_type + self.subset = subset + self.length = length + self.buffer_size_gb = buffer_size_gb + self.buffer_reload_rate = buffer_reload_rate + self.half = half + self.num_examples_per_epoch = num_examples_per_epoch + self.ext = ext + self.soft_clip = soft_clip + + search_path = os.path.join(input_dir, f"*.{ext}") + self.input_filepaths = glob.glob(search_path) + self.input_filepaths = sorted(self.input_filepaths) + + if len(self.input_filepaths) < 1: + raise RuntimeError(f"No files found in {input_dir}.") + + # get training split + self.input_filepaths = utils.split_dataset( + self.input_filepaths, self.subset, 0.9 + ) + + # get details about audio files + cnt = 0 + self.input_files = {} + for input_filepath in tqdm(self.input_filepaths, ncols=80): + file_id = os.path.basename(input_filepath) + audio_file = AudioFile( + input_filepath, + preload=False, + half=half, + ) + if audio_file.num_frames < self.length: + continue + self.input_files[file_id] = audio_file + self.sample_rate = self.input_files[file_id].sample_rate + cnt += 1 + if cnt > 1000: + break + + # some setup for iteratble loading of the dataset into RAM + self.items_since_load = self.buffer_reload_rate + + def __len__(self): + return self.num_examples_per_epoch + + def load_audio_buffer(self): + self.input_files_loaded = {} # clear audio buffer + self.items_since_load = 0 # reset iteration counter + nbytes_loaded = 0 # counter for data in RAM + + # different subset in each + random.shuffle(self.input_filepaths) + + # load files into RAM + for input_filepath in self.input_filepaths: + file_id = os.path.basename(input_filepath) + audio_file = AudioFile( + input_filepath, + preload=True, + half=self.half, + ) + + if audio_file.num_frames < self.length: + continue + + self.input_files_loaded[file_id] = audio_file + + nbytes = audio_file.audio.element_size() * audio_file.audio.nelement() + nbytes_loaded += nbytes + + if nbytes_loaded > self.buffer_size_gb * 1e9: + break + + def __getitem__(self, _): + """ """ + + # increment counter + self.items_since_load += 1 + + # load next chunk into buffer if needed + if self.items_since_load > self.buffer_reload_rate: + self.load_audio_buffer() + + rand_input_file_id = utils.get_random_file_id(self.input_files_loaded.keys()) + # use this random key to retrieve an input file + input_file = self.input_files_loaded[rand_input_file_id] + + # load the audio data if needed + if not input_file.loaded: + input_file.load() + + # get a random patch of size `self.length` + # start_idx, stop_idx = utils.get_random_patch(input_file, self.sample_rate, self.length) + start_idx, stop_idx = utils.get_random_patch(input_file, self.length) + input_audio = input_file.audio[:, start_idx:stop_idx].clone().detach() + + # random scaling + input_audio /= input_audio.abs().max() + scale_dB = (torch.rand(1).squeeze().numpy() * 12) + 12 + input_audio *= 10 ** (-scale_dB / 20.0) + + # generate random parameters (uniform) over 0 to 1 + params = torch.rand(self.processor.num_control_params) + + # expects batch dim + # apply plugins with random parameters + if self.processor_type == "channel": + params[-1] = 0.5 # set makeup gain to 0dB + target_audio = self.processor( + input_audio.view(1, 1, -1), + params.view(1, -1), + ) + target_audio = target_audio.view(1, -1) + elif self.processor_type == "peq": + target_audio = self.processor( + input_audio.view(1, 1, -1).numpy(), + params.view(1, -1).numpy(), + ) + target_audio = torch.tensor(target_audio).view(1, -1) + elif self.processor_type == "comp": + params[-1] = 0.5 # set makeup gain to 0dB + target_audio = self.processor( + input_audio.view(1, 1, -1).numpy(), + params.view(1, -1).numpy(), + ) + target_audio = torch.tensor(target_audio).view(1, -1) + + # clip + if self.soft_clip: + # target_audio = target_audio.clamp(-2.0, 2.0) + target_audio = torch.tanh(target_audio / 2.0) * 2.0 + + return input_audio, target_audio, params diff --git a/deepafx_st/data/style.py b/deepafx_st/data/style.py new file mode 100644 index 0000000..d1a44cc --- /dev/null +++ b/deepafx_st/data/style.py @@ -0,0 +1,62 @@ +import os +import glob +import torch +import torchaudio +from tqdm import tqdm + + +class StyleDataset(torch.utils.data.Dataset): + def __init__( + self, + audio_dir: str, + subset: str = "train", + sample_rate: int = 24000, + length: int = 131072, + ) -> None: + super().__init__() + self.audio_dir = audio_dir + self.subset = subset + self.sample_rate = sample_rate + self.length = length + + self.style_dirs = glob.glob(os.path.join(audio_dir, subset, "*")) + self.style_dirs = [sd for sd in self.style_dirs if os.path.isdir(sd)] + self.num_classes = len(self.style_dirs) + self.class_labels = {"broadcast" : 0, "telephone": 1, "neutral": 2, "bright": 3, "warm": 4} + + self.examples = [] + for n, style_dir in enumerate(self.style_dirs): + + # get all files in style dir + style_filepaths = glob.glob(os.path.join(style_dir, "*.wav")) + style_name = os.path.basename(style_dir) + for style_filepath in tqdm(style_filepaths, ncols=120): + # load audio file + x, sr = torchaudio.load(style_filepath) + + # sum to mono if needed + if x.shape[0] > 1: + x = x.mean(dim=0, keepdim=True) + + # resample + if sr != self.sample_rate: + x = torchaudio.transforms.Resample(sr, self.sample_rate)(x) + + # crop length after resample + if x.shape[-1] >= self.length: + x = x[...,:self.length] + + # store example + example = (x, self.class_labels[style_name]) + self.examples.append(example) + + print(f"Loaded {len(self.examples)} examples for {subset} subset.") + + def __len__(self): + return len(self.examples) + + def __getitem__(self, idx): + example = self.examples[idx] + x = example[0] + y = example[1] + return x, y diff --git a/deepafx_st/metrics.py b/deepafx_st/metrics.py new file mode 100755 index 0000000..ca5ea20 --- /dev/null +++ b/deepafx_st/metrics.py @@ -0,0 +1,157 @@ +import torch +import auraloss +import resampy +import torchaudio +from pesq import pesq +import pyloudnorm as pyln + + +def crest_factor(x): + """Compute the crest factor of waveform.""" + + peak, _ = x.abs().max(dim=-1) + rms = torch.sqrt((x ** 2).mean(dim=-1)) + + return 20 * torch.log(peak / rms.clamp(1e-8)) + + +def rms_energy(x): + + rms = torch.sqrt((x ** 2).mean(dim=-1)) + + return 20 * torch.log(rms.clamp(1e-8)) + + +def spectral_centroid(x): + """Compute the crest factor of waveform. + + See: https://gist.github.com/endolith/359724 + + """ + + spectrum = torch.fft.rfft(x).abs() + normalized_spectrum = spectrum / spectrum.sum() + normalized_frequencies = torch.linspace(0, 1, spectrum.shape[-1]) + spectral_centroid = torch.sum(normalized_frequencies * normalized_spectrum) + + return spectral_centroid + + +def loudness(x, sample_rate): + """Compute the loudness in dB LUFS of waveform.""" + meter = pyln.Meter(sample_rate) + + # add stereo dim if needed + if x.shape[0] < 2: + x = x.repeat(2, 1) + + return torch.tensor(meter.integrated_loudness(x.permute(1, 0).numpy())) + + +class MelSpectralDistance(torch.nn.Module): + def __init__(self, sample_rate, length=65536): + super().__init__() + self.error = auraloss.freq.MelSTFTLoss( + sample_rate, + fft_size=length, + hop_size=length, + win_length=length, + w_sc=0, + w_log_mag=1, + w_lin_mag=1, + n_mels=128, + scale_invariance=False, + ) + + # I think scale invariance may not work well, + # since aspects of the phase may be considered? + + def forward(self, input, target): + return self.error(input, target) + + +class PESQ(torch.nn.Module): + def __init__(self, sample_rate): + super().__init__() + self.sample_rate = sample_rate + + def forward(self, input, target): + if self.sample_rate != 16000: + target = resampy.resample( + target.view(-1).numpy(), + self.sample_rate, + 16000, + ) + input = resampy.resample( + input.view(-1).numpy(), + self.sample_rate, + 16000, + ) + + return pesq( + 16000, + target, + input, + "wb", + ) + + +class CrestFactorError(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, target): + return torch.nn.functional.l1_loss( + crest_factor(input), + crest_factor(target), + ).item() + + +class RMSEnergyError(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, target): + return torch.nn.functional.l1_loss( + rms_energy(input), + rms_energy(target), + ).item() + + +class SpectralCentroidError(torch.nn.Module): + def __init__(self, sample_rate, n_fft=2048, hop_length=512): + super().__init__() + + self.spectral_centroid = torchaudio.transforms.SpectralCentroid( + sample_rate, + n_fft=n_fft, + hop_length=hop_length, + ) + + def forward(self, input, target): + return torch.nn.functional.l1_loss( + self.spectral_centroid(input + 1e-16).mean(), + self.spectral_centroid(target + 1e-16).mean(), + ).item() + + +class LoudnessError(torch.nn.Module): + def __init__(self, sample_rate: int, peak_normalize: bool = False): + super().__init__() + self.sample_rate = sample_rate + self.peak_normalize = peak_normalize + + def forward(self, input, target): + + if self.peak_normalize: + # peak normalize + x = input / input.abs().max() + y = target / target.abs().max() + else: + x = input + y = target + + return torch.nn.functional.l1_loss( + loudness(x.view(1, -1), self.sample_rate), + loudness(y.view(1, -1), self.sample_rate), + ).item() diff --git a/deepafx_st/models/baselines.py b/deepafx_st/models/baselines.py new file mode 100755 index 0000000..806caca --- /dev/null +++ b/deepafx_st/models/baselines.py @@ -0,0 +1,280 @@ +import torch +import torchaudio +import scipy.signal +import numpy as np +import pyloudnorm as pyln +import matplotlib.pyplot as plt +from deepafx_st.processors.dsp.compressor import compressor + +from tqdm import tqdm + + +class BaselineEQ(torch.nn.Module): + def __init__( + self, + ntaps: int = 63, + n_fft: int = 65536, + sample_rate: float = 44100, + ): + super().__init__() + self.ntaps = ntaps + self.n_fft = n_fft + self.sample_rate = sample_rate + + # compute the target spectrum + # print("Computing target spectrum...") + # self.target_spec, self.sm_target_spec = self.analyze_speech_dataset(filepaths) + # self.plot_spectrum(self.target_spec, filename="targetEQ") + # self.plot_spectrum(self.sm_target_spec, filename="targetEQsm") + + def forward(self, x, y): + + bs, ch, s = x.size() + + x = x.view(bs * ch, -1) + y = y.view(bs * ch, -1) + + in_spec = self.get_average_spectrum(x) + ref_spec = self.get_average_spectrum(y) + + sm_in_spec = self.smooth_spectrum(in_spec) + sm_ref_spec = self.smooth_spectrum(ref_spec) + + # self.plot_spectrum(in_spec, filename="inSpec") + # self.plot_spectrum(sm_in_spec, filename="inSpecsm") + + # design inverse FIR filter to match target EQ + freqs = np.linspace(0, 1.0, num=(self.n_fft // 2) + 1) + response = sm_ref_spec / sm_in_spec + response[-1] = 0.0 # zero gain at nyquist + + b = scipy.signal.firwin2( + self.ntaps, + freqs * (self.sample_rate / 2), + response, + fs=self.sample_rate, + ) + + # scale the coefficients for less intense filter + # clearb *= 0.5 + + # apply the filter + x_filt = scipy.signal.lfilter(b, [1.0], x.numpy()) + x_filt = torch.tensor(x_filt.astype("float32")) + + if False: + # plot the filter response + w, h = scipy.signal.freqz(b, fs=self.sample_rate, worN=response.shape[-1]) + + fig, ax1 = plt.subplots() + ax1.set_title("Digital filter frequency response") + ax1.plot(w, 20 * np.log10(abs(h + 1e-8))) + ax1.plot(w, 20 * np.log10(abs(response + 1e-8))) + + ax1.set_xscale("log") + ax1.set_ylim([-12, 12]) + plt.grid(c="lightgray") + plt.savefig(f"inverse.png") + + x_filt_avg_spec = self.get_average_spectrum(x_filt) + sm_x_filt_avg_spec = self.smooth_spectrum(x_filt_avg_spec) + y_avg_spec = self.get_average_spectrum(y) + sm_y_avg_spec = self.smooth_spectrum(y_avg_spec) + compare = torch.stack( + [ + torch.tensor(sm_in_spec), + torch.tensor(sm_x_filt_avg_spec), + torch.tensor(sm_ref_spec), + torch.tensor(sm_y_avg_spec), + ] + ) + self.plot_multi_spectrum( + compare, + legend=["in", "out", "target curve", "actual target"], + filename="outSpec", + ) + + return x_filt + + def analyze_speech_dataset(self, filepaths, peak=-3.0): + avg_spec = [] + for filepath in tqdm(filepaths, ncols=80): + x, sr = torchaudio.load(filepath) + x /= x.abs().max() + x *= 10 ** (peak / 20.0) + avg_spec.append(self.get_average_spectrum(x)) + avg_specs = torch.stack(avg_spec) + + avg_spec = avg_specs.mean(dim=0).numpy() + avg_spec_std = avg_specs.std(dim=0).numpy() + + # self.plot_multi_spectrum(avg_specs, filename="allTargetEQs") + # self.plot_spectrum_stats(avg_spec, avg_spec_std, filename="targetEQstats") + + sm_avg_spec = self.smooth_spectrum(avg_spec) + + return avg_spec, sm_avg_spec + + def smooth_spectrum(self, H): + # apply Savgol filter for smoothed target curve + return scipy.signal.savgol_filter(H, 1025, 2) + + def get_average_spectrum(self, x): + + # x = x[:, : self.n_fft] + X = torch.stft(x, self.n_fft, return_complex=True, normalized=True) + # fft_size = self.next_power_of_2(x.shape[-1]) + # X = torch.fft.rfft(x, n=fft_size) + + X = X.abs() # convert to magnitude + X = X.mean(dim=-1).view(-1) # average across frames + + return X + + @staticmethod + def next_power_of_2(x): + return 1 if x == 0 else int(2 ** np.ceil(np.log2(x))) + + def plot_multi_spectrum(self, Hs, legend=[], filename=None): + + bin_width = (self.sample_rate / 2) / (self.n_fft // 2) + freqs = np.arange(0, (self.sample_rate / 2) + bin_width, step=bin_width) + + fig, ax1 = plt.subplots() + + for H in Hs: + ax1.plot( + freqs, + 20 * np.log10(abs(H) + 1e-8), + ) + + plt.legend(legend) + + # avg_spec = Hs.mean(dim=0).numpy() + # ax1.plot(freqs, 20 * np.log10(avg_spec), color="k", linewidth=2) + + ax1.set_xscale("log") + ax1.set_ylim([-80, 0]) + plt.grid(c="lightgray") + + if filename is not None: + plt.savefig(f"{filename}.png") + + def plot_spectrum_stats(self, H_mean, H_std, filename=None): + bin_width = (self.sample_rate / 2) / (self.n_fft // 2) + freqs = np.arange(0, (self.sample_rate / 2) + bin_width, step=bin_width) + + fig, ax1 = plt.subplots() + ax1.plot(freqs, 20 * np.log10(H_mean)) + ax1.plot( + freqs, + (20 * np.log10(H_mean)) + (20 * np.log10(H_std)), + linestyle="--", + color="k", + ) + ax1.plot( + freqs, + (20 * np.log10(H_mean)) - (20 * np.log10(H_std)), + linestyle="--", + color="k", + ) + + ax1.set_xscale("log") + ax1.set_ylim([-80, 0]) + plt.grid(c="lightgray") + + if filename is not None: + plt.savefig(f"{filename}.png") + + def plot_spectrum(self, H, legend=[], filename=None): + + bin_width = (self.sample_rate / 2) / (self.n_fft // 2) + freqs = np.arange(0, (self.sample_rate / 2) + bin_width, step=bin_width) + + fig, ax1 = plt.subplots() + ax1.plot(freqs, 20 * np.log10(H)) + ax1.set_xscale("log") + ax1.set_ylim([-80, 0]) + plt.grid(c="lightgray") + + plt.legend(legend) + + if filename is not None: + plt.savefig(f"{filename}.png") + + +class BaslineComp(torch.nn.Module): + def __init__( + self, + sample_rate: float = 44100, + ): + super().__init__() + self.sample_rate = sample_rate + self.meter = pyln.Meter(sample_rate) + + def forward(self, x, y): + + x_lufs = self.meter.integrated_loudness(x.view(-1).numpy()) + y_lufs = self.meter.integrated_loudness(y.view(-1).numpy()) + + delta_lufs = y_lufs - x_lufs + + threshold = 0.0 + x_comp = x + x_comp_new = x + while delta_lufs > 0.5 and threshold > -80.0: + x_comp = x_comp_new # use the last setting + x_comp_new = compressor( + x.view(-1).numpy(), + self.sample_rate, + threshold=threshold, + ratio=3, + attack_time=0.001, + release_time=0.05, + knee_dB=6.0, + makeup_gain_dB=0.0, + ) + x_comp_new = torch.tensor(x_comp_new) + x_comp_new /= x_comp_new.abs().max() + x_comp_new *= 10 ** (-12.0 / 20) + x_lufs = self.meter.integrated_loudness(x_comp_new.view(-1).numpy()) + delta_lufs = y_lufs - x_lufs + threshold -= 0.5 + + return x_comp.view(1, 1, -1) + + +class BaselineEQAndComp(torch.nn.Module): + def __init__( + self, + ntaps=63, + n_fft=65536, + sample_rate=44100, + block_size=1024, + plugin_config=None, + ): + super().__init__() + self.eq = BaselineEQ(ntaps, n_fft, sample_rate) + self.comp = BaslineComp(sample_rate) + + def forward(self, x, y): + + with torch.inference_mode(): + x /= x.abs().max() + y /= y.abs().max() + x *= 10 ** (-12.0 / 20) + y *= 10 ** (-12.0 / 20) + + x = self.eq(x, y) + + x /= x.abs().max() + y /= y.abs().max() + x *= 10 ** (-12.0 / 20) + y *= 10 ** (-12.0 / 20) + + x = self.comp(x, y) + + x /= x.abs().max() + x *= 10 ** (-12.0 / 20) + + return x diff --git a/deepafx_st/models/controller.py b/deepafx_st/models/controller.py new file mode 100755 index 0000000..859e552 --- /dev/null +++ b/deepafx_st/models/controller.py @@ -0,0 +1,75 @@ +import torch + +class StyleTransferController(torch.nn.Module): + def __init__( + self, + num_control_params, + edim, + hidden_dim=256, + agg_method="mlp", + ): + """Plugin parameter controller module to map from input to target style. + + Args: + num_control_params (int): Number of plugin parameters to predicted. + edim (int): Size of the encoder representations. + hidden_dim (int, optional): Hidden size of the 3-layer parameter predictor MLP. Default: 256 + agg_method (str, optional): Input/reference embed aggregation method ["conv" or "linear", "mlp"]. Default: "mlp" + """ + super().__init__() + self.num_control_params = num_control_params + self.edim = edim + self.hidden_dim = hidden_dim + self.agg_method = agg_method + + if agg_method == "conv": + self.agg = torch.nn.Conv1d( + 2, + 1, + kernel_size=129, + stride=1, + padding="same", + bias=False, + ) + mlp_in_dim = edim + elif agg_method == "linear": + self.agg = torch.nn.Linear(edim * 2, edim) + elif agg_method == "mlp": + self.agg = None + mlp_in_dim = edim * 2 + else: + raise ValueError(f"Invalid agg_method = {self.agg_method}.") + + self.mlp = torch.nn.Sequential( + torch.nn.Linear(mlp_in_dim, hidden_dim), + torch.nn.LeakyReLU(0.01), + torch.nn.Linear(hidden_dim, hidden_dim), + torch.nn.LeakyReLU(0.01), + torch.nn.Linear(hidden_dim, num_control_params), + torch.nn.Sigmoid(), # normalize between 0 and 1 + ) + + def forward(self, e_x, e_y, z=None): + """Forward pass to generate plugin parameters. + + Args: + e_x (tensor): Input signal embedding of shape (batch, edim) + e_y (tensor): Target signal embedding of shape (batch, edim) + Returns: + p (tensor): Estimated control parameters of shape (batch, num_control_params) + """ + + # use learnable projection + if self.agg_method == "conv": + e_xy = torch.stack((e_x, e_y), dim=1) # concat on channel dim + e_xy = self.agg(e_xy) + elif self.agg_method == "linear": + e_xy = torch.cat((e_x, e_y), dim=-1) # concat on embed dim + e_xy = self.agg(e_xy) + else: + e_xy = torch.cat((e_x, e_y), dim=-1) # concat on embed dim + + # pass through MLP to project to control parametesr + p = self.mlp(e_xy.squeeze(1)) + + return p diff --git a/deepafx_st/models/efficient_net/LICENSE b/deepafx_st/models/efficient_net/LICENSE new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/deepafx_st/models/efficient_net/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/deepafx_st/models/efficient_net/__init__.py b/deepafx_st/models/efficient_net/__init__.py new file mode 100644 index 0000000..2b529df --- /dev/null +++ b/deepafx_st/models/efficient_net/__init__.py @@ -0,0 +1,9 @@ +__version__ = "0.7.1" +from .model import EfficientNet, VALID_MODELS +from .utils import ( + GlobalParams, + BlockArgs, + BlockDecoder, + efficientnet, + get_model_params, +) diff --git a/deepafx_st/models/efficient_net/model.py b/deepafx_st/models/efficient_net/model.py new file mode 100755 index 0000000..ce850cd --- /dev/null +++ b/deepafx_st/models/efficient_net/model.py @@ -0,0 +1,419 @@ +"""model.py - Model and module class for EfficientNet. + They are built to mirror those in the official TensorFlow implementation. +""" + +# Author: lukemelas (github username) +# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch +# With adjustments and added comments by workingcoder (github username). + +import torch +from torch import nn +from torch.nn import functional as F +from .utils import ( + round_filters, + round_repeats, + drop_connect, + get_same_padding_conv2d, + get_model_params, + efficientnet_params, + load_pretrained_weights, + Swish, + MemoryEfficientSwish, + calculate_output_image_size +) + + +VALID_MODELS = ( + 'efficientnet-b0', 'efficientnet-b1', 'efficientnet-b2', 'efficientnet-b3', + 'efficientnet-b4', 'efficientnet-b5', 'efficientnet-b6', 'efficientnet-b7', + 'efficientnet-b8', + + # Support the construction of 'efficientnet-l2' without pretrained weights + 'efficientnet-l2' +) + + +class MBConvBlock(nn.Module): + """Mobile Inverted Residual Bottleneck Block. + + Args: + block_args (namedtuple): BlockArgs, defined in utils.py. + global_params (namedtuple): GlobalParam, defined in utils.py. + image_size (tuple or list): [image_height, image_width]. + + References: + [1] https://arxiv.org/abs/1704.04861 (MobileNet v1) + [2] https://arxiv.org/abs/1801.04381 (MobileNet v2) + [3] https://arxiv.org/abs/1905.02244 (MobileNet v3) + """ + + def __init__(self, block_args, global_params, image_size=None): + super().__init__() + self._block_args = block_args + self._bn_mom = 1 - global_params.batch_norm_momentum # pytorch's difference from tensorflow + self._bn_eps = global_params.batch_norm_epsilon + self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) + self.id_skip = block_args.id_skip # whether to use skip connection and drop connect + + # Expansion phase (Inverted Bottleneck) + inp = self._block_args.input_filters # number of input channels + oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels + if self._block_args.expand_ratio != 1: + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) + self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) + # image_size = calculate_output_image_size(image_size, 1) <-- this wouldn't modify image_size + + # Depthwise convolution phase + k = self._block_args.kernel_size + s = self._block_args.stride + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._depthwise_conv = Conv2d( + in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise + kernel_size=k, stride=s, bias=False) + self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) + image_size = calculate_output_image_size(image_size, s) + + # Squeeze and Excitation layer, if desired + if self.has_se: + Conv2d = get_same_padding_conv2d(image_size=(1, 1)) + num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) + self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) + self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) + + # Pointwise convolution phase + final_oup = self._block_args.output_filters + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) + self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) + self._swish = MemoryEfficientSwish() + + def forward(self, inputs, drop_connect_rate=None): + """MBConvBlock's forward function. + + Args: + inputs (tensor): Input tensor. + drop_connect_rate (bool): Drop connect rate (float, between 0 and 1). + + Returns: + Output of this block after processing. + """ + + # Expansion and Depthwise Convolution + x = inputs + if self._block_args.expand_ratio != 1: + x = self._expand_conv(inputs) + x = self._bn0(x) + x = self._swish(x) + + x = self._depthwise_conv(x) + x = self._bn1(x) + x = self._swish(x) + + # Squeeze and Excitation + if self.has_se: + x_squeezed = F.adaptive_avg_pool2d(x, 1) + x_squeezed = self._se_reduce(x_squeezed) + x_squeezed = self._swish(x_squeezed) + x_squeezed = self._se_expand(x_squeezed) + x = torch.sigmoid(x_squeezed) * x + + # Pointwise Convolution + x = self._project_conv(x) + x = self._bn2(x) + + # Skip connection and drop connect + input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters + if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: + # The combination of skip connection and drop connect brings about stochastic depth. + if drop_connect_rate: + x = drop_connect(x, p=drop_connect_rate, training=self.training) + x = x + inputs # skip connection + return x + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export). + + Args: + memory_efficient (bool): Whether to use memory-efficient version of swish. + """ + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + + +class EfficientNet(nn.Module): + """EfficientNet model. + Most easily loaded with the .from_name or .from_pretrained methods. + + Args: + blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks. + global_params (namedtuple): A set of GlobalParams shared between blocks. + + References: + [1] https://arxiv.org/abs/1905.11946 (EfficientNet) + + Example: + >>> import torch + >>> from efficientnet.model import EfficientNet + >>> inputs = torch.rand(1, 3, 224, 224) + >>> model = EfficientNet.from_pretrained('efficientnet-b0') + >>> model.eval() + >>> outputs = model(inputs) + """ + + def __init__(self, blocks_args=None, global_params=None): + super().__init__() + assert isinstance(blocks_args, list), 'blocks_args should be a list' + assert len(blocks_args) > 0, 'block args must be greater than 0' + self._global_params = global_params + self._blocks_args = blocks_args + + # Batch norm parameters + bn_mom = 1 - self._global_params.batch_norm_momentum + bn_eps = self._global_params.batch_norm_epsilon + + # Get stem static or dynamic convolution depending on image size + image_size = global_params.image_size + Conv2d = get_same_padding_conv2d(image_size=image_size) + + # Stem + in_channels = 3 # rgb + out_channels = round_filters(32, self._global_params) # number of output channels + self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) + self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + image_size = calculate_output_image_size(image_size, 2) + + # Build blocks + self._blocks = nn.ModuleList([]) + for block_args in self._blocks_args: + + # Update block input and output filters based on depth multiplier. + block_args = block_args._replace( + input_filters=round_filters(block_args.input_filters, self._global_params), + output_filters=round_filters(block_args.output_filters, self._global_params), + num_repeat=round_repeats(block_args.num_repeat, self._global_params) + ) + + # The first block needs to take care of stride and filter size increase. + self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) + image_size = calculate_output_image_size(image_size, block_args.stride) + if block_args.num_repeat > 1: # modify block_args to keep same output size + block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) + for _ in range(block_args.num_repeat - 1): + self._blocks.append(MBConvBlock(block_args, self._global_params, image_size=image_size)) + # image_size = calculate_output_image_size(image_size, block_args.stride) # stride = 1 + + # Head + in_channels = block_args.output_filters # output of final block + out_channels = round_filters(1280, self._global_params) + Conv2d = get_same_padding_conv2d(image_size=image_size) + self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) + + # Final linear layer + self._avg_pooling = nn.AdaptiveAvgPool2d(1) + if self._global_params.include_top: + self._dropout = nn.Dropout(self._global_params.dropout_rate) + self._fc = nn.Linear(out_channels, self._global_params.num_classes) + + # set activation to memory efficient swish by default + self._swish = MemoryEfficientSwish() + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export). + + Args: + memory_efficient (bool): Whether to use memory-efficient version of swish. + """ + self._swish = MemoryEfficientSwish() if memory_efficient else Swish() + for block in self._blocks: + block.set_swish(memory_efficient) + + def extract_endpoints(self, inputs): + """Use convolution layer to extract features + from reduction levels i in [1, 2, 3, 4, 5]. + + Args: + inputs (tensor): Input tensor. + + Returns: + Dictionary of last intermediate features + with reduction levels i in [1, 2, 3, 4, 5]. + Example: + >>> import torch + >>> from efficientnet.model import EfficientNet + >>> inputs = torch.rand(1, 3, 224, 224) + >>> model = EfficientNet.from_pretrained('efficientnet-b0') + >>> endpoints = model.extract_endpoints(inputs) + >>> print(endpoints['reduction_1'].shape) # torch.Size([1, 16, 112, 112]) + >>> print(endpoints['reduction_2'].shape) # torch.Size([1, 24, 56, 56]) + >>> print(endpoints['reduction_3'].shape) # torch.Size([1, 40, 28, 28]) + >>> print(endpoints['reduction_4'].shape) # torch.Size([1, 112, 14, 14]) + >>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7]) + >>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7]) + """ + endpoints = dict() + + # Stem + x = self._swish(self._bn0(self._conv_stem(inputs))) + prev_x = x + + # Blocks + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate + x = block(x, drop_connect_rate=drop_connect_rate) + if prev_x.size(2) > x.size(2): + endpoints['reduction_{}'.format(len(endpoints) + 1)] = prev_x + elif idx == len(self._blocks) - 1: + endpoints['reduction_{}'.format(len(endpoints) + 1)] = x + prev_x = x + + # Head + x = self._swish(self._bn1(self._conv_head(x))) + endpoints['reduction_{}'.format(len(endpoints) + 1)] = x + + return endpoints + + def extract_features(self, inputs): + """use convolution layer to extract feature . + + Args: + inputs (tensor): Input tensor. + + Returns: + Output of the final convolution + layer in the efficientnet model. + """ + # Stem + x = self._swish(self._bn0(self._conv_stem(inputs))) + + # Blocks + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len(self._blocks) # scale drop connect_rate + x = block(x, drop_connect_rate=drop_connect_rate) + + # Head + x = self._swish(self._bn1(self._conv_head(x))) + + return x + + def forward(self, inputs): + """EfficientNet's forward function. + Calls extract_features to extract features, applies final linear layer, and returns logits. + + Args: + inputs (tensor): Input tensor. + + Returns: + Output of this model after processing. + """ + # Convolution layers + x = self.extract_features(inputs) + # Pooling and final linear layer + x = self._avg_pooling(x) + if self._global_params.include_top: + x = x.flatten(start_dim=1) + x = self._dropout(x) + x = self._fc(x) + return x + + @classmethod + def from_name(cls, model_name, in_channels=3, **override_params): + """Create an efficientnet model according to name. + + Args: + model_name (str): Name for efficientnet. + in_channels (int): Input data's channel number. + override_params (other key word params): + Params to override model's global_params. + Optional key: + 'width_coefficient', 'depth_coefficient', + 'image_size', 'dropout_rate', + 'num_classes', 'batch_norm_momentum', + 'batch_norm_epsilon', 'drop_connect_rate', + 'depth_divisor', 'min_depth' + + Returns: + An efficientnet model. + """ + cls._check_model_name_is_valid(model_name) + blocks_args, global_params = get_model_params(model_name, override_params) + model = cls(blocks_args, global_params) + model._change_in_channels(in_channels) + return model + + @classmethod + def from_pretrained(cls, model_name, weights_path=None, advprop=False, + in_channels=3, num_classes=1000, **override_params): + """Create an efficientnet model according to name. + + Args: + model_name (str): Name for efficientnet. + weights_path (None or str): + str: path to pretrained weights file on the local disk. + None: use pretrained weights downloaded from the Internet. + advprop (bool): + Whether to load pretrained weights + trained with advprop (valid when weights_path is None). + in_channels (int): Input data's channel number. + num_classes (int): + Number of categories for classification. + It controls the output size for final linear layer. + override_params (other key word params): + Params to override model's global_params. + Optional key: + 'width_coefficient', 'depth_coefficient', + 'image_size', 'dropout_rate', + 'batch_norm_momentum', + 'batch_norm_epsilon', 'drop_connect_rate', + 'depth_divisor', 'min_depth' + + Returns: + A pretrained efficientnet model. + """ + model = cls.from_name(model_name, num_classes=num_classes, **override_params) + load_pretrained_weights(model, model_name, weights_path=weights_path, + load_fc=(num_classes == 1000), advprop=advprop) + model._change_in_channels(in_channels) + return model + + @classmethod + def get_image_size(cls, model_name): + """Get the input image size for a given efficientnet model. + + Args: + model_name (str): Name for efficientnet. + + Returns: + Input image size (resolution). + """ + cls._check_model_name_is_valid(model_name) + _, _, res, _ = efficientnet_params(model_name) + return res + + @classmethod + def _check_model_name_is_valid(cls, model_name): + """Validates model name. + + Args: + model_name (str): Name for efficientnet. + + Returns: + bool: Is a valid name or not. + """ + if model_name not in VALID_MODELS: + raise ValueError('model_name should be one of: ' + ', '.join(VALID_MODELS)) + + def _change_in_channels(self, in_channels): + """Adjust model's first convolution layer to in_channels, if in_channels not equals 3. + + Args: + in_channels (int): Input data's channel number. + """ + if in_channels != 3: + Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size) + out_channels = round_filters(32, self._global_params) + self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) diff --git a/deepafx_st/models/efficient_net/utils.py b/deepafx_st/models/efficient_net/utils.py new file mode 100755 index 0000000..826a627 --- /dev/null +++ b/deepafx_st/models/efficient_net/utils.py @@ -0,0 +1,616 @@ +"""utils.py - Helper functions for building the model and for loading model parameters. + These helper functions are built to mirror those in the official TensorFlow implementation. +""" + +# Author: lukemelas (github username) +# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch +# With adjustments and added comments by workingcoder (github username). + +import re +import math +import collections +from functools import partial +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils import model_zoo + + +################################################################################ +# Help functions for model architecture +################################################################################ + +# GlobalParams and BlockArgs: Two namedtuples +# Swish and MemoryEfficientSwish: Two implementations of the method +# round_filters and round_repeats: +# Functions to calculate params for scaling model width and depth ! ! ! +# get_width_and_height_from_size and calculate_output_image_size +# drop_connect: A structural design +# get_same_padding_conv2d: +# Conv2dDynamicSamePadding +# Conv2dStaticSamePadding +# get_same_padding_maxPool2d: +# MaxPool2dDynamicSamePadding +# MaxPool2dStaticSamePadding +# It's an additional function, not used in EfficientNet, +# but can be used in other model (such as EfficientDet). + +# Parameters for the entire model (stem, all blocks, and head) +GlobalParams = collections.namedtuple('GlobalParams', [ + 'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate', + 'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon', + 'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top']) + +# Parameters for an individual model block +BlockArgs = collections.namedtuple('BlockArgs', [ + 'num_repeat', 'kernel_size', 'stride', 'expand_ratio', + 'input_filters', 'output_filters', 'se_ratio', 'id_skip']) + +# Set GlobalParams and BlockArgs's defaults +GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) +BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) + +# Swish activation function +if hasattr(nn, 'SiLU'): + Swish = nn.SiLU +else: + # For compatibility with old PyTorch versions + class Swish(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +# A memory-efficient implementation of Swish function +class SwishImplementation(torch.autograd.Function): + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_tensors[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + + +class MemoryEfficientSwish(nn.Module): + def forward(self, x): + return SwishImplementation.apply(x) + + +def round_filters(filters, global_params): + """Calculate and round number of filters based on width multiplier. + Use width_coefficient, depth_divisor and min_depth of global_params. + + Args: + filters (int): Filters number to be calculated. + global_params (namedtuple): Global params of the model. + + Returns: + new_filters: New filters number after calculating. + """ + multiplier = global_params.width_coefficient + if not multiplier: + return filters + # TODO: modify the params names. + # maybe the names (width_divisor,min_width) + # are more suitable than (depth_divisor,min_depth). + divisor = global_params.depth_divisor + min_depth = global_params.min_depth + filters *= multiplier + min_depth = min_depth or divisor # pay attention to this line when using min_depth + # follow the formula transferred from official TensorFlow implementation + new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) + if new_filters < 0.9 * filters: # prevent rounding by more than 10% + new_filters += divisor + return int(new_filters) + + +def round_repeats(repeats, global_params): + """Calculate module's repeat number of a block based on depth multiplier. + Use depth_coefficient of global_params. + + Args: + repeats (int): num_repeat to be calculated. + global_params (namedtuple): Global params of the model. + + Returns: + new repeat: New repeat number after calculating. + """ + multiplier = global_params.depth_coefficient + if not multiplier: + return repeats + # follow the formula transferred from official TensorFlow implementation + return int(math.ceil(multiplier * repeats)) + + +def drop_connect(inputs, p, training): + """Drop connect. + + Args: + input (tensor: BCWH): Input of this structure. + p (float: 0.0~1.0): Probability of drop connection. + training (bool): The running mode. + + Returns: + output: Output after drop connection. + """ + assert 0 <= p <= 1, 'p must be in range of [0,1]' + + if not training: + return inputs + + batch_size = inputs.shape[0] + keep_prob = 1 - p + + # generate binary_tensor mask according to probability (p for 0, 1-p for 1) + random_tensor = keep_prob + random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) + binary_tensor = torch.floor(random_tensor) + + output = inputs / keep_prob * binary_tensor + return output + + +def get_width_and_height_from_size(x): + """Obtain height and width from x. + + Args: + x (int, tuple or list): Data size. + + Returns: + size: A tuple or list (H,W). + """ + if isinstance(x, int): + return x, x + if isinstance(x, list) or isinstance(x, tuple): + return x + else: + raise TypeError() + + +def calculate_output_image_size(input_image_size, stride): + """Calculates the output image size when using Conv2dSamePadding with a stride. + Necessary for static padding. Thanks to mannatsingh for pointing this out. + + Args: + input_image_size (int, tuple or list): Size of input image. + stride (int, tuple or list): Conv2d operation's stride. + + Returns: + output_image_size: A list [H,W]. + """ + if input_image_size is None: + return None + image_height, image_width = get_width_and_height_from_size(input_image_size) + stride = stride if isinstance(stride, int) else stride[0] + image_height = int(math.ceil(image_height / stride)) + image_width = int(math.ceil(image_width / stride)) + return [image_height, image_width] + + +# Note: +# The following 'SamePadding' functions make output size equal ceil(input size/stride). +# Only when stride equals 1, can the output size be the same as input size. +# Don't be confused by their function names ! ! ! + +def get_same_padding_conv2d(image_size=None): + """Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. + + Args: + image_size (int or tuple): Size of the image. + + Returns: + Conv2dDynamicSamePadding or Conv2dStaticSamePadding. + """ + if image_size is None: + return Conv2dDynamicSamePadding + else: + return partial(Conv2dStaticSamePadding, image_size=image_size) + + +class Conv2dDynamicSamePadding(nn.Conv2d): + """2D Convolutions like TensorFlow, for a dynamic image size. + The padding is operated in forward function by calculating dynamically. + """ + + # Tips for 'SAME' mode padding. + # Given the following: + # i: width or height + # s: stride + # k: kernel size + # d: dilation + # p: padding + # Output after Conv2d: + # o = floor((i+p-((k-1)*d+1))/s+1) + # If o equals i, i = floor((i+p-((k-1)*d+1))/s+1), + # => p = (i-1)*s+((k-1)*d+1)-i + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): + super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.weight.size()[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) # change the output size according to stride ! ! ! + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class Conv2dStaticSamePadding(nn.Conv2d): + """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size. + The padding mudule is calculated in construction function, then used in forward. + """ + + # With the same calculation as Conv2dDynamicSamePadding + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, image_size=None, **kwargs): + super().__init__(in_channels, out_channels, kernel_size, stride, **kwargs) + self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 + + # Calculate padding based on image size and save it + assert image_size is not None + ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size + kh, kw = self.weight.size()[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, + pad_h // 2, pad_h - pad_h // 2)) + else: + self.static_padding = nn.Identity() + + def forward(self, x): + x = self.static_padding(x) + x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return x + + +def get_same_padding_maxPool2d(image_size=None): + """Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. + + Args: + image_size (int or tuple): Size of the image. + + Returns: + MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding. + """ + if image_size is None: + return MaxPool2dDynamicSamePadding + else: + return partial(MaxPool2dStaticSamePadding, image_size=image_size) + + +class MaxPool2dDynamicSamePadding(nn.MaxPool2d): + """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size. + The padding is operated in forward function by calculating dynamically. + """ + + def __init__(self, kernel_size, stride, padding=0, dilation=1, return_indices=False, ceil_mode=False): + super().__init__(kernel_size, stride, padding, dilation, return_indices, ceil_mode) + self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride + self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size + self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation + + def forward(self, x): + ih, iw = x.size()[-2:] + kh, kw = self.kernel_size + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return F.max_pool2d(x, self.kernel_size, self.stride, self.padding, + self.dilation, self.ceil_mode, self.return_indices) + + +class MaxPool2dStaticSamePadding(nn.MaxPool2d): + """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size. + The padding mudule is calculated in construction function, then used in forward. + """ + + def __init__(self, kernel_size, stride, image_size=None, **kwargs): + super().__init__(kernel_size, stride, **kwargs) + self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride + self.kernel_size = [self.kernel_size] * 2 if isinstance(self.kernel_size, int) else self.kernel_size + self.dilation = [self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation + + # Calculate padding based on image size and save it + assert image_size is not None + ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size + kh, kw = self.kernel_size + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) + else: + self.static_padding = nn.Identity() + + def forward(self, x): + x = self.static_padding(x) + x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding, + self.dilation, self.ceil_mode, self.return_indices) + return x + + +################################################################################ +# Helper functions for loading model params +################################################################################ + +# BlockDecoder: A Class for encoding and decoding BlockArgs +# efficientnet_params: A function to query compound coefficient +# get_model_params and efficientnet: +# Functions to get BlockArgs and GlobalParams for efficientnet +# url_map and url_map_advprop: Dicts of url_map for pretrained weights +# load_pretrained_weights: A function to load pretrained weights + +class BlockDecoder(object): + """Block Decoder for readability, + straight from the official TensorFlow repository. + """ + + @staticmethod + def _decode_block_string(block_string): + """Get a block through a string notation of arguments. + + Args: + block_string (str): A string notation of arguments. + Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'. + + Returns: + BlockArgs: The namedtuple defined at the top of this file. + """ + assert isinstance(block_string, str) + + ops = block_string.split('_') + options = {} + for op in ops: + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # Check stride + assert (('s' in options and len(options['s']) == 1) or + (len(options['s']) == 2 and options['s'][0] == options['s'][1])) + + return BlockArgs( + num_repeat=int(options['r']), + kernel_size=int(options['k']), + stride=[int(options['s'][0])], + expand_ratio=int(options['e']), + input_filters=int(options['i']), + output_filters=int(options['o']), + se_ratio=float(options['se']) if 'se' in options else None, + id_skip=('noskip' not in block_string)) + + @staticmethod + def _encode_block_string(block): + """Encode a block to a string. + + Args: + block (namedtuple): A BlockArgs type argument. + + Returns: + block_string: A String form of BlockArgs. + """ + args = [ + 'r%d' % block.num_repeat, + 'k%d' % block.kernel_size, + 's%d%d' % (block.strides[0], block.strides[1]), + 'e%s' % block.expand_ratio, + 'i%d' % block.input_filters, + 'o%d' % block.output_filters + ] + if 0 < block.se_ratio <= 1: + args.append('se%s' % block.se_ratio) + if block.id_skip is False: + args.append('noskip') + return '_'.join(args) + + @staticmethod + def decode(string_list): + """Decode a list of string notations to specify blocks inside the network. + + Args: + string_list (list[str]): A list of strings, each string is a notation of block. + + Returns: + blocks_args: A list of BlockArgs namedtuples of block args. + """ + assert isinstance(string_list, list) + blocks_args = [] + for block_string in string_list: + blocks_args.append(BlockDecoder._decode_block_string(block_string)) + return blocks_args + + @staticmethod + def encode(blocks_args): + """Encode a list of BlockArgs to a list of strings. + + Args: + blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args. + + Returns: + block_strings: A list of strings, each string is a notation of block. + """ + block_strings = [] + for block in blocks_args: + block_strings.append(BlockDecoder._encode_block_string(block)) + return block_strings + + +def efficientnet_params(model_name): + """Map EfficientNet model name to parameter coefficients. + + Args: + model_name (str): Model name to be queried. + + Returns: + params_dict[model_name]: A (width,depth,res,dropout) tuple. + """ + params_dict = { + # Coefficients: width,depth,res,dropout + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + 'efficientnet-l2': (4.3, 5.3, 800, 0.5), + } + return params_dict[model_name] + + +def efficientnet(width_coefficient=None, depth_coefficient=None, image_size=None, + dropout_rate=0.2, drop_connect_rate=0.2, num_classes=1000, include_top=True): + """Create BlockArgs and GlobalParams for efficientnet model. + + Args: + width_coefficient (float) + depth_coefficient (float) + image_size (int) + dropout_rate (float) + drop_connect_rate (float) + num_classes (int) + + Meaning as the name suggests. + + Returns: + blocks_args, global_params. + """ + + # Blocks args for the whole model(efficientnet-b0 by default) + # It will be modified in the construction of EfficientNet Class according to model + blocks_args = [ + 'r1_k3_s11_e1_i32_o16_se0.25', + 'r2_k3_s22_e6_i16_o24_se0.25', + 'r2_k5_s22_e6_i24_o40_se0.25', + 'r3_k3_s22_e6_i40_o80_se0.25', + 'r3_k5_s11_e6_i80_o112_se0.25', + 'r4_k5_s22_e6_i112_o192_se0.25', + 'r1_k3_s11_e6_i192_o320_se0.25', + ] + blocks_args = BlockDecoder.decode(blocks_args) + + global_params = GlobalParams( + width_coefficient=width_coefficient, + depth_coefficient=depth_coefficient, + image_size=image_size, + dropout_rate=dropout_rate, + + num_classes=num_classes, + batch_norm_momentum=0.99, + batch_norm_epsilon=1e-3, + drop_connect_rate=drop_connect_rate, + depth_divisor=8, + min_depth=None, + include_top=include_top, + ) + + return blocks_args, global_params + + +def get_model_params(model_name, override_params): + """Get the block args and global params for a given model name. + + Args: + model_name (str): Model's name. + override_params (dict): A dict to modify global_params. + + Returns: + blocks_args, global_params + """ + if model_name.startswith('efficientnet'): + w, d, s, p = efficientnet_params(model_name) + # note: all models have drop connect rate = 0.2 + blocks_args, global_params = efficientnet( + width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s) + else: + raise NotImplementedError('model name is not pre-defined: {}'.format(model_name)) + if override_params: + # ValueError will be raised here if override_params has fields not included in global_params. + global_params = global_params._replace(**override_params) + return blocks_args, global_params + + +# train with Standard methods +# check more details in paper(EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks) +url_map = { + 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth', + 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth', + 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth', + 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth', + 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth', + 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth', + 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth', + 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth', +} + +# train with Adversarial Examples(AdvProp) +# check more details in paper(Adversarial Examples Improve Image Recognition) +url_map_advprop = { + 'efficientnet-b0': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b0-b64d5a18.pth', + 'efficientnet-b1': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b1-0f3ce85a.pth', + 'efficientnet-b2': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b2-6e9d97e5.pth', + 'efficientnet-b3': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b3-cdd7c0f4.pth', + 'efficientnet-b4': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b4-44fb3a87.pth', + 'efficientnet-b5': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b5-86493f6b.pth', + 'efficientnet-b6': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b6-ac80338e.pth', + 'efficientnet-b7': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b7-4652b6dd.pth', + 'efficientnet-b8': 'https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/adv-efficientnet-b8-22a8fe65.pth', +} + +# TODO: add the petrained weights url map of 'efficientnet-l2' + + +def load_pretrained_weights(model, model_name, weights_path=None, load_fc=True, advprop=False, verbose=True): + """Loads pretrained weights from weights path or download using url. + + Args: + model (Module): The whole model of efficientnet. + model_name (str): Model name of efficientnet. + weights_path (None or str): + str: path to pretrained weights file on the local disk. + None: use pretrained weights downloaded from the Internet. + load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. + advprop (bool): Whether to load pretrained weights + trained with advprop (valid when weights_path is None). + """ + if isinstance(weights_path, str): + state_dict = torch.load(weights_path) + else: + # AutoAugment or Advprop (different preprocessing) + url_map_ = url_map_advprop if advprop else url_map + state_dict = model_zoo.load_url(url_map_[model_name]) + + if load_fc: + ret = model.load_state_dict(state_dict, strict=False) + assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) + else: + state_dict.pop('_fc.weight') + state_dict.pop('_fc.bias') + ret = model.load_state_dict(state_dict, strict=False) + assert set(ret.missing_keys) == set( + ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys) + assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys) + + if verbose: + print('Loaded pretrained weights for {}'.format(model_name)) diff --git a/deepafx_st/models/encoder.py b/deepafx_st/models/encoder.py new file mode 100755 index 0000000..9c7bc63 --- /dev/null +++ b/deepafx_st/models/encoder.py @@ -0,0 +1,113 @@ +import torch + +from deepafx_st.models.mobilenetv2 import MobileNetV2 +from deepafx_st.models.efficient_net import EfficientNet + + +class SpectralEncoder(torch.nn.Module): + def __init__( + self, + num_params, + sample_rate, + encoder_model="mobilenet_v2", + embed_dim=1028, + width_mult=1, + min_level_db=-80, + ): + """Encoder operating on spectrograms. + + Args: + num_params (int): Number of processor parameters to generate. + sample_rate (float): Audio sample rate for computing melspectrogram. + encoder_model (str, optional): Encoder model architecture. Default: "mobilenet_v2" + embed_dim (int, optional): Dimentionality of the encoder representations. + width_mult (int, optional): Encoder size. Default: 1 + min_level_db (float, optional): Minimal dB value for the spectrogram. Default: -80 + """ + super().__init__() + self.num_params = num_params + self.sample_rate = sample_rate + self.encoder_model = encoder_model + self.embed_dim = embed_dim + self.width_mult = width_mult + self.min_level_db = min_level_db + + # load model from torch.hub + if encoder_model == "mobilenet_v2": + self.encoder = MobileNetV2(embed_dim=embed_dim, width_mult=width_mult) + elif encoder_model == "efficient_net": + self.encoder = EfficientNet.from_name( + "efficientnet-b2", + in_channels=1, + image_size=(128, 65), + include_top=False, + ) + self.embedding_projection = torch.nn.Conv2d( + in_channels=1408, + out_channels=embed_dim, + kernel_size=(1, 1), + stride=(1, 1), + padding=(0, 0), + bias=True, + ) + + else: + raise ValueError(f"Invalid encoder_model: {encoder_model}.") + + self.window = torch.nn.Parameter(torch.hann_window(4096)) + + def forward(self, x): + """ + Args: + x (Tensor): Input waveform of shape [batch x channels x samples] + + Returns: + e (Tensor): Latent embedding produced by Encoder. [batch x embed_dim] + """ + bs, chs, samp = x.size() + + # compute spectrogram of waveform + X = torch.stft( + x.view(bs, -1), + 4096, + 2048, + window=self.window, + return_complex=True, + ) + X_db = torch.pow(X.abs() + 1e-8, 0.3) + X_db_norm = X_db + + # standardize (0, 1) 0.322970 0.278452 + X_db_norm -= 0.322970 + X_db_norm /= 0.278452 + X_db_norm = X_db_norm.unsqueeze(1).permute(0, 1, 3, 2) + + if self.encoder_model == "mobilenet_v2": + # repeat channels by 3 to fit vision model + X_db_norm = X_db_norm.repeat(1, 3, 1, 1) + + # pass melspectrogram through encoder + e = self.encoder(X_db_norm) + + # apply avg pooling across time for encoder embeddings + e = torch.nn.functional.adaptive_avg_pool2d(e, 1).reshape(e.shape[0], -1) + + # normalize by L2 norm + norm = torch.norm(e, p=2, dim=-1, keepdim=True) + e_norm = e / norm + + elif self.encoder_model == "efficient_net": + + # Efficient Net internal downsamples by 32 on time and freq axis, then average pools the rest + e = self.encoder(X_db_norm) + + # Adding 1x1 conv to project down or up to the requested embedding size + e = self.embedding_projection(e) + e = torch.squeeze(e, dim=3) + e = torch.squeeze(e, dim=2) + + # normalize by L2 norm + norm = torch.norm(e, p=2, dim=-1, keepdim=True) + e_norm = e / norm + + return e_norm diff --git a/deepafx_st/models/mobilenetv2.py b/deepafx_st/models/mobilenetv2.py new file mode 100755 index 0000000..20e5c56 --- /dev/null +++ b/deepafx_st/models/mobilenetv2.py @@ -0,0 +1,226 @@ +# BSD 3-Clause License + +# Copyright (c) Soumith Chintala 2016, +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Adaptation of the PyTorch torchvision MobileNetV2 without a classifier. +# See source here: https://pytorch.org/vision/0.8/_modules/torchvision/models/mobilenet.html#mobilenet_v2 +from torch import nn + + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Sequential): + def __init__( + self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None + ): + padding = (kernel_size - 1) // 2 + if norm_layer is None: + norm_layer = nn.BatchNorm2d + super(ConvBNReLU, self).__init__( + nn.Conv2d( + in_planes, + out_planes, + kernel_size, + stride, + padding, + groups=groups, + bias=False, + ), + norm_layer(out_planes), + nn.ReLU6(inplace=True), + ) + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + if norm_layer is None: + norm_layer = nn.BatchNorm2d + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append( + ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer) + ) + layers.extend( + [ + # dw + ConvBNReLU( + hidden_dim, + hidden_dim, + stride=stride, + groups=hidden_dim, + norm_layer=norm_layer, + ), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + norm_layer(oup), + ] + ) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__( + self, + embed_dim=1028, + width_mult=1.0, + inverted_residual_setting=None, + round_nearest=8, + block=None, + norm_layer=None, + ): + """ + MobileNet V2 main class + + Args: + embed_dim (int): Number of channels in the final output. + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + norm_layer: Module specifying the normalization layer to use + + """ + super(MobileNetV2, self).__init__() + + if block is None: + block = InvertedResidual + + if norm_layer is None: + norm_layer = nn.BatchNorm2d + + input_channel = 32 + last_channel = embed_dim / width_mult + + if inverted_residual_setting is None: + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + [6, 160, 3, 2], + [6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if ( + len(inverted_residual_setting) == 0 + or len(inverted_residual_setting[0]) != 4 + ): + raise ValueError( + "inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting) + ) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible( + last_channel * max(1.0, width_mult), round_nearest + ) + features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append( + block( + input_channel, + output_channel, + stride, + expand_ratio=t, + norm_layer=norm_layer, + ) + ) + input_channel = output_channel + # building last several layers + features.append( + ConvBNReLU( + input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer + ) + ) + # make it nn.Sequential + self.features = nn.Sequential(*features) + + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out") + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + def _forward_impl(self, x): + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass + return self.features(x) + # return the features directly, no classifier or pooling + + def forward(self, x): + return self._forward_impl(x) diff --git a/deepafx_st/probes/cdpam_encoder.py b/deepafx_st/probes/cdpam_encoder.py new file mode 100644 index 0000000..5e20a4d --- /dev/null +++ b/deepafx_st/probes/cdpam_encoder.py @@ -0,0 +1,68 @@ +# MIT License + +# Copyright (c) 2021 Pranay Manocha + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# code adapated from https://github.com/pranaymanocha/PerceptualAudio + +import cdpam +import torch + + +class CDPAMEncoder(torch.nn.Module): + def __init__(self, cdpam_ckpt: str): + super().__init__() + + # pre-trained model parameterss + encoder_layers = 16 + encoder_filters = 64 + input_size = 512 + proj_ndim = [512, 256] + ndim = [16, 6] + classif_BN = 0 + classif_act = "no" + proj_dp = 0.1 + proj_BN = 1 + classif_dp = 0.05 + + model = cdpam.models.FINnet( + encoder_layers=encoder_layers, + encoder_filters=encoder_filters, + ndim=ndim, + classif_dp=classif_dp, + classif_BN=classif_BN, + classif_act=classif_act, + input_size=input_size, + ) + + state = torch.load(cdpam_ckpt, map_location="cpu")["state"] + model.load_state_dict(state) + model.eval() + + self.model = model + self.embed_dim = 512 + + def forward(self, x): + + with torch.no_grad(): + _, a1, c1 = self.model.base_encoder.forward(x) + a1 = torch.nn.functional.normalize(a1, dim=1) + + return a1 diff --git a/deepafx_st/probes/probe_system.py b/deepafx_st/probes/probe_system.py new file mode 100644 index 0000000..9dd61c7 --- /dev/null +++ b/deepafx_st/probes/probe_system.py @@ -0,0 +1,307 @@ +import torch +import julius +import torchopenl3 +import torchmetrics +import pytorch_lightning as pl +from typing import Tuple, List, Dict +from argparse import ArgumentParser + +from deepafx_st.probes.cdpam_encoder import CDPAMEncoder +from deepafx_st.probes.random_mel import RandomMelProjection + +import deepafx_st.utils as utils +from deepafx_st.utils import DSPMode +from deepafx_st.system import System +from deepafx_st.data.style import StyleDataset + + +class ProbeSystem(pl.LightningModule): + def __init__( + self, + audio_dir=None, + num_classes=5, + task="style", + encoder_type="deepafx_st_autodiff", + deepafx_st_autodiff_ckpt=None, + deepafx_st_spsa_ckpt=None, + deepafx_st_proxy0_ckpt=None, + probe_type="linear", + batch_size=32, + lr=3e-4, + lr_patience=20, + patience=10, + preload=False, + sample_rate=24000, + shuffle=True, + num_workers=16, + **kwargs, + ): + super().__init__() + self.save_hyperparameters() + + if "deepafx_st" in self.hparams.encoder_type: + + if "autodiff" in self.hparams.encoder_type: + self.hparams.deepafx_st_ckpt = self.hparams.deepafx_st_autodiff_ckpt + elif "spsa" in self.hparams.encoder_type: + self.hparams.deepafx_st_ckpt = self.hparams.deepafx_st_spsa_ckpt + elif "proxy0" in self.hparams.encoder_type: + self.hparams.deepafx_st_ckpt = self.hparams.deepafx_st_proxy0_ckpt + + else: + raise RuntimeError(f"Invalid encoder_type: {self.hparams.encoder_type}") + + if self.hparams.deepafx_st_ckpt is None: + raise RuntimeError( + f"Must supply {self.hparams.encoder_type}_ckpt checkpoint." + ) + use_dsp = DSPMode.NONE + system = System.load_from_checkpoint( + self.hparams.deepafx_st_ckpt, + use_dsp=use_dsp, + batch_size=self.hparams.batch_size, + spsa_parallel=False, + proxy_ckpts=[], + strict=False, + ) + system.eval() + self.encoder = system.encoder + self.hparams.embed_dim = self.encoder.embed_dim + + # freeze weights + for name, param in self.encoder.named_parameters(): + param.requires_grad = False + + elif self.hparams.encoder_type == "openl3": + self.encoder = torchopenl3.models.load_audio_embedding_model( + input_repr=self.hparams.openl3_input_repr, + embedding_size=self.hparams.openl3_embedding_size, + content_type=self.hparams.openl3_content_type, + ) + self.hparams.embed_dim = 6144 + elif self.hparams.encoder_type == "random_mel": + self.encoder = RandomMelProjection( + self.hparams.sample_rate, + self.hparams.random_mel_embedding_size, + self.hparams.random_mel_n_mels, + self.hparams.random_mel_n_fft, + self.hparams.random_mel_hop_size, + ) + self.hparams.embed_dim = self.hparams.random_mel_embedding_size + elif self.hparams.encoder_type == "cdpam": + self.encoder = CDPAMEncoder(self.hparams.cdpam_ckpt) + self.encoder.eval() + self.hparams.embed_dim = self.encoder.embed_dim + else: + raise ValueError(f"Invalid encoder_type: {self.hparams.encoder_type}") + + if self.hparams.probe_type == "linear": + if self.hparams.task == "style": + self.probe = torch.nn.Sequential( + torch.nn.Linear(self.hparams.embed_dim, self.hparams.num_classes), + # torch.nn.Softmax(-1), + ) + elif self.hparams.probe_type == "mlp": + if self.hparams.task == "style": + self.probe = torch.nn.Sequential( + torch.nn.Linear(self.hparams.embed_dim, 512), + torch.nn.ReLU(), + torch.nn.Linear(512, 512), + torch.nn.ReLU(), + torch.nn.Linear(512, self.hparams.num_classes), + ) + self.accuracy = torchmetrics.Accuracy() + self.f1_score = torchmetrics.F1Score(self.hparams.num_classes) + + def forward(self, x): + bs, chs, samp = x.size() + with torch.no_grad(): + if "deepafx_st" in self.hparams.encoder_type: + x /= x.abs().max() + x *= 10 ** (-12.0 / 20) # with min 12 dBFS headroom + e = self.encoder(x) + norm = torch.norm(e, p=2, dim=-1, keepdim=True) + e = e / norm + elif self.hparams.encoder_type == "openl3": + # x = julius.resample_frac(x, self.hparams.sample_rate, 48000) + e, ts = torchopenl3.get_audio_embedding( + x, + 48000, + model=self.encoder, + input_repr="mel128", + content_type="music", + ) + e = e.permute(0, 2, 1) + e = e.mean(dim=-1) + # normalize by L2 norm + norm = torch.norm(e, p=2, dim=-1, keepdim=True) + e = e / norm + elif self.hparams.encoder_type == "random_mel": + e = self.encoder(x) + norm = torch.norm(e, p=2, dim=-1, keepdim=True) + e = e / norm + elif self.hparams.encoder_type == "cdpam": + # x = julius.resample_frac(x, self.hparams.sample_rate, 22050) + x = torch.round(x * 32768) + e = self.encoder(x) + + return self.probe(e) + + def common_step( + self, + batch: Tuple, + batch_idx: int, + optimizer_idx: int = 0, + train: bool = True, + ): + loss = 0 + x, y = batch + + y_hat = self(x) + + # compute CE + if self.hparams.task == "style": + loss = torch.nn.functional.cross_entropy(y_hat, y) + + if not train: + # store audio data + data_dict = {"x": x.float().cpu()} + else: + data_dict = {} + + self.log( + "train_loss" if train else "val_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=False, + logger=True, + sync_dist=True, + ) + + if not train and self.hparams.task == "style": + self.log("val_acc_step", self.accuracy(y_hat, y)) + self.log("val_f1_step", self.f1_score(y_hat, y)) + + return loss, data_dict + + def training_step(self, batch, batch_idx, optimizer_idx=0): + loss, _ = self.common_step(batch, batch_idx) + return loss + + def validation_step(self, batch, batch_idx): + loss, data_dict = self.common_step(batch, batch_idx, train=False) + + if batch_idx == 0: + return data_dict + + def validation_epoch_end(self, outputs) -> None: + if self.hparams.task == "style": + self.log("val_acc_epoch", self.accuracy.compute()) + self.log("val_f1_epoch", self.f1_score.compute()) + + return super().validation_epoch_end(outputs) + + def configure_optimizers(self): + optimizer = torch.optim.AdamW( + self.probe.parameters(), + lr=self.hparams.lr, + betas=(0.9, 0.999), + ) + + ms1 = int(self.hparams.max_epochs * 0.8) + ms2 = int(self.hparams.max_epochs * 0.95) + print( + "Learning rate schedule:", + f"0 {self.hparams.lr:0.2e} -> ", + f"{ms1} {self.hparams.lr*0.1:0.2e} -> ", + f"{ms2} {self.hparams.lr*0.01:0.2e}", + ) + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[ms1, ms2], + gamma=0.1, + ) + + return [optimizer], {"scheduler": scheduler, "monitor": "val_loss"} + + def train_dataloader(self): + + if self.hparams.task == "style": + train_dataset = StyleDataset( + self.hparams.audio_dir, + "train", + sample_rate=self.hparams.encoder_sample_rate, + ) + + g = torch.Generator() + g.manual_seed(0) + + return torch.utils.data.DataLoader( + train_dataset, + num_workers=self.hparams.num_workers, + batch_size=self.hparams.batch_size, + shuffle=True, + worker_init_fn=utils.seed_worker, + generator=g, + pin_memory=True, + ) + + def val_dataloader(self): + + if self.hparams.task == "style": + val_dataset = StyleDataset( + self.hparams.audio_dir, + subset="val", + sample_rate=self.hparams.encoder_sample_rate, + ) + + g = torch.Generator() + g.manual_seed(0) + + return torch.utils.data.DataLoader( + val_dataset, + num_workers=self.hparams.num_workers, + batch_size=self.hparams.batch_size, + worker_init_fn=utils.seed_worker, + generator=g, + pin_memory=True, + ) + + # add any model hyperparameters here + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + # --- Model --- + parser.add_argument("--encoder_type", type=str, default="deeapfx2") + parser.add_argument("--probe_type", type=str, default="linear") + parser.add_argument("--task", type=str, default="style") + parser.add_argument("--encoder_sample_rate", type=int, default=24000) + # --- deeapfx2 --- + parser.add_argument("--deepafx_st_autodiff_ckpt", type=str) + parser.add_argument("--deepafx_st_spsa_ckpt", type=str) + parser.add_argument("--deepafx_st_proxy0_ckpt", type=str) + + # --- cdpam --- + parser.add_argument("--cdpam_ckpt", type=str) + # --- openl3 --- + parser.add_argument("--openl3_input_repr", type=str, default="mel128") + parser.add_argument("--openl3_content_type", type=str, default="env") + parser.add_argument("--openl3_embedding_size", type=int, default=6144) + # --- random_mel --- + parser.add_argument("--random_mel_embedding_size", type=str, default=4096) + parser.add_argument("--random_mel_n_fft", type=str, default=4096) + parser.add_argument("--random_mel_hop_size", type=str, default=1024) + parser.add_argument("--random_mel_n_mels", type=str, default=128) + # --- Training --- + parser.add_argument("--audio_dir", type=str) + parser.add_argument("--num_classes", type=int, default=5) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--lr", type=float, default=3e-4) + parser.add_argument("--lr_patience", type=int, default=20) + parser.add_argument("--patience", type=int, default=10) + parser.add_argument("--preload", action="store_true") + parser.add_argument("--sample_rate", type=int, default=24000) + parser.add_argument("--num_workers", type=int, default=8) + + return parser diff --git a/deepafx_st/probes/random_mel.py b/deepafx_st/probes/random_mel.py new file mode 100644 index 0000000..a83db53 --- /dev/null +++ b/deepafx_st/probes/random_mel.py @@ -0,0 +1,93 @@ +import math +import torch +import librosa + +# based on https://github.com/neuralaudio/hear-baseline/blob/main/hearbaseline/naive.py + + +class RandomMelProjection(torch.nn.Module): + def __init__( + self, + sample_rate, + embed_dim=4096, + n_mels=128, + n_fft=4096, + hop_size=1024, + seed=0, + epsilon=1e-4, + ): + super().__init__() + self.sample_rate = sample_rate + self.embed_dim = embed_dim + self.n_mels = n_mels + self.n_fft = n_fft + self.hop_size = hop_size + self.seed = seed + self.epsilon = epsilon + + # Set random seed + torch.random.manual_seed(self.seed) + + # Create a Hann window buffer to apply to frames prior to FFT. + self.register_buffer("window", torch.hann_window(self.n_fft)) + + # Create a mel filter buffer. + mel_scale = torch.tensor( + librosa.filters.mel( + self.sample_rate, + n_fft=self.n_fft, + n_mels=self.n_mels, + ) + ) + self.register_buffer("mel_scale", mel_scale) + + # Projection matrices. + normalization = math.sqrt(self.n_mels) + self.projection = torch.nn.Parameter( + torch.rand(self.n_mels, self.embed_dim) / normalization, + requires_grad=False, + ) + + def forward(self, x): + bs, chs, samp = x.size() + + x = torch.stft( + x.view(bs, -1), + self.n_fft, + self.hop_size, + window=self.window, + return_complex=True, + ) + x = x.unsqueeze(1).permute(0, 1, 3, 2) + + # Apply the mel-scale filter to the power spectrum. + x = torch.matmul(x.abs(), self.mel_scale.transpose(0, 1)) + + # power scale + x = torch.pow(x + self.epsilon, 0.3) + + # apply random projection + e = x.matmul(self.projection) + + # take mean across temporal dim + e = e.mean(dim=2).view(bs, -1) + + return e + + def compute_frame_embedding(self, x): + # Compute the real-valued Fourier transform on windowed input signal. + x = torch.fft.rfft(x * self.window) + + # Convert to a power spectrum. + x = torch.abs(x) ** 2.0 + + # Apply the mel-scale filter to the power spectrum. + x = torch.matmul(x, self.mel_scale.transpose(0, 1)) + + # Convert to a log mel spectrum. + x = torch.log(x + self.epsilon) + + # Apply projection to get a 4096 dimension embedding + embedding = x.matmul(self.projection) + + return embedding diff --git a/deepafx_st/processors/autodiff/__init__.py b/deepafx_st/processors/autodiff/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/deepafx_st/processors/autodiff/channel.py b/deepafx_st/processors/autodiff/channel.py new file mode 100755 index 0000000..e48a3cc --- /dev/null +++ b/deepafx_st/processors/autodiff/channel.py @@ -0,0 +1,28 @@ +import torch + +from deepafx_st.processors.autodiff.compressor import Compressor +from deepafx_st.processors.autodiff.peq import ParametricEQ +from deepafx_st.processors.autodiff.fir import FIRFilter + + +class AutodiffChannel(torch.nn.Module): + def __init__(self, sample_rate): + super().__init__() + + self.peq = ParametricEQ(sample_rate) + self.comp = Compressor(sample_rate) + self.ports = [self.peq.ports, self.comp.ports] + self.num_control_params = ( + self.peq.num_control_params + self.comp.num_control_params + ) + + def forward(self, x, p, sample_rate=24000, **kwargs): + + # split params between EQ and Comp. + p_peq = p[:, : self.peq.num_control_params] + p_comp = p[:, self.peq.num_control_params :] + + y = self.peq(x, p_peq, sample_rate) + y = self.comp(y, p_comp, sample_rate) + + return y diff --git a/deepafx_st/processors/autodiff/compressor.py b/deepafx_st/processors/autodiff/compressor.py new file mode 100755 index 0000000..2e81cae --- /dev/null +++ b/deepafx_st/processors/autodiff/compressor.py @@ -0,0 +1,169 @@ +import math +import torch +import scipy.signal + +import deepafx_st.processors.autodiff.signal +from deepafx_st.processors.processor import Processor + + +@torch.jit.script +def compressor( + x: torch.Tensor, + sample_rate: float, + threshold: torch.Tensor, + ratio: torch.Tensor, + attack_time: torch.Tensor, + release_time: torch.Tensor, + knee_dB: torch.Tensor, + makeup_gain_dB: torch.Tensor, + eps: float = 1e-8, +): + """Note the `release` parameter is not used.""" + # print(f"autodiff comp fs = {sample_rate}") + + s = x.size() # should be one 1d + + threshold = threshold.squeeze() + ratio = ratio.squeeze() + attack_time = attack_time.squeeze() + makeup_gain_dB = makeup_gain_dB.squeeze() + + # uni-polar dB signal + # Turn the input signal into a uni-polar signal on the dB scale + x_G = 20 * torch.log10(torch.abs(x) + 1e-8) # x_uni casts type + + # Ensure there are no values of negative infinity + x_G = torch.clamp(x_G, min=-96) + + # Static characteristics with knee + y_G = torch.zeros(s).type_as(x) + + ratio = ratio.view(-1) + threshold = threshold.view(-1) + attack_time = attack_time.view(-1) + release_time = release_time.view(-1) + knee_dB = knee_dB.view(-1) + makeup_gain_dB = makeup_gain_dB.view(-1) + + # Below knee + idx = torch.where((2 * (x_G - threshold)) < -knee_dB)[0] + y_G[idx] = x_G[idx] + + # At knee + idx = torch.where((2 * torch.abs(x_G - threshold)) <= knee_dB)[0] + y_G[idx] = x_G[idx] + ( + (1 / ratio) * (((x_G[idx] - threshold + knee_dB) / 2) ** 2) + ) / (2 * knee_dB) + + # Above knee threshold + idx = torch.where((2 * (x_G - threshold)) > knee_dB)[0] + y_G[idx] = threshold + ((x_G[idx] - threshold) / ratio) + + x_L = x_G - y_G + + # design 1-pole butterworth lowpass + fc = 1.0 / (attack_time * sample_rate) + b, a = deepafx_st.processors.autodiff.signal.butter(fc) + + # apply FIR approx of IIR filter + y_L = deepafx_st.processors.autodiff.signal.approx_iir_filter(b, a, x_L) + + lin_y_L = torch.pow(10.0, -y_L / 20.0) # convert back to linear + y = lin_y_L * x # apply gain + + # apply makeup gain + y *= torch.pow(10.0, makeup_gain_dB / 20.0) + + return y + + +class Compressor(Processor): + def __init__( + self, + sample_rate, + max_threshold=0.0, + min_threshold=-80, + max_ratio=20.0, + min_ratio=1.0, + max_attack=0.1, + min_attack=0.0001, + max_release=1.0, + min_release=0.005, + max_knee=12.0, + min_knee=0.0, + max_mkgain=48.0, + min_mkgain=-48.0, + eps=1e-8, + ): + """ """ + super().__init__() + self.sample_rate = sample_rate + self.eps = eps + self.ports = [ + { + "name": "Threshold", + "min": min_threshold, + "max": max_threshold, + "default": -12.0, + "units": "dB", + }, + { + "name": "Ratio", + "min": min_ratio, + "max": max_ratio, + "default": 2.0, + "units": "", + }, + { + "name": "Attack", + "min": min_attack, + "max": max_attack, + "default": 0.001, + "units": "s", + }, + { + # this is a dummy parameter + "name": "Release (dummy)", + "min": min_release, + "max": max_release, + "default": 0.045, + "units": "s", + }, + { + "name": "Knee", + "min": min_knee, + "max": max_knee, + "default": 6.0, + "units": "dB", + }, + { + "name": "Makeup Gain", + "min": min_mkgain, + "max": max_mkgain, + "default": 0.0, + "units": "dB", + }, + ] + + self.num_control_params = len(self.ports) + + def forward(self, x, p, sample_rate=24000, **kwargs): + """ + + Assume that parameters in p are normalized between 0 and 1. + + x (tensor): Shape batch x 1 x samples + p (tensor): shape batch x params + + """ + bs, ch, s = x.size() + + inputs = torch.split(x, 1, 0) + params = torch.split(p, 1, 0) + + y = [] # loop over batch dimension + for input, param in zip(inputs, params): + denorm_param = self.denormalize_params(param.view(-1)) + y.append(compressor(input.view(-1), sample_rate, *denorm_param)) + + return torch.stack(y, dim=0).view(bs, 1, -1) diff --git a/deepafx_st/processors/autodiff/fir.py b/deepafx_st/processors/autodiff/fir.py new file mode 100755 index 0000000..c4d1aa1 --- /dev/null +++ b/deepafx_st/processors/autodiff/fir.py @@ -0,0 +1,68 @@ +import torch + + +class FIRFilter(torch.nn.Module): + def __init__(self, num_control_params=63): + super().__init__() + self.num_control_params = num_control_params + self.adaptor = torch.nn.Linear(num_control_params, num_control_params) + #self.batched_lfilter = torch.vmap(self.lfilter) + + def forward(self, x, b, **kwargs): + """Forward pass by appling FIR filter to each batch element. + + Args: + x (tensor): Input signals with shape (batch x 1 x samples) + b (tensor): Matrix of FIR filter coefficients with shape (batch x ntaps) + + """ + bs, ch, s = x.size() + b = self.adaptor(b) + + # pad input + x = torch.nn.functional.pad(x, (b.shape[-1] // 2, b.shape[-1] // 2)) + + # add extra dim for virutal batch dim + x = x.view(bs, 1, ch, -1) + b = b.view(bs, 1, 1, -1) + + # exlcuding vmap for now + y = self.batched_lfilter(x, b).view(bs, ch, s) + + return y + + @staticmethod + def lfilter(x, b): + return torch.nn.functional.conv1d(x, b) + + +class FrequencyDomainFIRFilter(torch.nn.Module): + def __init__(self, num_control_params=31): + super().__init__() + self.num_control_params = num_control_params + self.adaptor = torch.nn.Linear(num_control_params, num_control_params) + + def forward(self, x, b, **kwargs): + """Forward pass by appling FIR filter to each batch element. + + Args: + x (tensor): Input signals with shape (batch x 1 x samples) + b (tensor): Matrix of FIR filter coefficients with shape (batch x ntaps) + """ + bs, c, s = x.size() + + b = self.adaptor(b) + + # transform input to freq. domain + X = torch.fft.rfft(x.view(bs, -1)) + + # frequency response of filter + H = torch.fft.rfft(b.view(bs, -1)) + + # apply filter as multiplication in freq. domain + Y = X * H + + # transform back to time domain + y = torch.fft.ifft(Y).view(bs, 1, -1) + + return y diff --git a/deepafx_st/processors/autodiff/peq.py b/deepafx_st/processors/autodiff/peq.py new file mode 100755 index 0000000..04e35bb --- /dev/null +++ b/deepafx_st/processors/autodiff/peq.py @@ -0,0 +1,274 @@ +import torch + +import deepafx_st.processors.autodiff.signal +from deepafx_st.processors.processor import Processor + + +@torch.jit.script +def parametric_eq( + x: torch.Tensor, + sample_rate: float, + low_shelf_gain_dB: torch.Tensor, + low_shelf_cutoff_freq: torch.Tensor, + low_shelf_q_factor: torch.Tensor, + first_band_gain_dB: torch.Tensor, + first_band_cutoff_freq: torch.Tensor, + first_band_q_factor: torch.Tensor, + second_band_gain_dB: torch.Tensor, + second_band_cutoff_freq: torch.Tensor, + second_band_q_factor: torch.Tensor, + third_band_gain_dB: torch.Tensor, + third_band_cutoff_freq: torch.Tensor, + third_band_q_factor: torch.Tensor, + fourth_band_gain_dB: torch.Tensor, + fourth_band_cutoff_freq: torch.Tensor, + fourth_band_q_factor: torch.Tensor, + high_shelf_gain_dB: torch.Tensor, + high_shelf_cutoff_freq: torch.Tensor, + high_shelf_q_factor: torch.Tensor, +): + """Six-band parametric EQ. + + Low-shelf -> Band 1 -> Band 2 -> Band 3 -> Band 4 -> High-shelf + + Args: + x (torch.Tensor): 1d signal. + + + """ + a_s, b_s = [], [] + #print(f"autodiff peq fs = {sample_rate}") + + # -------- apply low-shelf filter -------- + b, a = deepafx_st.processors.autodiff.signal.biqaud( + low_shelf_gain_dB, + low_shelf_cutoff_freq, + low_shelf_q_factor, + sample_rate, + "low_shelf", + ) + b_s.append(b) + a_s.append(a) + + # -------- apply first-band peaking filter -------- + b, a = deepafx_st.processors.autodiff.signal.biqaud( + first_band_gain_dB, + first_band_cutoff_freq, + first_band_q_factor, + sample_rate, + "peaking", + ) + b_s.append(b) + a_s.append(a) + + # -------- apply second-band peaking filter -------- + b, a = deepafx_st.processors.autodiff.signal.biqaud( + second_band_gain_dB, + second_band_cutoff_freq, + second_band_q_factor, + sample_rate, + "peaking", + ) + b_s.append(b) + a_s.append(a) + + # -------- apply third-band peaking filter -------- + b, a = deepafx_st.processors.autodiff.signal.biqaud( + third_band_gain_dB, + third_band_cutoff_freq, + third_band_q_factor, + sample_rate, + "peaking", + ) + b_s.append(b) + a_s.append(a) + + # -------- apply fourth-band peaking filter -------- + b, a = deepafx_st.processors.autodiff.signal.biqaud( + fourth_band_gain_dB, + fourth_band_cutoff_freq, + fourth_band_q_factor, + sample_rate, + "peaking", + ) + b_s.append(b) + a_s.append(a) + + # -------- apply high-shelf filter -------- + b, a = deepafx_st.processors.autodiff.signal.biqaud( + high_shelf_gain_dB, + high_shelf_cutoff_freq, + high_shelf_q_factor, + sample_rate, + "high_shelf", + ) + b_s.append(b) + a_s.append(a) + + x = deepafx_st.processors.autodiff.signal.approx_iir_filter_cascade( + b_s, a_s, x.view(-1) + ) + + return x + + +class ParametricEQ(Processor): + def __init__( + self, + sample_rate, + min_gain_dB=-24.0, + default_gain_dB=0.0, + max_gain_dB=24.0, + min_q_factor=0.1, + default_q_factor=0.707, + max_q_factor=10, + eps=1e-8, + ): + """ """ + super().__init__() + self.sample_rate = sample_rate + self.eps = eps + self.ports = [ + { + "name": "Lowshelf gain", + "min": min_gain_dB, + "max": max_gain_dB, + "default": default_gain_dB, + "units": "dB", + }, + { + "name": "Lowshelf cutoff", + "min": 20.0, + "max": 200.0, + "default": 100.0, + "units": "Hz", + }, + { + "name": "Lowshelf Q", + "min": min_q_factor, + "max": max_q_factor, + "default": default_q_factor, + "units": "", + }, + { + "name": "First band gain", + "min": min_gain_dB, + "max": max_gain_dB, + "default": default_gain_dB, + "units": "dB", + }, + { + "name": "First band cutoff", + "min": 200.0, + "max": 2000.0, + "default": 400.0, + "units": "Hz", + }, + { + "name": "First band Q", + "min": min_q_factor, + "max": max_q_factor, + "default": 0.707, + "units": "", + }, + { + "name": "Second band gain", + "min": min_gain_dB, + "max": max_gain_dB, + "default": default_gain_dB, + "units": "dB", + }, + { + "name": "Second band cutoff", + "min": 200.0, + "max": 4000.0, + "default": 1000.0, + "units": "Hz", + }, + { + "name": "Second band Q", + "min": min_q_factor, + "max": max_q_factor, + "default": default_q_factor, + "units": "", + }, + { + "name": "Third band gain", + "min": min_gain_dB, + "max": max_gain_dB, + "default": default_gain_dB, + "units": "dB", + }, + { + "name": "Third band cutoff", + "min": 2000.0, + "max": 8000.0, + "default": 4000.0, + "units": "Hz", + }, + { + "name": "Third band Q", + "min": min_q_factor, + "max": max_q_factor, + "default": default_q_factor, + "units": "", + }, + { + "name": "Fourth band gain", + "min": min_gain_dB, + "max": max_gain_dB, + "default": default_gain_dB, + "units": "dB", + }, + { + "name": "Fourth band cutoff", + "min": 4000.0, + "max": (24000 // 2) * 0.9, + "default": 8000.0, + "units": "Hz", + }, + { + "name": "Fourth band Q", + "min": min_q_factor, + "max": max_q_factor, + "default": default_q_factor, + "units": "", + }, + { + "name": "Highshelf gain", + "min": min_gain_dB, + "max": max_gain_dB, + "default": default_gain_dB, + "units": "dB", + }, + { + "name": "Highshelf cutoff", + "min": 4000.0, + "max": (24000 // 2) * 0.9, + "default": 8000.0, + "units": "Hz", + }, + { + "name": "Highshelf Q", + "min": min_q_factor, + "max": max_q_factor, + "default": default_q_factor, + "units": "", + }, + ] + + self.num_control_params = len(self.ports) + + def forward(self, x, p, sample_rate=24000, **kwargs): + + bs, chs, s = x.size() + + inputs = torch.split(x, 1, 0) + params = torch.split(p, 1, 0) + + y = [] # loop over batch dimension + for input, param in zip(inputs, params): + denorm_param = self.denormalize_params(param.view(-1)) + y.append(parametric_eq(input.view(-1), sample_rate, *denorm_param)) + + return torch.stack(y, dim=0).view(bs, 1, -1) diff --git a/deepafx_st/processors/autodiff/signal.py b/deepafx_st/processors/autodiff/signal.py new file mode 100755 index 0000000..e8223b7 --- /dev/null +++ b/deepafx_st/processors/autodiff/signal.py @@ -0,0 +1,194 @@ +import math +import torch +from typing import List + + +def butter(fc, fs: float = 2.0): + """ + + Recall Butterworth polynomials + N = 1 s + 1 + N = 2 s^2 + sqrt(2s) + 1 + N = 3 (s^2 + s + 1)(s + 1) + N = 4 (s^2 + 0.76536s + 1)(s^2 + 1.84776s + 1) + + Scaling + LP to LP: s -> s/w_c + LP to HP: s -> w_c/s + + Bilinear transform: + s = 2/T_d * (1 - z^-1)/(1 + z^-1) + + For 1-pole butterworth lowpass + + 1 / (s + 1) 1-pole prototype + 1 / (s/w_c + 1) LP to LP + 1 / (2/T_d * (1 - z^-1)/(1 + z^-1))/w_c + 1) Bilinear transform + + """ + + # apply pre-warping to the cutoff + T_d = 1 / fs + w_d = (2 * math.pi * fc) / fs + # sys.exit() + w_c = (2 / T_d) * torch.tan(w_d / 2) + + a0 = 2 + (T_d * w_c) + a1 = (T_d * w_c) - 2 + b0 = T_d * w_c + b1 = T_d * w_c + + b = torch.stack([b0, b1], dim=0).view(-1) + a = torch.stack([a0, a1], dim=0).view(-1) + + # normalize + b = b.type_as(fc) / a0 + a = a.type_as(fc) / a0 + + return b, a + + +def biqaud( + gain_dB: torch.Tensor, + cutoff_freq: torch.Tensor, + q_factor: torch.Tensor, + sample_rate: float, + filter_type: str = "peaking", +): + + # convert inputs to Tensors if needed + # gain_dB = torch.tensor([gain_dB]) + # cutoff_freq = torch.tensor([cutoff_freq]) + # q_factor = torch.tensor([q_factor]) + + A = 10 ** (gain_dB / 40.0) + w0 = 2 * math.pi * (cutoff_freq / sample_rate) + alpha = torch.sin(w0) / (2 * q_factor) + cos_w0 = torch.cos(w0) + sqrt_A = torch.sqrt(A) + + if filter_type == "high_shelf": + b0 = A * ((A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha) + b1 = -2 * A * ((A - 1) + (A + 1) * cos_w0) + b2 = A * ((A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha) + a0 = (A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha + a1 = 2 * ((A - 1) - (A + 1) * cos_w0) + a2 = (A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha + elif filter_type == "low_shelf": + b0 = A * ((A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha) + b1 = 2 * A * ((A - 1) - (A + 1) * cos_w0) + b2 = A * ((A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha) + a0 = (A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha + a1 = -2 * ((A - 1) + (A + 1) * cos_w0) + a2 = (A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha + elif filter_type == "peaking": + b0 = 1 + alpha * A + b1 = -2 * cos_w0 + b2 = 1 - alpha * A + a0 = 1 + (alpha / A) + a1 = -2 * cos_w0 + a2 = 1 - (alpha / A) + else: + raise ValueError(f"Invalid filter_type: {filter_type}.") + + b = torch.stack([b0, b1, b2], dim=0).view(-1) + a = torch.stack([a0, a1, a2], dim=0).view(-1) + + # normalize + b = b.type_as(gain_dB) / a0 + a = a.type_as(gain_dB) / a0 + + return b, a + + +def freqz(b, a, n_fft: int = 512): + + B = torch.fft.rfft(b, n_fft) + A = torch.fft.rfft(a, n_fft) + + H = B / A + + return H + + +def freq_domain_filter(x, H, n_fft): + + X = torch.fft.rfft(x, n_fft) + + # move H to same device as input x + H = H.type_as(X) + + Y = X * H + + y = torch.fft.irfft(Y, n_fft) + + return y + + +def approx_iir_filter(b, a, x): + """Approimxate the application of an IIR filter. + + Args: + b (Tensor): The numerator coefficients. + + """ + + # round up to nearest power of 2 for FFT + # n_fft = 2 ** math.ceil(math.log2(x.shape[-1] + x.shape[-1] - 1)) + + n_fft = 2 ** torch.ceil(torch.log2(torch.tensor(x.shape[-1] + x.shape[-1] - 1))) + n_fft = n_fft.int() + + # move coefficients to same device as x + b = b.type_as(x).view(-1) + a = a.type_as(x).view(-1) + + # compute complex response + H = freqz(b, a, n_fft=n_fft).view(-1) + + # apply filter + y = freq_domain_filter(x, H, n_fft) + + # crop + y = y[: x.shape[-1]] + + return y + + +def approx_iir_filter_cascade( + b_s: List[torch.Tensor], + a_s: List[torch.Tensor], + x: torch.Tensor, +): + """Apply a cascade of IIR filters. + + Args: + b (list[Tensor]): List of tensors of shape (3) + a (list[Tensor]): List of tensors of (3) + x (torch.Tensor): 1d Tensor. + """ + + if len(b_s) != len(a_s): + raise RuntimeError( + f"Must have same number of coefficients. Got b: {len(b_s)} and a: {len(a_s)}." + ) + + # round up to nearest power of 2 for FFT + # n_fft = 2 ** math.ceil(math.log2(x.shape[-1] + x.shape[-1] - 1)) + n_fft = 2 ** torch.ceil(torch.log2(torch.tensor(x.shape[-1] + x.shape[-1] - 1))) + n_fft = n_fft.int() + + # this could be done in parallel + b = torch.stack(b_s, dim=0).type_as(x) + a = torch.stack(a_s, dim=0).type_as(x) + + H = freqz(b, a, n_fft=n_fft) + H = torch.prod(H, dim=0).view(-1) + + # apply filter + y = freq_domain_filter(x, H, n_fft) + + # crop + y = y[: x.shape[-1]] + + return y diff --git a/deepafx_st/processors/dsp/compressor.py b/deepafx_st/processors/dsp/compressor.py new file mode 100755 index 0000000..ab515f9 --- /dev/null +++ b/deepafx_st/processors/dsp/compressor.py @@ -0,0 +1,177 @@ +import sys +import torch +import numpy as np +import scipy.signal +from numba import jit + +from deepafx_st.processors.processor import Processor + + +# Adapted from: https://github.com/drscotthawley/signaltrain/blob/master/signaltrain/audio.py +@jit(nopython=True) +def my_clip_min( + x: np.ndarray, + clip_min: float, +): # does the work of np.clip(), which numba doesn't support yet + # TODO: keep an eye on Numba PR https://github.com/numba/numba/pull/3468 that fixes this + inds = np.where(x < clip_min) + x[inds] = clip_min + return x + + +@jit(nopython=True) +def compressor( + x: np.ndarray, + sample_rate: float, + threshold: float = -24.0, + ratio: float = 2.0, + attack_time: float = 0.01, + release_time: float = 0.01, + knee_dB: float = 0.0, + makeup_gain_dB: float = 0.0, + dtype=np.float32, +): + """ + + Args: + x (np.ndarray): Input signal. + sample_rate (float): Sample rate in Hz. + threshold (float): Threhold in dB. + ratio (float): Ratio (should be >=1 , i.e. ratio:1). + attack_time (float): Attack time in seconds. + release_time (float): Release time in seconds. + knee_dB (float): Knee. + makeup_gain_dB (float): Makeup Gain. + dtype (type): Output type. Default: np.float32 + + Returns: + y (np.ndarray): Output signal. + + """ + # print(f"dsp comp fs = {sample_rate}") + + N = len(x) + dtype = x.dtype + y = np.zeros(N, dtype=dtype) + + # Initialize separate attack and release times + # Where do these numbers come from + alpha_A = np.exp(-np.log(9) / (sample_rate * attack_time)) + alpha_R = np.exp(-np.log(9) / (sample_rate * release_time)) + + # Turn the input signal into a uni-polar signal on the dB scale + x_G = 20 * np.log10(np.abs(x) + 1e-8) # x_uni casts type + + # Ensure there are no values of negative infinity + x_G = my_clip_min(x_G, -96) + + # Static characteristics with knee + y_G = np.zeros(N, dtype=dtype) + + # Below knee + idx = np.where((2 * (x_G - threshold)) < -knee_dB) + y_G[idx] = x_G[idx] + + # At knee + idx = np.where((2 * np.abs(x_G - threshold)) <= knee_dB) + y_G[idx] = x_G[idx] + ( + (1 / ratio) * (((x_G[idx] - threshold + knee_dB) / 2) ** 2) + ) / (2 * knee_dB) + + # Above knee threshold + idx = np.where((2 * (x_G - threshold)) > knee_dB) + y_G[idx] = threshold + ((x_G[idx] - threshold) / ratio) + + x_L = x_G - y_G + + # this loop is slow but not vectorizable due to its cumulative, sequential nature. @autojit makes it fast(er). + y_L = np.zeros(N, dtype=dtype) + for n in range(1, N): + # smooth over the gainChange + if x_L[n] > y_L[n - 1]: # attack mode + y_L[n] = (alpha_A * y_L[n - 1]) + ((1 - alpha_A) * x_L[n]) + else: # release + y_L[n] = (alpha_R * y_L[n - 1]) + ((1 - alpha_R) * x_L[n]) + + # Convert to linear amplitude scalar; i.e. map from dB to amplitude + lin_y_L = np.power(10.0, (-y_L / 20.0)) + y = lin_y_L * x # Apply linear amplitude to input sample + + y *= np.power(10.0, makeup_gain_dB / 20.0) # apply makeup gain + + return y.astype(dtype) + + +class Compressor(Processor): + def __init__( + self, + sample_rate, + max_threshold=0.0, + min_threshold=-80, + max_ratio=20.0, + min_ratio=1.0, + max_attack=0.1, + min_attack=0.0001, + max_release=1.0, + min_release=0.005, + max_knee=12.0, + min_knee=0.0, + max_mkgain=48.0, + min_mkgain=-48.0, + eps=1e-8, + ): + """ """ + super().__init__() + self.sample_rate = sample_rate + self.eps = eps + self.ports = [ + { + "name": "Threshold", + "min": min_threshold, + "max": max_threshold, + "default": -12.0, + "units": "", + }, + { + "name": "Ratio", + "min": min_ratio, + "max": max_ratio, + "default": 2.0, + "units": "", + }, + { + "name": "Attack Time", + "min": min_attack, + "max": max_attack, + "default": 0.001, + "units": "s", + }, + { + "name": "Release Time", + "min": min_release, + "max": max_release, + "default": 0.045, + "units": "s", + }, + { + "name": "Knee", + "min": min_knee, + "max": max_knee, + "default": 6.0, + "units": "dB", + }, + { + "name": "Makeup Gain", + "min": min_mkgain, + "max": max_mkgain, + "default": 0.0, + "units": "dB", + }, + ] + + self.num_control_params = len(self.ports) + self.process_fn = compressor + + def forward(self, x, p, sample_rate=24000, **kwargs): + "All processing in the forward is in numpy." + return self.run_series(x, p, sample_rate) diff --git a/deepafx_st/processors/dsp/peq.py b/deepafx_st/processors/dsp/peq.py new file mode 100755 index 0000000..8083b6d --- /dev/null +++ b/deepafx_st/processors/dsp/peq.py @@ -0,0 +1,323 @@ +import torch +import numpy as np +import scipy.signal +from numba import jit + +from deepafx_st.processors.processor import Processor + + +@jit(nopython=True) +def biqaud( + gain_dB: float, + cutoff_freq: float, + q_factor: float, + sample_rate: float, + filter_type: str, +): + """Use design parameters to generate coeffieicnets for a specific filter type. + + Args: + gain_dB (float): Shelving filter gain in dB. + cutoff_freq (float): Cutoff frequency in Hz. + q_factor (float): Q factor. + sample_rate (float): Sample rate in Hz. + filter_type (str): Filter type. + One of "low_shelf", "high_shelf", or "peaking" + + Returns: + b (np.ndarray): Numerator filter coefficients stored as [b0, b1, b2] + a (np.ndarray): Denominator filter coefficients stored as [a0, a1, a2] + """ + + A = 10 ** (gain_dB / 40.0) + w0 = 2.0 * np.pi * (cutoff_freq / sample_rate) + alpha = np.sin(w0) / (2.0 * q_factor) + + cos_w0 = np.cos(w0) + sqrt_A = np.sqrt(A) + + if filter_type == "high_shelf": + b0 = A * ((A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha) + b1 = -2 * A * ((A - 1) + (A + 1) * cos_w0) + b2 = A * ((A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha) + a0 = (A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha + a1 = 2 * ((A - 1) - (A + 1) * cos_w0) + a2 = (A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha + elif filter_type == "low_shelf": + b0 = A * ((A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha) + b1 = 2 * A * ((A - 1) - (A + 1) * cos_w0) + b2 = A * ((A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha) + a0 = (A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha + a1 = -2 * ((A - 1) + (A + 1) * cos_w0) + a2 = (A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha + elif filter_type == "peaking": + b0 = 1 + alpha * A + b1 = -2 * cos_w0 + b2 = 1 - alpha * A + a0 = 1 + alpha / A + a1 = -2 * cos_w0 + a2 = 1 - alpha / A + else: + pass + # raise ValueError(f"Invalid filter_type: {filter_type}.") + + b = np.array([b0, b1, b2]) / a0 + a = np.array([a0, a1, a2]) / a0 + + return b, a + + +# Adapted from https://github.com/csteinmetz1/pyloudnorm/blob/master/pyloudnorm/iirfilter.py +def parametric_eq( + x: np.ndarray, + sample_rate: float, + low_shelf_gain_dB: float = 0.0, + low_shelf_cutoff_freq: float = 80.0, + low_shelf_q_factor: float = 0.707, + first_band_gain_dB: float = 0.0, + first_band_cutoff_freq: float = 300.0, + first_band_q_factor: float = 0.707, + second_band_gain_dB: float = 0.0, + second_band_cutoff_freq: float = 1000.0, + second_band_q_factor: float = 0.707, + third_band_gain_dB: float = 0.0, + third_band_cutoff_freq: float = 4000.0, + third_band_q_factor: float = 0.707, + fourth_band_gain_dB: float = 0.0, + fourth_band_cutoff_freq: float = 8000.0, + fourth_band_q_factor: float = 0.707, + high_shelf_gain_dB: float = 0.0, + high_shelf_cutoff_freq: float = 1000.0, + high_shelf_q_factor: float = 0.707, + dtype=np.float32, +): + """Six-band parametric EQ. + + Low-shelf -> Band 1 -> Band 2 -> Band 3 -> Band 4 -> High-shelf + + Args: + + + """ + # print(f"autodiff peq fs = {sample_rate}") + + # -------- apply low-shelf filter -------- + b, a = biqaud( + low_shelf_gain_dB, + low_shelf_cutoff_freq, + low_shelf_q_factor, + sample_rate, + "low_shelf", + ) + sos0 = np.concatenate((b, a)) + x = scipy.signal.lfilter(b, a, x) + + # -------- apply first-band peaking filter -------- + b, a = biqaud( + first_band_gain_dB, + first_band_cutoff_freq, + first_band_q_factor, + sample_rate, + "peaking", + ) + sos1 = np.concatenate((b, a)) + x = scipy.signal.lfilter(b, a, x) + + # -------- apply second-band peaking filter -------- + b, a = biqaud( + second_band_gain_dB, + second_band_cutoff_freq, + second_band_q_factor, + sample_rate, + "peaking", + ) + sos2 = np.concatenate((b, a)) + x = scipy.signal.lfilter(b, a, x) + + # -------- apply third-band peaking filter -------- + b, a = biqaud( + third_band_gain_dB, + third_band_cutoff_freq, + third_band_q_factor, + sample_rate, + "peaking", + ) + sos3 = np.concatenate((b, a)) + x = scipy.signal.lfilter(b, a, x) + + # -------- apply fourth-band peaking filter -------- + b, a = biqaud( + fourth_band_gain_dB, + fourth_band_cutoff_freq, + fourth_band_q_factor, + sample_rate, + "peaking", + ) + sos4 = np.concatenate((b, a)) + x = scipy.signal.lfilter(b, a, x) + + # -------- apply high-shelf filter -------- + b, a = biqaud( + high_shelf_gain_dB, + high_shelf_cutoff_freq, + high_shelf_q_factor, + sample_rate, + "high_shelf", + ) + sos5 = np.concatenate((b, a)) + x = scipy.signal.lfilter(b, a, x) + + return x.astype(dtype) + + +class ParametricEQ(Processor): + def __init__( + self, + sample_rate, + min_gain_dB=-24.0, + default_gain_dB=0.0, + max_gain_dB=24.0, + min_q_factor=0.1, + default_q_factor=0.707, + max_q_factor=10, + eps=1e-8, + ): + """ """ + super().__init__() + self.sample_rate = sample_rate + self.eps = eps + self.ports = [ + { + "name": "Lowshelf gain", + "min": min_gain_dB, + "max": max_gain_dB, + "default": default_gain_dB, + "units": "dB", + }, + { + "name": "Lowshelf cutoff", + "min": 20.0, + "max": 200.0, + "default": 100.0, + "units": "Hz", + }, + { + "name": "Lowshelf Q", + "min": min_q_factor, + "max": max_q_factor, + "default": default_q_factor, + "units": "", + }, + { + "name": "First band gain", + "min": min_gain_dB, + "max": max_gain_dB, + "default": default_gain_dB, + "units": "dB", + }, + { + "name": "First band cutoff", + "min": 200.0, + "max": 2000.0, + "default": 400.0, + "units": "Hz", + }, + { + "name": "First band Q", + "min": min_q_factor, + "max": max_q_factor, + "default": 0.707, + "units": "", + }, + { + "name": "Second band gain", + "min": min_gain_dB, + "max": max_gain_dB, + "default": default_gain_dB, + "units": "dB", + }, + { + "name": "Second band cutoff", + "min": 800.0, + "max": 4000.0, + "default": 1000.0, + "units": "Hz", + }, + { + "name": "Second band Q", + "min": min_q_factor, + "max": max_q_factor, + "default": default_q_factor, + "units": "", + }, + { + "name": "Third band gain", + "min": min_gain_dB, + "max": max_gain_dB, + "default": default_gain_dB, + "units": "dB", + }, + { + "name": "Third band cutoff", + "min": 2000.0, + "max": 8000.0, + "default": 4000.0, + "units": "Hz", + }, + { + "name": "Third band Q", + "min": min_q_factor, + "max": max_q_factor, + "default": default_q_factor, + "units": "", + }, + { + "name": "Fourth band gain", + "min": min_gain_dB, + "max": max_gain_dB, + "default": default_gain_dB, + "units": "dB", + }, + { + "name": "Fourth band cutoff", + "min": 4000.0, + "max": (24000 // 2) * 0.9, + "default": 8000.0, + "units": "Hz", + }, + { + "name": "Fourth band Q", + "min": min_q_factor, + "max": max_q_factor, + "default": default_q_factor, + "units": "", + }, + { + "name": "Highshelf gain", + "min": min_gain_dB, + "max": max_gain_dB, + "default": default_gain_dB, + "units": "dB", + }, + { + "name": "Highshelf cutoff", + "min": 4000.0, + "max": (24000 // 2) * 0.9, + "default": 8000.0, + "units": "Hz", + }, + { + "name": "Highshelf Q", + "min": min_q_factor, + "max": max_q_factor, + "default": default_q_factor, + "units": "", + }, + ] + + self.num_control_params = len(self.ports) + self.process_fn = parametric_eq + + def forward(self, x, p, sample_rate=24000, **kwargs): + "All processing in the forward is in numpy." + return self.run_series(x, p, sample_rate) diff --git a/deepafx_st/processors/processor.py b/deepafx_st/processors/processor.py new file mode 100755 index 0000000..558a994 --- /dev/null +++ b/deepafx_st/processors/processor.py @@ -0,0 +1,87 @@ +import torch +import multiprocessing +from abc import ABC, abstractmethod +import deepafx_st.utils as utils +import numpy as np + + +class Processor(torch.nn.Module, ABC): + """Processor base class.""" + + def __init__( + self, + ): + super().__init__() + + def denormalize_params(self, p): + """This method takes a tensor of parameters scaled from 0-1 and + restores them back to the original parameter range.""" + + # check if the number of parameters is correct + params = p # torch.split(p, 1, -1) + if len(params) != self.num_control_params: + raise RuntimeError( + f"Invalid number of parameters. ", + f"Expected {self.num_control_params} but found {len(params)} {params.shape}.", + ) + + # iterate over the parameters and expand from 0-1 to full range + denorm_params = [] + for param, port in zip(params, self.ports): + # check if parameter exceeds range + if param > 1.0 or param < 0.0: + raise RuntimeError( + f"""Parameter '{port["name"]}' exceeds range: {param}""" + ) + + # denormalize and store result + denorm_params.append(utils.denormalize(param, port["max"], port["min"])) + + return denorm_params + + def normalize_params(self, *params): + """This method creates a vector of parameters normalized from 0-1.""" + + # check if the number of parameters is correct + if len(params) != self.num_control_params: + raise RuntimeError( + f"Invalid number of parameters. ", + f"Expected {self.num_control_params} but found {len(params)}.", + ) + + norm_params = [] + for param, port in zip(params, self.ports): + norm_params.append(utils.normalize(param, port["max"], port["min"])) + + p = torch.tensor(norm_params).view(1, -1) + + return p + + # def run_series(self, inputs, params): + # """Run the process function in a loop given a list of inputs and parameters""" + # p_b_denorm = [p for p in self.denormalize_params(params)] + # y = self.process_fn(inputs, self.sample_rate, *p_b_denorm) + # return y + + def run_series(self, inputs, params, sample_rate=24000): + """Run the process function in a loop given a list of inputs and parameters""" + if params.ndim == 1: + params = np.reshape(params, (1, -1)) + inputs = np.reshape(inputs, (1, -1)) + bs = inputs.shape[0] + ys = [] + params = np.clip(params, 0, 1) + for bidx in range(bs): + p_b_denorm = [p for p in self.denormalize_params(params[bidx, :])] + y = self.process_fn( + inputs[bidx, ...].reshape(-1), + sample_rate, + *p_b_denorm, + ) + ys.append(y) + y = np.stack(ys, axis=0) + return y + + @abstractmethod + def forward(self, x, p): + pass diff --git a/deepafx_st/processors/proxy/channel.py b/deepafx_st/processors/proxy/channel.py new file mode 100755 index 0000000..297b6ac --- /dev/null +++ b/deepafx_st/processors/proxy/channel.py @@ -0,0 +1,130 @@ +import torch +from deepafx_st.processors.proxy.proxy_system import ProxySystem +from deepafx_st.utils import DSPMode + + +class ProxyChannel(torch.nn.Module): + def __init__( + self, + proxy_system_ckpts: list, + freeze_proxies: bool = True, + dsp_mode: DSPMode = DSPMode.NONE, + num_tcns: int = 2, + tcn_nblocks: int = 4, + tcn_dilation_growth: int = 8, + tcn_channel_width: int = 64, + tcn_kernel_size: int = 13, + sample_rate: int = 24000, + ): + super().__init__() + self.freeze_proxies = freeze_proxies + self.dsp_mode = dsp_mode + self.num_tcns = num_tcns + + # load the proxies + self.proxies = torch.nn.ModuleList() + self.num_control_params = 0 + self.ports = [] + for proxy_system_ckpt in proxy_system_ckpts: + proxy = ProxySystem.load_from_checkpoint(proxy_system_ckpt) + # freeze model parameters + if freeze_proxies: + for param in proxy.parameters(): + param.requires_grad = False + self.proxies.append(proxy) + if proxy.hparams.processor == "channel": + self.ports = proxy.processor.ports + else: + self.ports.append(proxy.processor.ports) + self.num_control_params += proxy.processor.num_control_params + + if len(proxy_system_ckpts) == 0: + if self.num_tcns == 2: + peq_proxy = ProxySystem( + processor="peq", + output_gain=False, + nblocks=tcn_nblocks, + dilation_growth=tcn_dilation_growth, + kernel_size=tcn_kernel_size, + channel_width=tcn_channel_width, + sample_rate=sample_rate, + ) + self.proxies.append(peq_proxy) + self.ports.append(peq_proxy.processor.ports) + self.num_control_params += peq_proxy.processor.num_control_params + comp_proxy = ProxySystem( + processor="comp", + output_gain=True, + nblocks=tcn_nblocks, + dilation_growth=tcn_dilation_growth, + kernel_size=tcn_kernel_size, + channel_width=tcn_channel_width, + sample_rate=sample_rate, + ) + self.proxies.append(comp_proxy) + self.ports.append(comp_proxy.processor.ports) + self.num_control_params += comp_proxy.processor.num_control_params + elif self.num_tcns == 1: + channel_proxy = ProxySystem( + processor="channel", + output_gain=True, + nblocks=tcn_nblocks, + dilation_growth=tcn_dilation_growth, + kernel_size=tcn_kernel_size, + channel_width=tcn_channel_width, + sample_rate=sample_rate, + ) + self.proxies.append(channel_proxy) + for port_list in channel_proxy.processor.ports: + self.ports.append(port_list) + self.num_control_params += channel_proxy.processor.num_control_params + else: + raise ValueError(f"num_tcns must be <= 2. Asked for {self.num_tcns}.") + + def forward( + self, + x: torch.Tensor, + p: torch.Tensor, + dsp_mode: DSPMode = DSPMode.NONE, + sample_rate: int = 24000, + **kwargs, + ): + # loop over the proxies and pass parameters + stop_idx = 0 + for proxy in self.proxies: + start_idx = stop_idx + stop_idx += proxy.processor.num_control_params + p_subset = p[:, start_idx:stop_idx] + if dsp_mode.name == DSPMode.NONE.name: + x = proxy( + x, + p_subset, + use_dsp=False, + ) + elif dsp_mode.name == DSPMode.INFER.name: + x = proxy( + x, + p_subset, + use_dsp=True, + sample_rate=sample_rate, + ) + elif dsp_mode.name == DSPMode.TRAIN_INFER.name: + # Mimic gumbel softmax implementation to replace grads similar to + # https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f + x_hard = proxy( + x, + p_subset, + use_dsp=True, + sample_rate=sample_rate, + ) + x = proxy( + x, + p_subset, + use_dsp=False, + sample_rate=sample_rate, + ) + x = (x_hard - x).detach() + x + else: + assert 0, "invalid dsp model for proxy" + + return x diff --git a/deepafx_st/processors/proxy/proxy_system.py b/deepafx_st/processors/proxy/proxy_system.py new file mode 100755 index 0000000..ed695a5 --- /dev/null +++ b/deepafx_st/processors/proxy/proxy_system.py @@ -0,0 +1,289 @@ +from re import X +import torch +import auraloss +import pytorch_lightning as pl +from typing import Tuple, List, Dict +from argparse import ArgumentParser + + +import deepafx_st.utils as utils +from deepafx_st.data.proxy import DSPProxyDataset +from deepafx_st.processors.proxy.tcn import ConditionalTCN +from deepafx_st.processors.spsa.channel import SPSAChannel +from deepafx_st.processors.dsp.peq import ParametricEQ +from deepafx_st.processors.dsp.compressor import Compressor + + +class ProxySystem(pl.LightningModule): + def __init__( + self, + causal=True, + nblocks=4, + dilation_growth=8, + kernel_size=13, + channel_width=64, + input_dir=None, + processor="channel", + batch_size=32, + lr=3e-4, + lr_patience=20, + patience=10, + preload=False, + sample_rate=24000, + shuffle=True, + train_length=65536, + train_examples_per_epoch=10000, + val_length=131072, + val_examples_per_epoch=1000, + num_workers=16, + output_gain=False, + **kwargs, + ): + super().__init__() + self.save_hyperparameters() + #print(f"Proxy Processor: {processor} @ fs={sample_rate} Hz") + + # construct both the true DSP... + if self.hparams.processor == "peq": + self.processor = ParametricEQ(self.hparams.sample_rate) + elif self.hparams.processor == "comp": + self.processor = Compressor(self.hparams.sample_rate) + elif self.hparams.processor == "channel": + self.processor = SPSAChannel(self.hparams.sample_rate) + + # and the neural network proxy + self.proxy = ConditionalTCN( + self.hparams.sample_rate, + num_control_params=self.processor.num_control_params, + causal=self.hparams.causal, + nblocks=self.hparams.nblocks, + channel_width=self.hparams.channel_width, + kernel_size=self.hparams.kernel_size, + dilation_growth=self.hparams.dilation_growth, + ) + + self.receptive_field = self.proxy.compute_receptive_field() + + self.recon_losses = {} + self.recon_loss_weights = {} + + self.recon_losses["mrstft"] = auraloss.freq.MultiResolutionSTFTLoss( + fft_sizes=[32, 128, 512, 2048, 8192, 32768], + hop_sizes=[16, 64, 256, 1024, 4096, 16384], + win_lengths=[32, 128, 512, 2048, 8192, 32768], + w_sc=0.0, + w_phs=0.0, + w_lin_mag=1.0, + w_log_mag=1.0, + ) + self.recon_loss_weights["mrstft"] = 1.0 + + self.recon_losses["l1"] = torch.nn.L1Loss() + self.recon_loss_weights["l1"] = 100.0 + + def forward(self, x, p, use_dsp=False, sample_rate=24000, **kwargs): + """Use the pre-trained neural network proxy effect.""" + bs, chs, samp = x.size() + if not use_dsp: + y = self.proxy(x, p) + # manually apply the makeup gain parameter + if self.hparams.output_gain and not self.hparams.processor == "peq": + gain_db = (p[..., -1] * 96) - 48 + gain_ln = 10 ** (gain_db / 20.0) + y *= gain_ln.view(bs, chs, 1) + else: + with torch.no_grad(): + bs, chs, s = x.shape + + if self.hparams.output_gain and not self.hparams.processor == "peq": + # override makeup gain + gain_db = (p[..., -1] * 96) - 48 + gain_ln = 10 ** (gain_db / 20.0) + p[..., -1] = 0.5 + + if self.hparams.processor == "channel": + y_temp = self.processor(x.cpu(), p.cpu()) + y_temp = y_temp.view(bs, chs, s).type_as(x) + else: + y_temp = self.processor( + x.cpu().numpy(), + p.cpu().numpy(), + sample_rate, + ) + y_temp = torch.tensor(y_temp).view(bs, chs, s).type_as(x) + + y = y_temp.type_as(x).view(bs, 1, -1) + + if self.hparams.output_gain and not self.hparams.processor == "peq": + y *= gain_ln.view(bs, chs, 1) + + return y + + def common_step( + self, + batch: Tuple, + batch_idx: int, + optimizer_idx: int = 0, + train: bool = True, + ): + loss = 0 + x, y, p = batch + + y_hat = self(x, p) + + # compute loss + for loss_idx, (loss_name, loss_fn) in enumerate(self.recon_losses.items()): + tmp_loss = loss_fn(y_hat.float(), y.float()) + loss += self.recon_loss_weights[loss_name] * tmp_loss + + self.log( + f"train_loss/{loss_name}" if train else f"val_loss/{loss_name}", + tmp_loss, + on_step=True, + on_epoch=True, + prog_bar=False, + logger=True, + sync_dist=True, + ) + + if not train: + # store audio data + data_dict = { + "x": x.float().cpu(), + "y": y.float().cpu(), + "p": p.float().cpu(), + "y_hat": y_hat.float().cpu(), + } + else: + data_dict = {} + + self.log( + "train_loss" if train else "val_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=False, + logger=True, + sync_dist=True, + ) + + return loss, data_dict + + def training_step(self, batch, batch_idx, optimizer_idx=0): + loss, _ = self.common_step(batch, batch_idx) + return loss + + def validation_step(self, batch, batch_idx): + loss, data_dict = self.common_step(batch, batch_idx, train=False) + + if batch_idx == 0: + return data_dict + + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.proxy.parameters(), + lr=self.hparams.lr, + betas=(0.9, 0.999), + ) + + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + patience=self.hparams.lr_patience, + verbose=True, + ) + + return [optimizer], {"scheduler": scheduler, "monitor": "val_loss"} + + def train_dataloader(self): + + train_dataset = DSPProxyDataset( + self.hparams.input_dir, + self.processor, + self.hparams.processor, # name + subset="train", + length=self.hparams.train_length, + num_examples_per_epoch=self.hparams.train_examples_per_epoch, + half=True if self.hparams.precision == 16 else False, + buffer_size_gb=self.hparams.buffer_size_gb, + buffer_reload_rate=self.hparams.buffer_reload_rate, + ) + + g = torch.Generator() + g.manual_seed(0) + + return torch.utils.data.DataLoader( + train_dataset, + num_workers=self.hparams.num_workers, + batch_size=self.hparams.batch_size, + worker_init_fn=utils.seed_worker, + generator=g, + pin_memory=True, + ) + + def val_dataloader(self): + + val_dataset = DSPProxyDataset( + self.hparams.input_dir, + self.processor, + self.hparams.processor, # name + subset="val", + length=self.hparams.val_length, + num_examples_per_epoch=self.hparams.val_examples_per_epoch, + half=True if self.hparams.precision == 16 else False, + buffer_size_gb=self.hparams.buffer_size_gb, + buffer_reload_rate=self.hparams.buffer_reload_rate, + ) + + g = torch.Generator() + g.manual_seed(0) + + return torch.utils.data.DataLoader( + val_dataset, + num_workers=self.hparams.num_workers, + batch_size=self.hparams.batch_size, + worker_init_fn=utils.seed_worker, + generator=g, + pin_memory=True, + ) + + @staticmethod + def count_control_params(plugin_config): + num_control_params = 0 + + for plugin in plugin_config["plugins"]: + for port in plugin["ports"]: + if port["optim"]: + num_control_params += 1 + + return num_control_params + + # add any model hyperparameters here + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + # --- Model --- + parser.add_argument("--causal", action="store_true") + parser.add_argument("--output_gain", action="store_true") + parser.add_argument("--dilation_growth", type=int, default=8) + parser.add_argument("--nblocks", type=int, default=4) + parser.add_argument("--kernel_size", type=int, default=13) + parser.add_argument("--channel_width", type=int, default=13) + # --- Training --- + parser.add_argument("--input_dir", type=str) + parser.add_argument("--processor", type=str) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--lr", type=float, default=3e-4) + parser.add_argument("--lr_patience", type=int, default=20) + parser.add_argument("--patience", type=int, default=10) + parser.add_argument("--preload", action="store_true") + parser.add_argument("--sample_rate", type=int, default=24000) + parser.add_argument("--shuffle", type=bool, default=True) + parser.add_argument("--train_length", type=int, default=65536) + parser.add_argument("--train_examples_per_epoch", type=int, default=10000) + parser.add_argument("--val_length", type=int, default=131072) + parser.add_argument("--val_examples_per_epoch", type=int, default=1000) + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--buffer_reload_rate", type=int, default=1000) + parser.add_argument("--buffer_size_gb", type=float, default=1.0) + + return parser diff --git a/deepafx_st/processors/proxy/tcn.py b/deepafx_st/processors/proxy/tcn.py new file mode 100755 index 0000000..a7e0004 --- /dev/null +++ b/deepafx_st/processors/proxy/tcn.py @@ -0,0 +1,199 @@ +# Copyright 2022 Christian J. Steinmetz + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TCN implementation adapted from: +# https://github.com/csteinmetz1/micro-tcn/blob/main/microtcn/tcn.py + +import torch +from argparse import ArgumentParser + +from deepafx_st.utils import center_crop, causal_crop + + +class FiLM(torch.nn.Module): + def __init__(self, num_features, cond_dim): + super().__init__() + self.num_features = num_features + self.bn = torch.nn.BatchNorm1d(num_features, affine=False) + self.adaptor = torch.nn.Linear(cond_dim, num_features * 2) + + def forward(self, x, cond): + + # project conditioning to 2 x num. conv channels + cond = self.adaptor(cond) + + # split the projection into gain and bias + g, b = torch.chunk(cond, 2, dim=-1) + + # add virtual channel dim if needed + if g.ndim == 2: + g = g.unsqueeze(1) + b = b.unsqueeze(1) + + # reshape for application + g = g.permute(0, 2, 1) + b = b.permute(0, 2, 1) + + x = self.bn(x) # apply BatchNorm without affine + x = (x * g) + b # then apply conditional affine + + return x + + +class ConditionalTCNBlock(torch.nn.Module): + def __init__( + self, in_ch, out_ch, cond_dim, kernel_size=3, dilation=1, causal=False, **kwargs + ): + super().__init__() + + self.in_ch = in_ch + self.out_ch = out_ch + self.kernel_size = kernel_size + self.dilation = dilation + self.causal = causal + + self.conv1 = torch.nn.Conv1d( + in_ch, + out_ch, + kernel_size=kernel_size, + padding=0, + dilation=dilation, + bias=True, + ) + self.film = FiLM(out_ch, cond_dim) + self.relu = torch.nn.PReLU(out_ch) + self.res = torch.nn.Conv1d( + in_ch, out_ch, kernel_size=1, groups=in_ch, bias=False + ) + + def forward(self, x, p): + x_in = x + + x = self.conv1(x) + x = self.film(x, p) # apply FiLM conditioning + x = self.relu(x) + x_res = self.res(x_in) + + if self.causal: + x = x + causal_crop(x_res, x.shape[-1]) + else: + x = x + center_crop(x_res, x.shape[-1]) + + return x + + +class ConditionalTCN(torch.nn.Module): + """Temporal convolutional network with conditioning module. + Args: + sample_rate (float): Audio sample rate. + num_control_params (int, optional): Dimensionality of the conditioning signal. Default: 24 + ninputs (int, optional): Number of input channels (mono = 1, stereo 2). Default: 1 + noutputs (int, optional): Number of output channels (mono = 1, stereo 2). Default: 1 + nblocks (int, optional): Number of total TCN blocks. Default: 10 + kernel_size (int, optional: Width of the convolutional kernels. Default: 3 + dialation_growth (int, optional): Compute the dilation factor at each block as dilation_growth ** (n % stack_size). Default: 1 + channel_growth (int, optional): Compute the output channels at each black as in_ch * channel_growth. Default: 2 + channel_width (int, optional): When channel_growth = 1 all blocks use convolutions with this many channels. Default: 64 + stack_size (int, optional): Number of blocks that constitute a single stack of blocks. Default: 10 + causal (bool, optional): Causal TCN configuration does not consider future input values. Default: False + """ + + def __init__( + self, + sample_rate, + num_control_params=24, + ninputs=1, + noutputs=1, + nblocks=10, + kernel_size=15, + dilation_growth=2, + channel_growth=1, + channel_width=64, + stack_size=10, + causal=False, + skip_connections=False, + **kwargs, + ): + super().__init__() + self.num_control_params = num_control_params + self.ninputs = ninputs + self.noutputs = noutputs + self.nblocks = nblocks + self.kernel_size = kernel_size + self.dilation_growth = dilation_growth + self.channel_growth = channel_growth + self.channel_width = channel_width + self.stack_size = stack_size + self.causal = causal + self.skip_connections = skip_connections + self.sample_rate = sample_rate + + self.blocks = torch.nn.ModuleList() + for n in range(nblocks): + in_ch = out_ch if n > 0 else ninputs + + if self.channel_growth > 1: + out_ch = in_ch * self.channel_growth + else: + out_ch = self.channel_width + + dilation = self.dilation_growth ** (n % self.stack_size) + + self.blocks.append( + ConditionalTCNBlock( + in_ch, + out_ch, + self.num_control_params, + kernel_size=self.kernel_size, + dilation=dilation, + padding="same" if self.causal else "valid", + causal=self.causal, + ) + ) + + self.output = torch.nn.Conv1d(out_ch, noutputs, kernel_size=1) + self.receptive_field = self.compute_receptive_field() + # print( + # f"TCN receptive field: {self.receptive_field} samples", + # f" or {(self.receptive_field/self.sample_rate)*1e3:0.3f} ms", + # ) + + def forward(self, x, p, **kwargs): + + # causally pad input signal + x = torch.nn.functional.pad(x, (self.receptive_field - 1, 0)) + + # iterate over blocks passing conditioning + for idx, block in enumerate(self.blocks): + x = block(x, p) + if self.skip_connections: + if idx == 0: + skips = x + else: + skips = center_crop(skips, x[-1]) + x + else: + skips = 0 + + # final 1x1 convolution to collapse channels + out = self.output(x + skips) + + return out + + def compute_receptive_field(self): + """Compute the receptive field in samples.""" + rf = self.kernel_size + for n in range(1, self.nblocks): + dilation = self.dilation_growth ** (n % self.stack_size) + rf = rf + ((self.kernel_size - 1) * dilation) + return rf diff --git a/deepafx_st/processors/spsa/channel.py b/deepafx_st/processors/spsa/channel.py new file mode 100755 index 0000000..3595af7 --- /dev/null +++ b/deepafx_st/processors/spsa/channel.py @@ -0,0 +1,179 @@ +import torch +import numpy as np +import torch.multiprocessing as mp + +from deepafx_st.processors.dsp.peq import ParametricEQ +from deepafx_st.processors.dsp.compressor import Compressor +from deepafx_st.processors.spsa.spsa_func import SPSAFunction +from deepafx_st.utils import rademacher + + +def dsp_func(x, p, dsp, sample_rate=24000): + + (peq, comp), meta = dsp + + p_peq = p[:meta] + p_comp = p[meta:] + + y = peq(x, p_peq, sample_rate) + y = comp(y, p_comp, sample_rate) + + return y + + +class SPSAChannel(torch.nn.Module): + """ + + Args: + sample_rate (float): Sample rate of the plugin instance + parallel (bool, optional): Use parallel workers for DSP. + + By default, this utilizes parallelized instances of the plugin channel, + where the number of workers is equal to the batch size. + """ + + def __init__( + self, + sample_rate: int, + parallel: bool = False, + batch_size: int = 8, + ): + super().__init__() + + self.batch_size = batch_size + self.parallel = parallel + + if self.parallel: + self.apply_func = SPSAFunction.apply + + procs = {} + for b in range(self.batch_size): + + peq = ParametricEQ(sample_rate) + comp = Compressor(sample_rate) + dsp = ((peq, comp), peq.num_control_params) + + parent_conn, child_conn = mp.Pipe() + p = mp.Process(target=SPSAChannel.worker_pipe, args=(child_conn, dsp)) + p.start() + procs[b] = [p, parent_conn, child_conn] + #print(b, p) + + # Update stuff for external public members TODO: fix + self.ports = [peq.ports, comp.ports] + self.num_control_params = ( + comp.num_control_params + peq.num_control_params + ) + + self.procs = procs + #print(self.procs) + + else: + self.peq = ParametricEQ(sample_rate) + self.comp = Compressor(sample_rate) + self.apply_func = SPSAFunction.apply + self.ports = [self.peq.ports, self.comp.ports] + self.num_control_params = ( + self.comp.num_control_params + self.peq.num_control_params + ) + self.dsp = ((self.peq, self.comp), self.peq.num_control_params) + + # add one param for wet/dry mix + # self.num_control_params += 1 + + def __del__(self): + if hasattr(self, "procs"): + for proc_idx, proc in self.procs.items(): + #print(f"Closing {proc_idx}...") + proc[0].terminate() + + def forward(self, x, p, epsilon=0.001, sample_rate=24000, **kwargs): + """ + Args: + x (Tensor): Input signal with shape: [batch x channels x samples] + p (Tensor): Audio effect control parameters with shape: [batch x parameters] + epsilon (float, optional): Twiddle parameter range for SPSA gradient estimation. + + Returns: + y (Tensor): Processed audio signal. + + """ + if self.parallel: + y = self.apply_func(x, p, None, epsilon, self, sample_rate) + + else: + # this will process on CPU in NumPy + y = self.apply_func(x, p, None, epsilon, self, sample_rate) + + return y.type_as(x) + + @staticmethod + def static_backward(dsp, value): + + ( + batch_index, + x, + params, + needs_input_grad, + needs_param_grad, + grad_output, + epsilon, + ) = value + + grads_input = None + grads_params = None + ps = params.shape[-1] + factors = [1.0] + + # estimate gradient w.r.t input + if needs_input_grad: + delta_k = rademacher(x.shape).numpy() + J_plus = dsp_func(x + epsilon * delta_k, params, dsp) + J_minus = dsp_func(x - epsilon * delta_k, params, dsp) + grads_input = (J_plus - J_minus) / (2.0 * epsilon) + + # estimate gradient w.r.t params + grads_params_runs = [] + if needs_param_grad: + for factor in factors: + params_sublist = [] + delta_k = rademacher(params.shape).numpy() + + # compute output in two random directions of the parameter space + params_plus = np.clip(params + (factor * epsilon * delta_k), 0, 1) + J_plus = dsp_func(x, params_plus, dsp) + + params_minus = np.clip(params - (factor * epsilon * delta_k), 0, 1) + J_minus = dsp_func(x, params_minus, dsp) + grad_param = J_plus - J_minus + + # compute gradient for each parameter as a function of epsilon and random direction + for sub_p_idx in range(ps): + grad_p = grad_param / (2 * epsilon * delta_k[sub_p_idx]) + params_sublist.append(np.sum(grad_output * grad_p)) + + grads_params = np.array(params_sublist) + grads_params_runs.append(grads_params) + + # average gradients + grads_params = np.mean(grads_params_runs, axis=0) + + return grads_input, grads_params + + @staticmethod + def static_forward(dsp, value): + batch_index, x, p, sample_rate = value + y = dsp_func(x, p, dsp, sample_rate) + return y + + @staticmethod + def worker_pipe(child_conn, dsp): + + while True: + msg, value = child_conn.recv() + if msg == "forward": + child_conn.send(SPSAChannel.static_forward(dsp, value)) + elif msg == "backward": + child_conn.send(SPSAChannel.static_backward(dsp, value)) + elif msg == "shutdown": + break diff --git a/deepafx_st/processors/spsa/eps_scheduler.py b/deepafx_st/processors/spsa/eps_scheduler.py new file mode 100755 index 0000000..abcee22 --- /dev/null +++ b/deepafx_st/processors/spsa/eps_scheduler.py @@ -0,0 +1,32 @@ +import torch + + +class EpsilonScheduler: + def __init__( + self, + epsilon: float = 0.001, + patience: int = 10, + factor: float = 0.5, + verbose: bool = False, + ): + self.epsilon = epsilon + self.patience = patience + self.factor = factor + self.best = 1e16 + self.count = 0 + self.verbose = verbose + + def step(self, metric: float): + + if metric < self.best: + self.best = metric + self.count = 0 + else: + self.count += 1 + if self.verbose: + print(f"Train loss has not improved for {self.count} epochs.") + if self.count >= self.patience: + self.count = 0 + self.epsilon *= self.factor + if self.verbose: + print(f"Reducing epsilon to {self.epsilon:0.2e}...") diff --git a/deepafx_st/processors/spsa/spsa_func.py b/deepafx_st/processors/spsa/spsa_func.py new file mode 100755 index 0000000..a657789 --- /dev/null +++ b/deepafx_st/processors/spsa/spsa_func.py @@ -0,0 +1,131 @@ +import torch + + +def spsa_func(input, params, process, i, sample_rate=24000): + return process(input.cpu(), params.cpu(), i, sample_rate).type_as(input) + + +class SPSAFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + input, + params, + process, + epsilon, + thread_context, + sample_rate=24000, + ): + """Apply processor to a batch of tensors using given parameters. + + Args: + input (Tensor): Audio with shape: batch x 2 x samples + params (Tensor): Processor parameters with shape: batch x params + process (function): Function that will apply processing. + epsilon (float): Perturbation strength for SPSA computation. + + Returns: + output (Tensor): Processed audio with same shape as input. + """ + ctx.save_for_backward(input, params) + ctx.epsilon = epsilon + ctx.process = process + ctx.thread_context = thread_context + + if thread_context.parallel: + + for i in range(input.shape[0]): + msg = ( + "forward", + ( + i, + input[i].view(-1).detach().cpu().numpy(), + params[i].view(-1).detach().cpu().numpy(), + sample_rate, + ), + ) + thread_context.procs[i][1].send(msg) + + z = torch.empty_like(input) + for i in range(input.shape[0]): + z[i] = torch.from_numpy(thread_context.procs[i][1].recv()) + else: + z = torch.empty_like(input) + for i in range(input.shape[0]): + value = ( + i, + input[i].view(-1).detach().cpu().numpy(), + params[i].view(-1).detach().cpu().numpy(), + sample_rate, + ) + z[i] = torch.from_numpy( + thread_context.static_forward(thread_context.dsp, value) + ) + + return z + + @staticmethod + def backward(ctx, grad_output): + """Estimate gradients using SPSA.""" + + input, params = ctx.saved_tensors + epsilon = ctx.epsilon + needs_input_grad = ctx.needs_input_grad[0] + needs_param_grad = ctx.needs_input_grad[1] + thread_context = ctx.thread_context + + grads_input = None + grads_params = None + + # Receive grads + if needs_input_grad: + grads_input = torch.empty_like(input) + if needs_param_grad: + grads_params = torch.empty_like(params) + + if thread_context.parallel: + + for i in range(input.shape[0]): + msg = ( + "backward", + ( + i, + input[i].view(-1).detach().cpu().numpy(), + params[i].view(-1).detach().cpu().numpy(), + needs_input_grad, + needs_param_grad, + grad_output[i].view(-1).detach().cpu().numpy(), + epsilon, + ), + ) + thread_context.procs[i][1].send(msg) + + # Wait for output + for i in range(input.shape[0]): + temp1, temp2 = thread_context.procs[i][1].recv() + + if temp1 is not None: + grads_input[i] = torch.from_numpy(temp1) + + if temp2 is not None: + grads_params[i] = torch.from_numpy(temp2) + + return grads_input, grads_params, None, None, None, None + else: + for i in range(input.shape[0]): + value = ( + i, + input[i].view(-1).detach().cpu().numpy(), + params[i].view(-1).detach().cpu().numpy(), + needs_input_grad, + needs_param_grad, + grad_output[i].view(-1).detach().cpu().numpy(), + epsilon, + ) + temp1, temp2 = thread_context.static_backward(thread_context.dsp, value) + if temp1 is not None: + grads_input[i] = torch.from_numpy(temp1) + + if temp2 is not None: + grads_params[i] = torch.from_numpy(temp2) + return grads_input, grads_params, None, None, None, None diff --git a/deepafx_st/system.py b/deepafx_st/system.py new file mode 100755 index 0000000..449afa5 --- /dev/null +++ b/deepafx_st/system.py @@ -0,0 +1,563 @@ +import torch +import auraloss +import torchaudio +from itertools import chain +import pytorch_lightning as pl +from argparse import ArgumentParser +from typing import Tuple, List, Dict + +import deepafx_st.utils as utils +from deepafx_st.utils import DSPMode +from deepafx_st.data.dataset import AudioDataset +from deepafx_st.models.encoder import SpectralEncoder +from deepafx_st.models.controller import StyleTransferController +from deepafx_st.processors.spsa.channel import SPSAChannel +from deepafx_st.processors.spsa.eps_scheduler import EpsilonScheduler +from deepafx_st.processors.proxy.channel import ProxyChannel +from deepafx_st.processors.autodiff.channel import AutodiffChannel + + +class System(pl.LightningModule): + def __init__( + self, + ext="wav", + dsp_sample_rate=24000, + **kwargs, + ): + super().__init__() + self.save_hyperparameters() + + self.eps_scheduler = EpsilonScheduler( + self.hparams.spsa_epsilon, + self.hparams.spsa_patience, + self.hparams.spsa_factor, + self.hparams.spsa_verbose, + ) + + self.hparams.dsp_mode = DSPMode.NONE + + # first construct the processor, since this will dictate encoder + if self.hparams.processor_model == "spsa": + self.processor = SPSAChannel( + self.hparams.dsp_sample_rate, + self.hparams.spsa_parallel, + self.hparams.batch_size, + ) + elif self.hparams.processor_model == "autodiff": + self.processor = AutodiffChannel(self.hparams.dsp_sample_rate) + elif self.hparams.processor_model == "proxy0": + # print('self.hparams.proxy_ckpts,',self.hparams.proxy_ckpts) + self.hparams.dsp_mode = DSPMode.NONE + self.processor = ProxyChannel( + self.hparams.proxy_ckpts, + self.hparams.freeze_proxies, + self.hparams.dsp_mode, + sample_rate=self.hparams.dsp_sample_rate, + ) + elif self.hparams.processor_model == "proxy1": + # print('self.hparams.proxy_ckpts,',self.hparams.proxy_ckpts) + self.hparams.dsp_mode = DSPMode.INFER + self.processor = ProxyChannel( + self.hparams.proxy_ckpts, + self.hparams.freeze_proxies, + self.hparams.dsp_mode, + sample_rate=self.hparams.dsp_sample_rate, + ) + elif self.hparams.processor_model == "proxy2": + # print('self.hparams.proxy_ckpts,',self.hparams.proxy_ckpts) + self.hparams.dsp_mode = DSPMode.TRAIN_INFER + self.processor = ProxyChannel( + self.hparams.proxy_ckpts, + self.hparams.freeze_proxies, + self.hparams.dsp_mode, + sample_rate=self.hparams.dsp_sample_rate, + ) + elif self.hparams.processor_model == "tcn1": + # self.processor = ConditionalTCN(self.hparams.sample_rate) + self.hparams.dsp_mode = DSPMode.NONE + self.processor = ProxyChannel( + [], + freeze_proxies=False, + dsp_mode=self.hparams.dsp_mode, + tcn_nblocks=self.hparams.tcn_nblocks, + tcn_dilation_growth=self.hparams.tcn_dilation_growth, + tcn_channel_width=self.hparams.tcn_channel_width, + tcn_kernel_size=self.hparams.tcn_kernel_size, + num_tcns=1, + sample_rate=self.hparams.sample_rate, + ) + elif self.hparams.processor_model == "tcn2": + self.hparams.dsp_mode = DSPMode.NONE + self.processor = ProxyChannel( + [], + freeze_proxies=False, + dsp_mode=self.hparams.dsp_mode, + tcn_nblocks=self.hparams.tcn_nblocks, + tcn_dilation_growth=self.hparams.tcn_dilation_growth, + tcn_channel_width=self.hparams.tcn_channel_width, + tcn_kernel_size=self.hparams.tcn_kernel_size, + num_tcns=2, + sample_rate=self.hparams.sample_rate, + ) + else: + raise ValueError(f"Invalid processor_model: {self.hparams.processor_model}") + + if self.hparams.encoder_ckpt is not None: + # load encoder weights from a pre-trained system + system = System.load_from_checkpoint(self.hparams.encoder_ckpt) + self.encoder = system.encoder + self.hparams.encoder_embed_dim = system.encoder.embed_dim + else: + self.encoder = SpectralEncoder( + self.processor.num_control_params, + self.hparams.sample_rate, + encoder_model=self.hparams.encoder_model, + embed_dim=self.hparams.encoder_embed_dim, + width_mult=self.hparams.encoder_width_mult, + ) + + if self.hparams.encoder_freeze: + for param in self.encoder.parameters(): + param.requires_grad = False + + self.controller = StyleTransferController( + self.processor.num_control_params, + self.hparams.encoder_embed_dim, + ) + + if len(self.hparams.recon_losses) != len(self.hparams.recon_loss_weights): + raise ValueError("Must supply same number of weights as losses.") + + self.recon_losses = torch.nn.ModuleDict() + for recon_loss in self.hparams.recon_losses: + if recon_loss == "mrstft": + self.recon_losses[recon_loss] = auraloss.freq.MultiResolutionSTFTLoss( + fft_sizes=[32, 128, 512, 2048, 8192, 32768], + hop_sizes=[16, 64, 256, 1024, 4096, 16384], + win_lengths=[32, 128, 512, 2048, 8192, 32768], + w_sc=0.0, + w_phs=0.0, + w_lin_mag=1.0, + w_log_mag=1.0, + ) + elif recon_loss == "mrstft-md": + self.recon_losses[recon_loss] = auraloss.freq.MultiResolutionSTFTLoss( + fft_sizes=[128, 512, 2048, 8192], + hop_sizes=[32, 128, 512, 2048], # 1 / 4 + win_lengths=[128, 512, 2048, 8192], + w_sc=0.0, + w_phs=0.0, + w_lin_mag=1.0, + w_log_mag=1.0, + ) + elif recon_loss == "mrstft-sm": + self.recon_losses[recon_loss] = auraloss.freq.MultiResolutionSTFTLoss( + fft_sizes=[512, 2048, 8192], + hop_sizes=[256, 1024, 4096], # 1 / 4 + win_lengths=[512, 2048, 8192], + w_sc=0.0, + w_phs=0.0, + w_lin_mag=1.0, + w_log_mag=1.0, + ) + elif recon_loss == "melfft": + self.recon_losses[recon_loss] = auraloss.freq.MelSTFTLoss( + self.hparams.sample_rate, + fft_size=self.hparams.train_length, + hop_size=self.hparams.train_length // 2, + win_length=self.hparams.train_length, + n_mels=128, + w_sc=0.0, + device="cuda" if self.hparams.gpus > 0 else "cpu", + ) + elif recon_loss == "melstft": + self.recon_losses[recon_loss] = auraloss.freq.MelSTFTLoss( + self.hparams.sample_rate, + device="cuda" if self.hparams.gpus > 0 else "cpu", + ) + elif recon_loss == "l1": + self.recon_losses[recon_loss] = torch.nn.L1Loss() + elif recon_loss == "sisdr": + self.recon_losses[recon_loss] = auraloss.time.SISDRLoss() + else: + raise ValueError( + f"Invalid reconstruction loss: {self.hparams.recon_losses}" + ) + + def forward( + self, + x: torch.Tensor, + y: torch.Tensor = None, + e_y: torch.Tensor = None, + z: torch.Tensor = None, + dsp_mode: DSPMode = DSPMode.NONE, + analysis_length: int = 0, + sample_rate: int = 24000, + ): + """Forward pass through the system subnetworks. + + Args: + x (tensor): Input audio tensor with shape (batch x 1 x samples) + y (tensor): Target audio tensor with shape (batch x 1 x samples) + e_y (tensor): Target embedding with shape (batch x edim) + z (tensor): Bottleneck latent. + dsp_mode (DSPMode): Mode of operation for the DSP blocks. + analysis_length (optional, int): Only analyze the first N samples. + sample_rate (optional, int): Desired sampling rate for the DSP blocks. + + You must supply target audio `y`, `z`, or an embedding for the target `e_y`. + + Returns: + y_hat (tensor): Output audio. + p (tensor): + e (tensor): + + """ + bs, chs, samp = x.size() + + if sample_rate != self.hparams.sample_rate: + x_enc = torchaudio.transforms.Resample( + sample_rate, self.hparams.sample_rate + ).to(x.device)(x) + if y is not None: + y_enc = torchaudio.transforms.Resample( + sample_rate, self.hparams.sample_rate + ).to(x.device)(y) + else: + x_enc = x + y_enc = y + + if analysis_length > 0: + x_enc = x_enc[..., :analysis_length] + if y is not None: + y_enc = y_enc[..., :analysis_length] + + e_x = self.encoder(x_enc) # generate latent embedding for input + + if y is not None: + e_y = self.encoder(y_enc) # generate latent embedding for target + elif e_y is None: + raise RuntimeError("Must supply y, z, or e_y. None supplied.") + + # learnable comparision + p = self.controller(e_x, e_y, z=z) + + # process audio conditioned on parameters + # if there are multiple channels process them using same parameters + y_hat = torch.zeros(x.shape).type_as(x) + for ch_idx in range(chs): + y_hat_ch = self.processor( + x[:, ch_idx : ch_idx + 1, :], + p, + epsilon=self.eps_scheduler.epsilon, + dsp_mode=dsp_mode, + sample_rate=sample_rate, + ) + y_hat[:, ch_idx : ch_idx + 1, :] = y_hat_ch + + return y_hat, p, e_x + + def common_paired_step( + self, + batch: Tuple, + batch_idx: int, + optimizer_idx: int = 0, + train: bool = False, + ): + """Model step used for validation and training. + + Args: + batch (Tuple[Tensor, Tensor]): Batch items containing input audio (x) and target audio (y). + batch_idx (int): Index of the batch within the current epoch. + optimizer_idx (int): Index of the optimizer, this step is called once for each optimizer. + The firs optimizer corresponds to the generator and the second optimizer, + corresponds to the adversarial loss (when in use). + train (bool): Whether step is called during training (True) or validation (False). + """ + x, y = batch + loss = 0 + dsp_mode = self.hparams.dsp_mode + + if train and dsp_mode.INFER.name == DSPMode.INFER.name: + dsp_mode = DSPMode.NONE + + # proces input audio through model + if self.hparams.style_transfer: + length = x.shape[-1] + + x_A = x[..., : length // 2] + x_B = x[..., length // 2 :] + + y_A = y[..., : length // 2] + y_B = y[..., length // 2 :] + + if torch.rand(1).sum() > 0.5: + y_ref = y_B + y = y_A + x = x_A + else: + y_ref = y_A + y = y_B + x = x_B + + y_hat, p, e = self(x, y=y_ref, dsp_mode=dsp_mode) + else: + y_ref = None + y_hat, p, e = self(x, dsp_mode=dsp_mode) + + # compute reconstruction loss terms + for loss_idx, (loss_name, recon_loss_fn) in enumerate( + self.recon_losses.items() + ): + temp_loss = recon_loss_fn(y_hat, y) # reconstruction loss + loss += float(self.hparams.recon_loss_weights[loss_idx]) * temp_loss + + self.log( + ("train" if train else "val") + f"_loss/{loss_name}", + temp_loss, + on_step=True, + on_epoch=True, + prog_bar=False, + logger=True, + sync_dist=True, + ) + + # log the overall aggregate loss + self.log( + ("train" if train else "val") + "_loss/loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=False, + logger=True, + sync_dist=True, + ) + + # store audio data + data_dict = { + "x": x.cpu(), + "y": y.cpu(), + "p": p.cpu(), + "e": e.cpu(), + "y_hat": y_hat.cpu(), + } + + if y_ref is not None: + data_dict["y_ref"] = y_ref.cpu() + + return loss, data_dict + + def training_step(self, batch, batch_idx, optimizer_idx=0): + loss, _ = self.common_paired_step( + batch, + batch_idx, + optimizer_idx, + train=True, + ) + + return loss + + def training_epoch_end(self, training_step_outputs): + if self.hparams.spsa_schedule and self.hparams.processor_model == "spsa": + self.eps_scheduler.step( + self.trainer.callback_metrics[self.hparams.train_monitor], + ) + + def validation_step(self, batch, batch_idx): + loss, data_dict = self.common_paired_step(batch, batch_idx) + + return data_dict + + def optimizer_step( + self, + epoch, + batch_idx, + optimizer, + optimizer_idx, + optimizer_closure, + on_tpu=False, + using_native_amp=False, + using_lbfgs=False, + ): + if optimizer_idx == 0: + optimizer.step(closure=optimizer_closure) + + def configure_optimizers(self): + # we need additional optimizer for the discriminator + optimizers = [] + g_optimizer = torch.optim.Adam( + chain( + self.encoder.parameters(), + self.processor.parameters(), + self.controller.parameters(), + ), + lr=self.hparams.lr, + betas=(0.9, 0.999), + ) + optimizers.append(g_optimizer) + + g_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + g_optimizer, + patience=self.hparams.lr_patience, + verbose=True, + ) + ms1 = int(self.hparams.max_epochs * 0.8) + ms2 = int(self.hparams.max_epochs * 0.95) + print( + "Learning rate schedule:", + f"0 {self.hparams.lr:0.2e} -> ", + f"{ms1} {self.hparams.lr*0.1:0.2e} -> ", + f"{ms2} {self.hparams.lr*0.01:0.2e}", + ) + g_scheduler = torch.optim.lr_scheduler.MultiStepLR( + g_optimizer, + milestones=[ms1, ms2], + gamma=0.1, + ) + + lr_schedulers = { + "scheduler": g_scheduler, + } + + return optimizers, lr_schedulers + + def train_dataloader(self): + + train_dataset = AudioDataset( + self.hparams.audio_dir, + subset="train", + train_frac=self.hparams.train_frac, + half=self.hparams.half, + length=self.hparams.train_length, + input_dirs=self.hparams.input_dirs, + random_scale_input=self.hparams.random_scale_input, + random_scale_target=self.hparams.random_scale_target, + buffer_size_gb=self.hparams.buffer_size_gb, + buffer_reload_rate=self.hparams.buffer_reload_rate, + num_examples_per_epoch=self.hparams.train_examples_per_epoch, + augmentations={ + "pitch": {"sr": self.hparams.sample_rate}, + "tempo": {"sr": self.hparams.sample_rate}, + }, + freq_corrupt=self.hparams.freq_corrupt, + drc_corrupt=self.hparams.drc_corrupt, + ext=self.hparams.ext, + ) + + g = torch.Generator() + g.manual_seed(0) + + return torch.utils.data.DataLoader( + train_dataset, + num_workers=self.hparams.num_workers, + batch_size=self.hparams.batch_size, + worker_init_fn=utils.seed_worker, + generator=g, + pin_memory=True, + persistent_workers=True, + timeout=60, + ) + + def val_dataloader(self): + + val_dataset = AudioDataset( + self.hparams.audio_dir, + subset="val", + half=self.hparams.half, + train_frac=self.hparams.train_frac, + length=self.hparams.val_length, + input_dirs=self.hparams.input_dirs, + buffer_size_gb=self.hparams.buffer_size_gb, + buffer_reload_rate=self.hparams.buffer_reload_rate, + random_scale_input=self.hparams.random_scale_input, + random_scale_target=self.hparams.random_scale_target, + num_examples_per_epoch=self.hparams.val_examples_per_epoch, + augmentations={}, + freq_corrupt=self.hparams.freq_corrupt, + drc_corrupt=self.hparams.drc_corrupt, + ext=self.hparams.ext, + ) + + self.val_dataset = val_dataset + + g = torch.Generator() + g.manual_seed(0) + + return torch.utils.data.DataLoader( + val_dataset, + num_workers=1, + batch_size=self.hparams.batch_size, + worker_init_fn=utils.seed_worker, + generator=g, + pin_memory=True, + persistent_workers=True, + timeout=60, + ) + def shutdown(self): + del self.processor + + # add any model hyperparameters here + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + # --- Training --- + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--lr", type=float, default=3e-4) + parser.add_argument("--lr_patience", type=int, default=20) + parser.add_argument("--recon_losses", nargs="+", default=["l1"]) + parser.add_argument("--recon_loss_weights", nargs="+", default=[1.0]) + # --- Controller --- + parser.add_argument( + "--processor_model", + type=str, + help="autodiff, spsa, tcn1, tcn2, proxy0, proxy1, proxy2", + ) + parser.add_argument("--controller_hidden_dim", type=int, default=256) + parser.add_argument("--style_transfer", action="store_true") + # --- Encoder --- + parser.add_argument("--encoder_model", type=str, default="mobilenet_v2") + parser.add_argument("--encoder_embed_dim", type=int, default=128) + parser.add_argument("--encoder_width_mult", type=int, default=2) + parser.add_argument("--encoder_ckpt", type=str, default=None) + parser.add_argument("--encoder_freeze", action="store_true", default=False) + # --- TCN --- + parser.add_argument("--tcn_causal", action="store_true") + parser.add_argument("--tcn_nblocks", type=int, default=4) + parser.add_argument("--tcn_dilation_growth", type=int, default=8) + parser.add_argument("--tcn_channel_width", type=int, default=32) + parser.add_argument("--tcn_kernel_size", type=int, default=13) + # --- SPSA --- + parser.add_argument("--plugin_config_file", type=str, default=None) + parser.add_argument("--spsa_epsilon", type=float, default=0.001) + parser.add_argument("--spsa_schedule", action="store_true") + parser.add_argument("--spsa_patience", type=int, default=10) + parser.add_argument("--spsa_verbose", action="store_true") + parser.add_argument("--spsa_factor", type=float, default=0.5) + parser.add_argument("--spsa_parallel", action="store_true") + # --- Proxy ---- + parser.add_argument("--proxy_ckpts", nargs="+") + parser.add_argument("--freeze_proxies", action="store_true", default=False) + parser.add_argument("--use_dsp", action="store_true", default=False) + parser.add_argument("--dsp_mode", choices=DSPMode, type=DSPMode) + # --- Dataset --- + parser.add_argument("--audio_dir", type=str) + parser.add_argument("--ext", type=str, default="wav") + parser.add_argument("--input_dirs", nargs="+") + parser.add_argument("--buffer_reload_rate", type=int, default=1000) + parser.add_argument("--buffer_size_gb", type=float, default=1.0) + parser.add_argument("--sample_rate", type=int, default=24000) + parser.add_argument("--dsp_sample_rate", type=int, default=24000) + parser.add_argument("--shuffle", type=bool, default=True) + parser.add_argument("--random_scale_input", action="store_true") + parser.add_argument("--random_scale_target", action="store_true") + parser.add_argument("--freq_corrupt", action="store_true") + parser.add_argument("--drc_corrupt", action="store_true") + parser.add_argument("--train_length", type=int, default=65536) + parser.add_argument("--train_frac", type=float, default=0.8) + parser.add_argument("--half", action="store_true") + parser.add_argument("--train_examples_per_epoch", type=int, default=10000) + parser.add_argument("--val_length", type=int, default=131072) + parser.add_argument("--val_examples_per_epoch", type=int, default=1000) + parser.add_argument("--num_workers", type=int, default=16) + + return parser diff --git a/deepafx_st/utils.py b/deepafx_st/utils.py new file mode 100755 index 0000000..d60acb1 --- /dev/null +++ b/deepafx_st/utils.py @@ -0,0 +1,277 @@ +# Adapted from: +# https://github.com/csteinmetz1/micro-tcn/blob/main/microtcn/utils.py +import os +import csv +import torch +import fnmatch +import numpy as np +import random +from enum import Enum +import pyloudnorm as pyln + + +class DSPMode(Enum): + NONE = "none" + TRAIN_INFER = "train_infer" + INFER = "infer" + + def __str__(self): + return self.value + + +def loudness_normalize(x, sample_rate, target_loudness=-24.0): + x = x.view(1, -1) + stereo_audio = x.repeat(2, 1).permute(1, 0).numpy() + meter = pyln.Meter(sample_rate) + loudness = meter.integrated_loudness(stereo_audio) + norm_x = pyln.normalize.loudness( + stereo_audio, + loudness, + target_loudness, + ) + x = torch.tensor(norm_x).permute(1, 0) + x = x[0, :].view(1, -1) + + return x + + +def get_random_file_id(keys): + # generate a random index into the keys of the input files + rand_input_idx = torch.randint(0, len(keys) - 1, [1])[0] + # find the key (file_id) correponding to the random index + rand_input_file_id = list(keys)[rand_input_idx] + + return rand_input_file_id + + +def get_random_patch(audio_file, length, check_silence=True): + silent = True + while silent: + start_idx = int(torch.rand(1) * (audio_file.num_frames - length)) + stop_idx = start_idx + length + patch = audio_file.audio[:, start_idx:stop_idx].clone().detach() + if (patch ** 2).mean() > 1e-4 or not check_silence: + silent = False + + return start_idx, stop_idx + + +def seed_worker(worker_id): + worker_seed = torch.initial_seed() % 2 ** 32 + np.random.seed(worker_seed) + random.seed(worker_seed) + + +def getFilesPath(directory, extension): + + n_path = [] + for path, subdirs, files in os.walk(directory): + for name in files: + if fnmatch.fnmatch(name, extension): + n_path.append(os.path.join(path, name)) + n_path.sort() + + return n_path + + +def count_parameters(model, trainable_only=True): + + if trainable_only: + if len(list(model.parameters())) > 0: + params = sum(p.numel() for p in model.parameters() if p.requires_grad) + else: + params = 0 + else: + if len(list(model.parameters())) > 0: + params = sum(p.numel() for p in model.parameters()) + else: + params = 0 + + return params + + +def system_summary(system): + print(f"Encoder: {count_parameters(system.encoder)/1e6:0.2f} M") + print(f"Processor: {count_parameters(system.processor)/1e6:0.2f} M") + + if hasattr(system, "adv_loss_fn"): + for idx, disc in enumerate(system.adv_loss_fn.discriminators): + print(f"Discriminator {idx+1}: {count_parameters(disc)/1e6:0.2f} M") + + +def center_crop(x, length: int): + if x.shape[-1] != length: + start = (x.shape[-1] - length) // 2 + stop = start + length + x = x[..., start:stop] + return x + + +def causal_crop(x, length: int): + if x.shape[-1] != length: + stop = x.shape[-1] - 1 + start = stop - length + x = x[..., start:stop] + return x + + +def denormalize(norm_val, max_val, min_val): + return (norm_val * (max_val - min_val)) + min_val + + +def normalize(denorm_val, max_val, min_val): + return (denorm_val - min_val) / (max_val - min_val) + + +def get_random_patch(audio_file, length, energy_treshold=1e-4): + """Produce sample indicies for a random patch of size `length`. + + This function will check the energy of the selected patch to + ensure that it is not complete silence. If silence is found, + it will continue searching for a non-silent patch. + + Args: + audio_file (AudioFile): Audio file object. + length (int): Number of samples in random patch. + + Returns: + start_idx (int): Starting sample index + stop_idx (int): Stop sample index + """ + + silent = True + while silent: + start_idx = int(torch.rand(1) * (audio_file.num_frames - length)) + stop_idx = start_idx + length + patch = audio_file.audio[:, start_idx:stop_idx] + if (patch ** 2).mean() > energy_treshold: + silent = False + + return start_idx, stop_idx + + +def split_dataset(file_list, subset, train_frac): + """Given a list of files, split into train/val/test sets. + + Args: + file_list (list): List of audio files. + subset (str): One of "train", "val", or "test". + train_frac (float): Fraction of the dataset to use for training. + + Returns: + file_list (list): List of audio files corresponding to subset. + """ + assert train_frac > 0.1 and train_frac < 1.0 + + total_num_examples = len(file_list) + + train_num_examples = int(total_num_examples * train_frac) + val_num_examples = int(total_num_examples * (1 - train_frac) / 2) + test_num_examples = total_num_examples - (train_num_examples + val_num_examples) + + if train_num_examples < 0: + raise ValueError( + f"No examples in training set. Try increasing train_frac: {train_frac}." + ) + elif val_num_examples < 0: + raise ValueError( + f"No examples in validation set. Try decreasing train_frac: {train_frac}." + ) + elif test_num_examples < 0: + raise ValueError( + f"No examples in test set. Try decreasing train_frac: {train_frac}." + ) + + if subset == "train": + start_idx = 0 + stop_idx = train_num_examples + elif subset == "val": + start_idx = train_num_examples + stop_idx = start_idx + val_num_examples + elif subset == "test": + start_idx = train_num_examples + val_num_examples + stop_idx = start_idx + test_num_examples + 1 + else: + raise ValueError("Invalid subset: {subset}.") + + return file_list[start_idx:stop_idx] + + +def rademacher(size): + """Generates random samples from a Rademacher distribution +-1 + + Args: + size (int): + + """ + m = torch.distributions.binomial.Binomial(1, 0.5) + x = m.sample(size) + x[x == 0] = -1 + return x + + +def get_subset(csv_file): + subset_files = [] + with open(csv_file) as fp: + reader = csv.DictReader(fp) + for row in reader: + subset_files.append(row["filepath"]) + + return list(set(subset_files)) + + +def conform_length(x: torch.Tensor, length: int): + """Crop or pad input on last dim to match `length`.""" + if x.shape[-1] < length: + padsize = length - x.shape[-1] + x = torch.nn.functional.pad(x, (0, padsize)) + elif x.shape[-1] > length: + x = x[..., :length] + + return x + + +def linear_fade( + x: torch.Tensor, + fade_ms: float = 50.0, + sample_rate: float = 22050, +): + """Apply fade in and fade out to last dim.""" + fade_samples = int(fade_ms * 1e-3 * 22050) + + fade_in = torch.linspace(0.0, 1.0, steps=fade_samples) + fade_out = torch.linspace(1.0, 0.0, steps=fade_samples) + + # fade in + x[..., :fade_samples] *= fade_in + + # fade out + x[..., -fade_samples:] *= fade_out + + return x + + +# def get_random_patch(x, sample_rate, length_samples): +# length = length_samples +# silent = True +# while silent: +# start_idx = np.random.randint(0, x.shape[-1] - length - 1) +# stop_idx = start_idx + length +# x_crop = x[0:1, start_idx:stop_idx] + +# # check for silence +# frames = length // sample_rate +# silent_frames = [] +# for n in range(frames): +# start_idx = n * sample_rate +# stop_idx = start_idx + sample_rate +# x_frame = x_crop[0:1, start_idx:stop_idx] +# if (x_frame ** 2).mean() > 3e-4: +# silent_frames.append(False) +# else: +# silent_frames.append(True) +# silent = True if any(silent_frames) else False + +# x_crop /= x_crop.abs().max() + +# return x_crop diff --git a/deepafx_st/version.py b/deepafx_st/version.py new file mode 100755 index 0000000..ed6b02e --- /dev/null +++ b/deepafx_st/version.py @@ -0,0 +1,6 @@ +# !/usr/bin/env python +# -*- coding: utf-8 -*- +'''Version info''' + +short_version = '0.0' +version = '0.0.1' diff --git a/docs/color-generic-style-transfer-headline.svg b/docs/color-generic-style-transfer-headline.svg new file mode 100644 index 0000000..e6760be --- /dev/null +++ b/docs/color-generic-style-transfer-headline.svg @@ -0,0 +1,4 @@ + + + +
Controller
Controller
Encoder
Encoder
Encoder
Encoder
Reference
Reference
Output
Output
Input
Input
Audio Effect 1
Audio Effect 1
Audio Effect N
Audio Effect N
...
...
Audio Effect 2
Audio Effect 2
Parameters
Parameters
Text is not SVG - cannot display
\ No newline at end of file diff --git a/docs/deepafx-st-headline.png b/docs/deepafx-st-headline.png new file mode 100644 index 0000000..3e07e7d Binary files /dev/null and b/docs/deepafx-st-headline.png differ diff --git a/docs/generic-style-transfer-headline.svg b/docs/generic-style-transfer-headline.svg new file mode 100644 index 0000000..82a0334 --- /dev/null +++ b/docs/generic-style-transfer-headline.svg @@ -0,0 +1,4 @@ + + + +
Controller
Controller
Encoder
Encoder
Encoder
Encoder
Reference
Reference
Output
Output
Input
Input
Audio Effect 1
Audio Effect 1
Audio Effect N
Audio Effect N
...
...
Audio Effect 2
Audio Effect 2
Parameters
Parameters
Text is not SVG - cannot display
\ No newline at end of file diff --git a/docs/new-generic-style-transfer-headline.svg b/docs/new-generic-style-transfer-headline.svg new file mode 100644 index 0000000..1c0a48b --- /dev/null +++ b/docs/new-generic-style-transfer-headline.svg @@ -0,0 +1,4 @@ + + + +
Controller
Controller
Encoder
Encoder
Encoder
Encoder
Reference
Reference
Output
Output
Input
Input
Audio Effect 1
Audio Effect 1
Audio Effect N
Audio Effect N
...
...
Audio Effect 2
Audio Effect 2
Parameters
Parameters
Text is not SVG - cannot display
\ No newline at end of file diff --git a/docs/training.svg b/docs/training.svg new file mode 100644 index 0000000..85fde0d --- /dev/null +++ b/docs/training.svg @@ -0,0 +1,4 @@ + + + +
x_i
x_r
f_\the...
f_\the...
x_{i,a}
\mathcal{L...
p
DDSP
DDSP
e_{ i}
e_{...
g_\phi
e_{ ir...
\mat...
x_{r,...
x_{r,b}
x_r
\hat{x}_{r...
DSP
DSP
Split
Split
Split
Split
DSP
DSP
h(x, p)
Augment
Augment
Encoder
Encoder
Controller
Controller
Differentiable Audio Effects
Differentiable Au...
x_{r,a...
\nabla
x
Self-Supervised Data Generation
Self-Supervised Data Generation
Model Architecture
Model Architecture
x_{i,a...
x_{i,b}
x_{r,b...
x_{r,a}
Audio
Audio
Parameters
Parameters
Gradients
Gradients
STFT
STFT
STFT
STFT
Text is not SVG - cannot display
\ No newline at end of file diff --git a/results/eval.md b/results/eval.md new file mode 100644 index 0000000..750ebdd --- /dev/null +++ b/results/eval.md @@ -0,0 +1,1611 @@ +## LibriTTS @ 24kHz (test) + +CUDA_VISIBLE_DEVICES=4 python scripts/eval.py \ +/import/c4dm-datasets/deepafx_st/logs/style/libritts/ \ +--root_dir /import/c4dm-datasets/deepafx_st/ \ +--gpu \ +--spsa_version 2 \ +--tcn1_version 1 \ +--autodiff_version 1 \ +--tcn2_version 1 \ +--subset test \ +--output /import/c4dm-datasets/deepafx_st/ +--save \ + +1000 + + Corrupt PESQ: 3.765 MRSTFT: 1.187 MSD: 2.180 SCE: 687.534 CFE: 6.261 RMS: 6.983 LUFS: 2.426 + TCN1 (libritts) PESQ: 4.258 MRSTFT: 0.405 MSD: 0.887 SCE: 128.408 CFE: 2.582 RMS: 2.237 LUFS: 1.066 + TCN2 (libritts) PESQ: 4.281 MRSTFT: 0.372 MSD: 0.833 SCE: 117.496 CFE: 2.460 RMS: 1.927 LUFS: 0.925 + SPSA (libritts) PESQ: 4.180 MRSTFT: 0.635 MSD: 1.406 SCE: 219.409 CFE: 5.734 RMS: 3.263 LUFS: 1.600 + proxy0 (libritts) PESQ: 3.643 MRSTFT: 0.676 MSD: 1.405 SCE: 264.970 CFE: 4.291 RMS: 2.812 LUFS: 1.340 + Proxy1 (libritts) PESQ: 3.999 MRSTFT: 1.038 MSD: 2.179 SCE: 440.159 CFE: 5.283 RMS: 5.472 LUFS: 2.679 + Proxy2 (libritts) PESQ: 3.945 MRSTFT: 1.058 MSD: 2.088 SCE: 404.867 CFE: 5.328 RMS: 6.820 LUFS: 3.197 + Autodiff (libritts) PESQ: 4.310 MRSTFT: 0.388 MSD: 0.882 SCE: 111.549 CFE: 4.079 RMS: 1.828 LUFS: 0.823 + Baseline (libritts) PESQ: 3.856 MRSTFT: 0.943 MSD: 1.955 SCE: 410.330 CFE: 4.013 RMS: 4.204 LUFS: 1.674 +-------------------------------- + + Corrupt 3.765 & 1.187 & 2.180 & 687.534 & 6.983 & 2.426 + Baseline 3.856 & 0.943 & 1.955 & 410.330 & 4.204 & 1.674 + TCN1 4.258 & 0.405 & 0.887 & 128.408 & 2.237 & 1.066 + TCN2 4.281 & 0.372 & 0.833 & 117.496 & 1.927 & 0.925 + SPSA 4.180 & 0.635 & 1.406 & 219.409 & 3.263 & 1.600 + proxy0 3.643 & 0.676 & 1.405 & 264.970 & 2.812 & 1.340 + Proxy1 3.999 & 1.038 & 2.179 & 440.159 & 5.472 & 2.679 + Proxy2 3.945 & 1.058 & 2.088 & 404.867 & 6.820 & 3.197 + Autodiff 4.310 & 0.388 & 0.882 & 111.549 & 1.828 & 0.823 +-------------------------------- + +1000 NEW MSD Scores + + Corrupt PESQ: 3.765 MRSTFT: 1.187 MSD: 5.311 SCE: 687.534 CFE: 6.261 RMS: 6.983 LUFS: 2.426 + TCN1 (libritts) PESQ: 4.258 MRSTFT: 0.405 MSD: 1.647 SCE: 128.400 CFE: 2.582 RMS: 2.237 LUFS: 1.066 + TCN2 (libritts) PESQ: 4.281 MRSTFT: 0.372 MSD: 1.529 SCE: 117.493 CFE: 2.460 RMS: 1.927 LUFS: 0.925 + SPSA (libritts) PESQ: 4.180 MRSTFT: 0.635 MSD: 2.894 SCE: 219.409 CFE: 5.734 RMS: 3.263 LUFS: 1.600 + proxy0 (libritts) PESQ: 3.643 MRSTFT: 0.676 MSD: 2.483 SCE: 264.947 CFE: 4.291 RMS: 2.811 LUFS: 1.340 + Proxy1 (libritts) PESQ: 3.999 MRSTFT: 1.038 MSD: 4.766 SCE: 440.159 CFE: 5.283 RMS: 5.472 LUFS: 2.679 + Proxy2 (libritts) PESQ: 3.945 MRSTFT: 1.058 MSD: 4.858 SCE: 404.866 CFE: 5.328 RMS: 6.820 LUFS: 3.197 + Autodiff (libritts) PESQ: 4.310 MRSTFT: 0.388 MSD: 1.692 SCE: 111.507 CFE: 4.079 RMS: 1.828 LUFS: 0.823 + Baseline (libritts) PESQ: 3.856 MRSTFT: 0.943 MSD: 4.002 SCE: 410.330 CFE: 4.013 RMS: 4.204 LUFS: 1.674 +-------------------------------- +Evaluation complete. + +## DAPS @ 24kHz (train) (we don't train with train set) + +CUDA_VISIBLE_DEVICES=0 python scripts/eval.py \ +/import/c4dm-datasets/deepafx_st/logs/style/libritts/ \ +--root_dir /import/c4dm-datasets/deepafx_st/ \ +--gpu \ +--dataset daps \ +--dataset_dir daps_24000/cleanraw \ +--spsa_version 2 \ +--tcn1_version 1 \ +--autodiff_version 1 \ +--tcn2_version 1 \ +--subset train \ +--output /import/c4dm-datasets/deepafx_st/ +#--save \ + +1000 + + Corrupt PESQ: 3.684 MRSTFT: 1.179 MSD: 2.151 SCE: 641.683 CFE: 6.133 RMS: 6.900 LUFS: 2.314 + TCN1 (daps) PESQ: 4.185 MRSTFT: 0.419 MSD: 0.884 SCE: 124.609 CFE: 2.473 RMS: 2.098 LUFS: 1.006 + TCN2 (daps) PESQ: 4.224 MRSTFT: 0.391 MSD: 0.841 SCE: 113.863 CFE: 2.352 RMS: 1.886 LUFS: 0.913 + SPSA (daps) PESQ: 4.099 MRSTFT: 0.645 MSD: 1.379 SCE: 213.596 CFE: 5.166 RMS: 2.989 LUFS: 1.511 + proxy0 (daps) PESQ: 3.605 MRSTFT: 0.685 MSD: 1.362 SCE: 249.159 CFE: 4.222 RMS: 2.732 LUFS: 1.350 + Proxy1 (daps) PESQ: 3.903 MRSTFT: 1.022 MSD: 2.113 SCE: 451.879 CFE: 4.927 RMS: 5.104 LUFS: 2.535 + Proxy2 (daps) PESQ: 3.891 MRSTFT: 1.037 MSD: 2.045 SCE: 395.421 CFE: 5.112 RMS: 6.754 LUFS: 3.117 + Autodiff (daps) PESQ: 4.222 MRSTFT: 0.416 MSD: 0.895 SCE: 109.004 CFE: 4.290 RMS: 1.758 LUFS: 0.799 + Baseline (daps) PESQ: 3.787 MRSTFT: 0.917 MSD: 1.882 SCE: 399.714 CFE: 3.742 RMS: 3.705 LUFS: 1.481 +-------------------------------- +Evaluation complete. + + + Corrupt & 3.684 & 1.179 & 2.151 & 641.683 & 6.900 & 2.314 + TCN1 (daps) & 4.185 & 0.419 & 0.884 & 124.609 & 2.098 & 1.006 + TCN2 (daps) & 4.224 & 0.391 & 0.841 & 113.863 & 1.886 & 0.913 + SPSA (daps) & 4.099 & 0.645 & 1.379 & 213.596 & 2.989 & 1.511 + proxy0 (daps) & 3.605 & 0.685 & 1.362 & 249.159 & 2.732 & 1.350 + Proxy1 (daps) & 3.903 & 1.022 & 2.113 & 451.879 & 5.104 & 2.535 + Proxy2 (daps) & 3.891 & 1.037 & 2.045 & 395.421 & 6.754 & 3.117 + Autodiff (daps) & 4.222 & 0.416 & 0.895 & 109.004 & 1.758 & 0.799 + Baseline (daps) & 3.787 & 0.917 & 1.882 & 399.714 & 3.705 & 1.481 + +1000 New MSD + + Corrupt PESQ: 3.684 MRSTFT: 1.179 MSD: 5.318 SCE: 641.683 CFE: 6.133 RMS: 6.900 LUFS: 2.314 + TCN1 (daps) PESQ: 4.185 MRSTFT: 0.419 MSD: 1.714 SCE: 124.609 CFE: 2.473 RMS: 2.098 LUFS: 1.006 + TCN2 (daps) PESQ: 4.224 MRSTFT: 0.391 MSD: 1.621 SCE: 113.863 CFE: 2.352 RMS: 1.886 LUFS: 0.913 + SPSA (daps) PESQ: 4.099 MRSTFT: 0.645 MSD: 3.007 SCE: 213.596 CFE: 5.166 RMS: 2.989 LUFS: 1.511 + proxy0 (daps) PESQ: 3.605 MRSTFT: 0.685 MSD: 2.520 SCE: 249.159 CFE: 4.222 RMS: 2.732 LUFS: 1.350 + Proxy1 (daps) PESQ: 3.903 MRSTFT: 1.022 MSD: 4.749 SCE: 451.879 CFE: 4.927 RMS: 5.104 LUFS: 2.535 + Proxy2 (daps) PESQ: 3.891 MRSTFT: 1.037 MSD: 4.999 SCE: 395.421 CFE: 5.112 RMS: 6.754 LUFS: 3.117 + Autodiff (daps) PESQ: 4.222 MRSTFT: 0.416 MSD: 1.802 SCE: 109.004 CFE: 4.290 RMS: 1.758 LUFS: 0.799 + Baseline (daps) PESQ: 3.787 MRSTFT: 0.917 MSD: 4.065 SCE: 399.714 CFE: 3.742 RMS: 3.705 LUFS: 1.481 +-------------------------------- +Evaluation complete. + + +## VCTK @ 24kHz (train) (we don't train with train set) + +CUDA_VISIBLE_DEVICES=4 python scripts/eval.py \ +/import/c4dm-datasets/deepafx_st/logs/style/libritts/ \ +--root_dir /import/c4dm-datasets/deepafx_st/ \ +--gpu \ +--dataset vctk \ +--dataset_dir vctk_24000 \ +--spsa_version 2 \ +--tcn1_version 1 \ +--autodiff_version 1 \ +--tcn2_version 1 \ +--subset train \ +--output /import/c4dm-datasets/deepafx_st/ + +1000 + + Corrupt PESQ: 3.672 MRSTFT: 1.254 MSD: 2.008 SCE: 815.422 CFE: 6.686 RMS: 7.783 LUFS: 2.532 + TCN1 (vctk) PESQ: 4.181 MRSTFT: 0.467 MSD: 0.891 SCE: 173.751 CFE: 2.712 RMS: 2.651 LUFS: 1.165 + TCN2 (vctk) PESQ: 4.201 MRSTFT: 0.441 MSD: 0.856 SCE: 163.839 CFE: 2.583 RMS: 2.431 LUFS: 1.086 + SPSA (vctk) PESQ: 4.023 MRSTFT: 0.730 MSD: 1.359 SCE: 301.608 CFE: 5.477 RMS: 3.535 LUFS: 1.737 + proxy0 (vctk) PESQ: 3.651 MRSTFT: 0.737 MSD: 1.300 SCE: 321.701 CFE: 4.591 RMS: 3.166 LUFS: 1.453 + Proxy1 (vctk) PESQ: 3.951 MRSTFT: 1.044 MSD: 1.930 SCE: 591.476 CFE: 5.293 RMS: 5.194 LUFS: 2.651 + Proxy2 (vctk) PESQ: 3.894 MRSTFT: 1.087 MSD: 1.934 SCE: 514.048 CFE: 5.544 RMS: 7.065 LUFS: 3.363 + Autodiff (vctk) PESQ: 4.218 MRSTFT: 0.481 MSD: 0.924 SCE: 152.748 CFE: 5.169 RMS: 2.317 LUFS: 1.006 + Baseline (vctk) PESQ: 3.709 MRSTFT: 1.101 MSD: 1.911 SCE: 657.608 CFE: 4.647 RMS: 5.039 LUFS: 2.018 +-------------------------------- +Evaluation complete. + + Corrupt & 3.672 & 1.254 & 2.008 & 815.422 & 7.783 & 2.532 + TCN1 (vctk) & 4.181 & 0.467 & 0.891 & 173.751 & 2.651 & 1.165 + TCN2 (vctk) & 4.201 & 0.441 & 0.856 & 163.839 & 2.431 & 1.086 + SPSA (vctk) & 4.023 & 0.730 & 1.359 & 301.608 & 3.535 & 1.737 + proxy0 (vctk) & 3.651 & 0.737 & 1.300 & 321.701 & 3.166 & 1.453 + Proxy1 (vctk) & 3.951 & 1.044 & 1.930 & 591.476 & 5.194 & 2.651 + Proxy2 (vctk) & 3.894 & 1.087 & 1.934 & 514.048 & 7.065 & 3.363 + Autodiff (vctk) & 4.218 & 0.481 & 0.924 & 152.748 & 2.317 & 1.006 + Baseline (vctk) & 3.709 & 1.101 & 1.911 & 657.608 & 5.039 & 2.018 + +1000 + +New MSD scores + Corrupt PESQ: 3.672 MRSTFT: 1.254 MSD: 4.373 SCE: 815.422 CFE: 6.686 RMS: 7.783 LUFS: 2.532 + TCN1 (vctk) PESQ: 4.181 MRSTFT: 0.467 MSD: 1.620 SCE: 173.751 CFE: 2.712 RMS: 2.651 LUFS: 1.165 + TCN2 (vctk) PESQ: 4.201 MRSTFT: 0.441 MSD: 1.569 SCE: 163.839 CFE: 2.583 RMS: 2.431 LUFS: 1.086 + SPSA (vctk) PESQ: 4.023 MRSTFT: 0.730 MSD: 2.759 SCE: 301.608 CFE: 5.477 RMS: 3.535 LUFS: 1.737 + proxy0 (vctk) PESQ: 3.651 MRSTFT: 0.737 MSD: 2.254 SCE: 321.701 CFE: 4.591 RMS: 3.166 LUFS: 1.453 + Proxy1 (vctk) PESQ: 3.951 MRSTFT: 1.044 MSD: 3.948 SCE: 591.476 CFE: 5.293 RMS: 5.194 LUFS: 2.651 + Proxy2 (vctk) PESQ: 3.894 MRSTFT: 1.087 MSD: 4.248 SCE: 514.048 CFE: 5.544 RMS: 7.065 LUFS: 3.363 + Autodiff (vctk) PESQ: 4.218 MRSTFT: 0.481 MSD: 1.757 SCE: 152.748 CFE: 5.169 RMS: 2.317 LUFS: 1.006 + Baseline (vctk) PESQ: 3.709 MRSTFT: 1.101 MSD: 4.005 SCE: 657.608 CFE: 4.647 RMS: 5.039 LUFS: 2.018 +-------------------------------- +Evaluation complete. + + +## Style case study (SPSA) +CUDA_VISIBLE_DEVICES=7 python scripts/style_case_study.py \ +--ckpt_path "/import/c4dm-datasets/deepafx_st/logs/style/libritts/spsa/lightning_logs/version_2/checkpoints/epoch=367-step=1226911-val-libritts-spsa.ckpt" \ +--input_audio "/import/c4dm-datasets/deepafx_st/vctk_24000" \ +--style_audio "/import/c4dm-datasets/deepafx_st/daps_24000_styles/val" \ +--gpu \ + +10 inputs x 10 examples per style + +telephone 9 +System: & 1.178 MRSTFT 1.814 MSD 2.434 SCE 193.695 CFE 5.946 RMS 15.071 LUFS 6.835 +Baseline: PESQ 1.235 MRSTFT 1.890 MSD 2.532 SCE 212.969 CFE 2.732 RMS 2.732 LUFS 0.989 +Corrupt: PESQ 1.153 MRSTFT 2.757 MSD 3.827 SCE 384.676 CFE 9.193 RMS 9.193 LUFS 3.187 + +bright 9 +System: PESQ 1.150 MRSTFT 2.298 MSD 3.011 SCE 524.796 CFE 4.008 RMS 7.016 LUFS 3.618 +Baseline: PESQ 1.132 MRSTFT 2.364 MSD 3.047 SCE 797.717 CFE 4.258 RMS 4.258 LUFS 1.825 +Corrupt: PESQ 1.143 MRSTFT 2.626 MSD 3.755 SCE 2296.843 CFE 9.971 RMS 9.971 LUFS 2.716 + +radio 9 +System: PESQ 1.103 MRSTFT 2.064 MSD 2.793 SCE 239.602 CFE 15.518 RMS 2.521 LUFS 1.181 +Baseline: PESQ 1.139 MRSTFT 2.350 MSD 3.292 SCE 772.898 CFE 11.836 RMS 11.836 LUFS 5.307 +Corrupt: PESQ 1.175 MRSTFT 2.391 MSD 3.296 SCE 451.134 CFE 11.745 RMS 11.745 LUFS 5.414 + +podcast 9 +System: PESQ 1.145 MRSTFT 2.330 MSD 3.311 SCE 247.877 CFE 4.556 RMS 2.490 LUFS 1.003 +Baseline: PESQ 1.165 MRSTFT 2.417 MSD 3.415 SCE 597.773 CFE 4.114 RMS 4.114 LUFS 1.465 +Corrupt: PESQ 1.149 MRSTFT 2.445 MSD 3.484 SCE 335.067 CFE 5.127 RMS 5.127 LUFS 2.079 + +warm 9 +System: PESQ 1.124 MRSTFT 2.348 MSD 3.492 SCE 282.160 CFE 7.555 RMS 5.295 LUFS 3.060 +Baseline: PESQ 1.110 MRSTFT 2.528 MSD 3.804 SCE 790.690 CFE 8.481 RMS 8.481 LUFS 3.364 +Corrupt: PESQ 1.138 MRSTFT 2.530 MSD 3.703 SCE 565.930 CFE 14.402 RMS 14.402 LUFS 5.193 + +## Style case study +CUDA_VISIBLE_DEVICES=4 python scripts/style_case_study.py \ +--ckpt_paths \ +"/import/c4dm-datasets/deepafx_st/logs/style/libritts/spsa/lightning_logs/version_2/checkpoints/epoch=367-step=1226911-val-libritts-spsa.ckpt" \ +"/import/c4dm-datasets/deepafx_st/logs/style/libritts/autodiff/lightning_logs/version_1/checkpoints/epoch=367-step=1226911-val-libritts-autodiff.ckpt" \ +"/import/c4dm-datasets/deepafx_st/logs/style/libritts/proxy0/lightning_logs/version_0/checkpoints/epoch=327-step=1093551-val-libritts-proxy0.ckpt" \ +--style_audio "/import/c4dm-datasets/deepafx_st/daps_24000_styles_1000_diverse/train" \ +--output_dir "/import/c4dm-datasets/deepafx_st/style_case_study" \ +--gpu \ +#--save \ +#--plot \ + +broadcast-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:17<00:00, 1.72s/it] +spsa MSD: 4.522 SCE: 182.120 RMS: 3.554 LUFS: 1.509 +autodiff MSD: 4.031 SCE: 352.741 RMS: 2.788 LUFS: 1.148 +proxy0 MSD: 4.615 SCE: 269.970 RMS: 3.619 LUFS: 1.784 +Baseline MSD: 6.575 SCE: 396.523 RMS: 11.710 LUFS: 5.149 +Corrupt MSD: 6.950 SCE: 411.573 RMS: 12.435 LUFS: 5.258 + +broadcast-->telephone +100%|███████████████████████████████████████████| 10/10 [00:46<00:00, 4.61s/it] +spsa MSD: 6.581 SCE: 130.320 RMS: 14.865 LUFS: 6.979 +autodiff MSD: 5.966 SCE: 87.473 RMS: 6.932 LUFS: 3.022 +proxy0 MSD: 8.802 SCE: 223.404 RMS: 11.616 LUFS: 5.182 +Baseline MSD: 6.782 SCE: 283.199 RMS: 5.826 LUFS: 3.391 +Corrupt MSD: 11.492 SCE: 461.633 RMS: 5.259 LUFS: 2.276 + +broadcast-->neutral +100%|███████████████████████████████████████████| 10/10 [00:41<00:00, 4.12s/it] +spsa MSD: 8.776 SCE: 284.227 RMS: 3.605 LUFS: 1.535 +autodiff MSD: 8.765 SCE: 375.015 RMS: 8.036 LUFS: 3.435 +proxy0 MSD: 8.891 SCE: 299.929 RMS: 6.967 LUFS: 2.908 +Baseline MSD: 8.653 SCE: 294.783 RMS: 8.922 LUFS: 4.117 +Corrupt MSD: 9.496 SCE: 458.657 RMS: 9.055 LUFS: 4.152 + +broadcast-->bright +100%|███████████████████████████████████████████| 10/10 [00:57<00:00, 5.80s/it] +spsa MSD: 5.041 SCE: 632.066 RMS: 12.098 LUFS: 6.022 +autodiff MSD: 5.274 SCE: 518.414 RMS: 13.832 LUFS: 6.562 +proxy0 MSD: 6.408 SCE: 585.818 RMS: 7.310 LUFS: 4.727 +Baseline MSD: 5.414 SCE: 782.185 RMS: 20.304 LUFS: 10.463 +Corrupt MSD: 6.707 SCE: 2252.014 RMS: 11.961 LUFS: 7.429 + +broadcast-->warm +100%|███████████████████████████████████████████| 10/10 [00:56<00:00, 5.68s/it] +spsa MSD: 11.578 SCE: 167.850 RMS: 5.142 LUFS: 3.489 +autodiff MSD: 10.112 SCE: 247.329 RMS: 12.229 LUFS: 6.786 +proxy0 MSD: 11.477 SCE: 295.929 RMS: 22.112 LUFS: 11.750 +Baseline MSD: 10.507 SCE: 408.553 RMS: 26.613 LUFS: 12.465 +Corrupt MSD: 11.337 SCE: 789.713 RMS: 30.505 LUFS: 12.952 + +telephone-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:10<00:00, 1.03s/it] +spsa MSD: 8.697 SCE: 256.463 RMS: 2.159 LUFS: 1.360 +autodiff MSD: 7.018 SCE: 268.331 RMS: 1.101 LUFS: 0.977 +proxy0 MSD: 8.816 SCE: 1561.290 RMS: 6.828 LUFS: 1.942 +Baseline MSD: 9.524 SCE: 582.040 RMS: 8.052 LUFS: 4.040 +Corrupt MSD: 11.522 SCE: 357.191 RMS: 8.650 LUFS: 3.928 + +telephone-->telephone +100%|███████████████████████████████████████████| 10/10 [00:11<00:00, 1.15s/it] +spsa MSD: 5.716 SCE: 99.469 RMS: 6.316 LUFS: 2.624 +autodiff MSD: 6.058 SCE: 57.992 RMS: 3.994 LUFS: 1.865 +proxy0 MSD: 6.660 SCE: 69.429 RMS: 3.669 LUFS: 1.641 +Baseline MSD: 6.246 SCE: 134.453 RMS: 2.541 LUFS: 0.724 +Corrupt MSD: 6.477 SCE: 145.124 RMS: 3.164 LUFS: 1.102 + +telephone-->neutral +100%|███████████████████████████████████████████| 10/10 [00:22<00:00, 2.22s/it] +spsa MSD: 10.208 SCE: 260.075 RMS: 4.410 LUFS: 1.704 +autodiff MSD: 10.647 SCE: 267.470 RMS: 2.601 LUFS: 1.424 +proxy0 MSD: 11.776 SCE: 782.558 RMS: 14.902 LUFS: 5.210 +Baseline MSD: 10.858 SCE: 462.504 RMS: 4.944 LUFS: 1.820 +Corrupt MSD: 11.803 SCE: 359.048 RMS: 11.537 LUFS: 3.876 + +telephone-->bright +100%|███████████████████████████████████████████| 10/10 [00:09<00:00, 1.04it/s] +spsa MSD: 6.466 SCE: 1290.766 RMS: 2.904 LUFS: 1.350 +autodiff MSD: 6.892 SCE: 359.603 RMS: 4.223 LUFS: 2.198 +proxy0 MSD: 7.223 SCE: 1230.582 RMS: 4.967 LUFS: 2.182 +Baseline MSD: 7.135 SCE: 601.520 RMS: 3.659 LUFS: 1.522 +Corrupt MSD: 8.123 SCE: 2198.995 RMS: 4.774 LUFS: 1.853 + +telephone-->warm +100%|███████████████████████████████████████████| 10/10 [00:36<00:00, 3.60s/it] +spsa MSD: 14.016 SCE: 187.455 RMS: 11.562 LUFS: 4.883 +autodiff MSD: 11.033 SCE: 368.469 RMS: 8.113 LUFS: 4.508 +proxy0 MSD: 15.291 SCE: 1533.699 RMS: 44.513 LUFS: 18.723 +Baseline MSD: 15.129 SCE: 286.571 RMS: 5.691 LUFS: 1.939 +Corrupt MSD: 17.063 SCE: 401.732 RMS: 22.030 LUFS: 7.518 + +neutral-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:07<00:00, 1.29it/s] +spsa MSD: 5.379 SCE: 186.673 RMS: 1.992 LUFS: 0.990 +autodiff MSD: 5.450 SCE: 184.832 RMS: 0.910 LUFS: 0.353 +proxy0 MSD: 5.696 SCE: 426.500 RMS: 2.029 LUFS: 1.023 +Baseline MSD: 8.856 SCE: 707.260 RMS: 13.779 LUFS: 6.409 +Corrupt MSD: 9.007 SCE: 253.347 RMS: 12.566 LUFS: 5.949 + +neutral-->telephone +100%|███████████████████████████████████████████| 10/10 [00:21<00:00, 2.11s/it] +spsa MSD: 6.586 SCE: 137.444 RMS: 11.907 LUFS: 5.679 +autodiff MSD: 6.775 SCE: 79.489 RMS: 7.413 LUFS: 2.940 +proxy0 MSD: 8.618 SCE: 273.167 RMS: 8.310 LUFS: 3.709 +Baseline MSD: 7.167 SCE: 163.260 RMS: 3.488 LUFS: 1.508 +Corrupt MSD: 12.148 SCE: 266.549 RMS: 9.099 LUFS: 2.703 + +neutral-->neutral +100%|███████████████████████████████████████████| 10/10 [00:17<00:00, 1.79s/it] +spsa MSD: 10.705 SCE: 227.437 RMS: 3.666 LUFS: 1.459 +autodiff MSD: 10.884 SCE: 227.260 RMS: 6.592 LUFS: 2.734 +proxy0 MSD: 10.923 SCE: 233.776 RMS: 5.590 LUFS: 2.262 +Baseline MSD: 10.688 SCE: 445.606 RMS: 3.532 LUFS: 1.476 +Corrupt MSD: 11.592 SCE: 264.716 RMS: 5.298 LUFS: 2.222 + +neutral-->bright +100%|███████████████████████████████████████████| 10/10 [00:17<00:00, 1.79s/it] +spsa MSD: 6.496 SCE: 602.073 RMS: 3.655 LUFS: 1.890 +autodiff MSD: 7.098 SCE: 290.838 RMS: 5.375 LUFS: 2.539 +proxy0 MSD: 8.008 SCE: 830.619 RMS: 10.439 LUFS: 6.364 +Baseline MSD: 8.187 SCE: 940.809 RMS: 6.405 LUFS: 2.563 +Corrupt MSD: 9.876 SCE: 2139.984 RMS: 9.523 LUFS: 2.784 + +neutral-->warm +100%|███████████████████████████████████████████| 10/10 [00:54<00:00, 5.50s/it] +spsa MSD: 14.845 SCE: 279.135 RMS: 3.300 LUFS: 2.232 +autodiff MSD: 14.592 SCE: 423.990 RMS: 3.939 LUFS: 1.885 +proxy0 MSD: 13.307 SCE: 515.236 RMS: 16.671 LUFS: 8.101 +Baseline MSD: 14.943 SCE: 548.626 RMS: 11.837 LUFS: 4.657 +Corrupt MSD: 17.023 SCE: 861.202 RMS: 14.336 LUFS: 5.142 + +bright-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:06<00:00, 1.62it/s] +spsa MSD: 6.772 SCE: 532.672 RMS: 4.613 LUFS: 2.300 +autodiff MSD: 5.774 SCE: 341.599 RMS: 2.149 LUFS: 1.043 +proxy0 MSD: 6.704 SCE: 522.506 RMS: 1.888 LUFS: 1.093 +Baseline MSD: 8.718 SCE: 983.950 RMS: 10.220 LUFS: 4.666 +Corrupt MSD: 8.120 SCE: 2187.024 RMS: 6.306 LUFS: 4.037 + +bright-->telephone +100%|███████████████████████████████████████████| 10/10 [00:17<00:00, 1.75s/it] +spsa MSD: 5.855 SCE: 270.561 RMS: 8.932 LUFS: 4.175 +autodiff MSD: 6.628 SCE: 94.206 RMS: 4.192 LUFS: 1.253 +proxy0 MSD: 9.108 SCE: 71.258 RMS: 8.973 LUFS: 4.219 +Baseline MSD: 6.980 SCE: 230.584 RMS: 4.646 LUFS: 1.684 +Corrupt MSD: 8.623 SCE: 2354.555 RMS: 4.436 LUFS: 1.660 + +bright-->neutral +100%|███████████████████████████████████████████| 10/10 [00:24<00:00, 2.49s/it] +spsa MSD: 9.165 SCE: 398.710 RMS: 2.955 LUFS: 1.443 +autodiff MSD: 9.007 SCE: 289.767 RMS: 4.499 LUFS: 2.343 +proxy0 MSD: 9.422 SCE: 306.264 RMS: 10.855 LUFS: 4.741 +Baseline MSD: 9.376 SCE: 393.232 RMS: 2.709 LUFS: 1.195 +Corrupt MSD: 10.215 SCE: 2303.824 RMS: 10.234 LUFS: 2.903 + +bright-->bright +100%|███████████████████████████████████████████| 10/10 [00:19<00:00, 1.98s/it] +spsa MSD: 6.685 SCE: 411.594 RMS: 5.279 LUFS: 1.701 +autodiff MSD: 6.916 SCE: 375.956 RMS: 6.980 LUFS: 2.641 +proxy0 MSD: 7.404 SCE: 978.521 RMS: 4.706 LUFS: 2.541 +Baseline MSD: 6.885 SCE: 488.962 RMS: 3.045 LUFS: 1.165 +Corrupt MSD: 7.663 SCE: 433.188 RMS: 5.412 LUFS: 2.207 + +bright-->warm +100%|███████████████████████████████████████████| 10/10 [00:41<00:00, 4.16s/it] +spsa MSD: 14.430 SCE: 702.406 RMS: 7.840 LUFS: 3.216 +autodiff MSD: 13.001 SCE: 304.595 RMS: 6.781 LUFS: 3.275 +proxy0 MSD: 14.754 SCE: 1581.939 RMS: 43.075 LUFS: 17.552 +Baseline MSD: 15.660 SCE: 1029.709 RMS: 8.135 LUFS: 2.869 +Corrupt MSD: 15.150 SCE: 2758.329 RMS: 24.744 LUFS: 8.197 + +warm-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:07<00:00, 1.32it/s] +spsa MSD: 5.225 SCE: 183.420 RMS: 6.149 LUFS: 2.679 +autodiff MSD: 4.704 SCE: 245.411 RMS: 1.272 LUFS: 0.566 +proxy0 MSD: 5.890 SCE: 717.873 RMS: 2.599 LUFS: 1.820 +Baseline MSD: 8.872 SCE: 839.614 RMS: 15.062 LUFS: 7.262 +Corrupt MSD: 12.639 SCE: 694.704 RMS: 29.492 LUFS: 12.191 + +warm-->telephone +100%|███████████████████████████████████████████| 10/10 [00:11<00:00, 1.18s/it] +spsa MSD: 7.045 SCE: 102.885 RMS: 13.941 LUFS: 6.782 +autodiff MSD: 7.008 SCE: 95.444 RMS: 4.295 LUFS: 2.152 +proxy0 MSD: 8.251 SCE: 163.098 RMS: 12.524 LUFS: 5.308 +Baseline MSD: 8.355 SCE: 363.053 RMS: 3.255 LUFS: 1.194 +Corrupt MSD: 15.149 SCE: 508.455 RMS: 20.462 LUFS: 7.136 + +warm-->neutral +100%|███████████████████████████████████████████| 10/10 [00:13<00:00, 1.39s/it] +spsa MSD: 10.097 SCE: 240.049 RMS: 4.867 LUFS: 1.852 +autodiff MSD: 9.771 SCE: 365.662 RMS: 6.075 LUFS: 2.589 +proxy0 MSD: 22.146 SCE: 667.698 RMS: 12.068 LUFS: 5.956 +Baseline MSD: 10.539 SCE: 671.609 RMS: 8.868 LUFS: 3.426 +Corrupt MSD: 13.590 SCE: 711.717 RMS: 14.475 LUFS: 5.441 + +warm-->bright +100%|███████████████████████████████████████████| 10/10 [00:09<00:00, 1.01it/s] +spsa MSD: 6.968 SCE: 621.412 RMS: 3.168 LUFS: 2.040 +autodiff MSD: 6.544 SCE: 403.204 RMS: 4.901 LUFS: 2.014 +proxy0 MSD: 12.103 SCE: 746.076 RMS: 16.569 LUFS: 9.295 +Baseline MSD: 8.659 SCE: 1147.820 RMS: 7.271 LUFS: 1.948 +Corrupt MSD: 15.391 SCE: 2795.258 RMS: 24.478 LUFS: 7.965 + +warm-->warm +100%|███████████████████████████████████████████| 10/10 [00:23<00:00, 2.30s/it] +spsa MSD: 12.112 SCE: 202.358 RMS: 3.484 LUFS: 1.735 +autodiff MSD: 11.914 SCE: 270.304 RMS: 3.797 LUFS: 1.755 +proxy0 MSD: 13.057 SCE: 790.326 RMS: 10.564 LUFS: 4.503 +Baseline MSD: 12.075 SCE: 356.166 RMS: 3.608 LUFS: 1.892 +Corrupt MSD: 12.278 SCE: 366.978 RMS: 2.604 LUFS: 1.345 + + + +## Jamendo @ 24kHz (test) + +CUDA_VISIBLE_DEVICES=4 python scripts/eval.py \ +/import/c4dm-datasets/deepafx_st/logs_jamendo/style/jamendo/ \ +--root_dir /import/c4dm-datasets/mtg-jamendo-raw/mtg-jamendo-dataset/ \ +--gpu \ +--dataset jamendo \ +--dataset_dir mtg-jamendo_24000/ \ +--spsa_version 0 \ +--tcn1_version 0 \ +--autodiff_version 0 \ +--tcn2_version 0 \ +--subset test \ +--save \ +--ext flac \ +--output /import/c4dm-datasets/deepafx_st/ + +1000 + + Corrupt PESQ: 2.849 MRSTFT: 1.175 MSD: 2.269 SCE: 669.298 CFE: 5.435 RMS: 6.667 LUFS: 2.510 + TCN1 (jamendo) PESQ: 3.351 MRSTFT: 0.547 MSD: 1.164 SCE: 166.483 CFE: 2.382 RMS: 3.495 LUFS: 1.609 + TCN2 (jamendo) PESQ: 3.323 MRSTFT: 0.559 MSD: 1.181 SCE: 164.578 CFE: 2.360 RMS: 3.136 LUFS: 1.501 + SPSA (jamendo) PESQ: 3.123 MRSTFT: 0.712 MSD: 1.520 SCE: 220.400 CFE: 4.087 RMS: 3.003 LUFS: 1.396 + proxy0 (jamendo) PESQ: 2.941 MRSTFT: 0.773 MSD: 1.611 SCE: 221.463 CFE: 4.070 RMS: 3.118 LUFS: 1.482 + Proxy1 (jamendo) PESQ: 2.828 MRSTFT: 1.074 MSD: 2.425 SCE: 423.571 CFE: 5.936 RMS: 6.143 LUFS: 2.961 + Proxy2 (jamendo) PESQ: 2.792 MRSTFT: 1.002 MSD: 2.121 SCE: 291.825 CFE: 6.212 RMS: 3.480 LUFS: 1.702 + Autodiff (jamendo) PESQ: 3.348 MRSTFT: 0.500 MSD: 1.145 SCE: 154.312 CFE: 4.247 RMS: 2.451 LUFS: 1.098 + Baseline (jamendo) PESQ: 2.841 MRSTFT: 0.878 MSD: 1.995 SCE: 254.154 CFE: 3.615 RMS: 3.750 LUFS: 1.531 +-------------------------------- +Evaluation complete. + +New Mel STFT for MSD +1000 + + Corrupt PESQ: 2.849 MRSTFT: 1.175 MSD: 6.186 SCE: 669.298 CFE: 5.435 RMS: 6.667 LUFS: 2.510 + TCN1 (jamendo) PESQ: 3.351 MRSTFT: 0.547 MSD: 2.480 SCE: 166.480 CFE: 2.382 RMS: 3.495 LUFS: 1.609 + TCN2 (jamendo) PESQ: 3.320 MRSTFT: 0.559 MSD: 2.485 SCE: 164.577 CFE: 2.360 RMS: 3.136 LUFS: 1.501 + SPSA (jamendo) PESQ: 3.123 MRSTFT: 0.712 MSD: 3.203 SCE: 220.400 CFE: 4.087 RMS: 3.003 LUFS: 1.396 + proxy0 (jamendo) PESQ: 2.941 MRSTFT: 0.773 MSD: 2.965 SCE: 221.462 CFE: 4.070 RMS: 3.118 LUFS: 1.482 + Proxy1 (jamendo) PESQ: 2.828 MRSTFT: 1.074 MSD: 7.014 SCE: 423.571 CFE: 5.936 RMS: 6.143 LUFS: 2.961 + Proxy2 (jamendo) PESQ: 2.793 MRSTFT: 1.002 MSD: 5.180 SCE: 291.825 CFE: 6.212 RMS: 3.480 LUFS: 1.702 + Autodiff (jamendo) PESQ: 3.348 MRSTFT: 0.500 MSD: 2.426 SCE: 154.307 CFE: 4.247 RMS: 2.451 LUFS: 1.098 + Baseline (jamendo) PESQ: 2.839 MRSTFT: 0.878 MSD: 4.285 SCE: 254.154 CFE: 3.615 RMS: 3.750 LUFS: 1.531 +-------------------------------- +Evaluation complete. + + Corrupt & 2.849 & 1.175 & 6.186 & 669.298 & 6.667 & 2.510 + TCN1 (jamendo) & 3.351 & 0.547 & 2.480 & 166.480 & 3.495 & 1.609 + TCN2 (jamendo) & 3.320 & 0.559 & 2.485 & 164.577 & 3.136 & 1.501 + SPSA (jamendo) & 3.123 & 0.712 & 3.203 & 220.400 & 3.003 & 1.396 + proxy0 (jamendo) & 2.941 & 0.773 & 2.965 & 221.462 & 3.118 & 1.482 + Proxy1 (jamendo) & 2.828 & 1.074 & 7.014 & 423.571 & 6.143 & 2.961 + Proxy2 (jamendo) & 2.793 & 1.002 & 5.180 & 291.825 & 3.480 & 1.702 + Autodiff (jamendo) & 3.348 & 0.500 & 2.426 & 154.307 & 2.451 & 1.098 + Baseline (jamendo) & 2.839 & 0.878 & 4.285 & 254.154 & 3.750 & 1.531 +-------------------------------- +Evaluation complete. + + +-------------------------------- with music proxies +1000 + + Corrupt PESQ: 2.927 MRSTFT: 1.198 MSD: 6.088 SCE: 646.464 CFE: 5.754 RMS: 6.695 LUFS: 2.518 + TCN1 (jamendo) PESQ: 3.402 MRSTFT: 0.547 MSD: 2.294 SCE: 160.940 CFE: 2.598 RMS: 3.261 LUFS: 1.483 + TCN2 (jamendo) PESQ: 3.390 MRSTFT: 0.548 MSD: 2.278 SCE: 152.314 CFE: 2.494 RMS: 2.951 LUFS: 1.397 + SPSA (jamendo) PESQ: 3.173 MRSTFT: 0.716 MSD: 3.024 SCE: 210.077 CFE: 4.549 RMS: 2.809 LUFS: 1.344 + proxy0 (jamendo) PESQ: 2.926 MRSTFT: 0.787 MSD: 2.838 SCE: 221.322 CFE: 4.426 RMS: 2.785 LUFS: 1.390 + proxy1 (jamendo) PESQ: 2.819 MRSTFT: 1.092 MSD: 6.791 SCE: 395.326 CFE: 6.543 RMS: 6.276 LUFS: 3.032 + proxy2 (jamendo) PESQ: 2.833 MRSTFT: 1.016 MSD: 5.005 SCE: 280.831 CFE: 6.675 RMS: 3.377 LUFS: 1.634 + proxy0m (jamendo) PESQ: 2.765 MRSTFT: 0.845 MSD: 3.211 SCE: 255.230 CFE: 4.972 RMS: 3.227 LUFS: 1.608 + proxy1m (jamendo) PESQ: 2.532 MRSTFT: 1.166 MSD: 7.070 SCE: 591.900 CFE: 11.163 RMS: 5.660 LUFS: 2.593 + proxy2m (jamendo) PESQ: 2.648 MRSTFT: 1.137 MSD: 6.368 SCE: 605.618 CFE: 10.520 RMS: 5.903 LUFS: 2.587 + Autodiff (jamendo) PESQ: 3.355 MRSTFT: 0.488 MSD: 2.149 SCE: 144.740 CFE: 4.373 RMS: 2.167 LUFS: 1.005 + Baseline (jamendo) PESQ: 2.849 MRSTFT: 0.925 MSD: 4.422 SCE: 263.193 CFE: 4.148 RMS: 4.254 LUFS: 1.706 +-------------------------------- +Evaluation complete. + + Corrupt & 2.927 & 1.198 & 6.088 & 646.464 & 6.695 & 2.518 + TCN1 (jamendo) & 3.402 & 0.547 & 2.294 & 160.940 & 3.261 & 1.483 + TCN2 (jamendo) & 3.390 & 0.548 & 2.278 & 152.314 & 2.951 & 1.397 + SPSA (jamendo) & 3.173 & 0.716 & 3.024 & 210.077 & 2.809 & 1.344 + proxy0 (jamendo) & 2.926 & 0.787 & 2.838 & 221.322 & 2.785 & 1.390 + proxy1 (jamendo) & 2.819 & 1.092 & 6.791 & 395.326 & 6.276 & 3.032 + proxy2 (jamendo) & 2.833 & 1.016 & 5.005 & 280.831 & 3.377 & 1.634 + proxy0m (jamendo) & 2.765 & 0.845 & 3.211 & 255.230 & 3.227 & 1.608 + proxy1m (jamendo) & 2.532 & 1.166 & 7.070 & 591.900 & 5.660 & 2.593 + proxy2m (jamendo) & 2.648 & 1.137 & 6.368 & 605.618 & 5.903 & 2.587 + Autodiff (jamendo) & 3.355 & 0.488 & 2.149 & 144.740 & 2.167 & 1.005 + Baseline (jamendo) & 2.849 & 0.925 & 4.422 & 263.193 & 4.254 & 1.706 + +## MUSDB18 @ 44.1kHz (train) + +CUDA_VISIBLE_DEVICES=4 python scripts/eval.py \ +/import/c4dm-datasets/deepafx_st/logs_jamendo/style/jamendo/ \ +--root_dir /import/c4dm-datasets/deepafx_st \ +--gpu \ +--dataset musdb18_44100 \ +--dataset_dir musdb18_44100/ \ +--spsa_version 0 \ +--tcn1_version 0 \ +--autodiff_version 0 \ +--tcn2_version 0 \ +--subset train \ +--length 262144 \ +--save \ +--ext wav \ +--output /import/c4dm-datasets/deepafx_st/ + +1000 + + Corrupt PESQ: 2.900 MRSTFT: 1.252 MSD: 4.342 SCE: 1088.327 CFE: 5.158 RMS: 5.940 LUFS: 2.312 + TCN1 (musdb18_44100) PESQ: 3.121 MRSTFT: 0.896 MSD: 2.956 SCE: 730.446 CFE: 3.594 RMS: 4.548 LUFS: 2.231 + TCN2 (musdb18_44100) PESQ: 3.107 MRSTFT: 0.917 MSD: 2.986 SCE: 749.454 CFE: 3.584 RMS: 4.208 LUFS: 2.061 + SPSA (musdb18_44100) PESQ: 3.126 MRSTFT: 0.789 MSD: 2.321 SCE: 574.392 CFE: 4.198 RMS: 2.925 LUFS: 1.394 + proxy0 (musdb18_44100) PESQ: 2.804 MRSTFT: 0.950 MSD: 2.778 SCE: 742.068 CFE: 4.561 RMS: 3.835 LUFS: 1.921 + proxy1 (musdb18_44100) PESQ: 2.853 MRSTFT: 1.165 MSD: 4.852 SCE: 1005.729 CFE: 7.319 RMS: 6.451 LUFS: 3.269 + proxy2 (musdb18_44100) PESQ: 2.857 MRSTFT: 1.045 MSD: 3.809 SCE: 617.259 CFE: 6.184 RMS: 3.932 LUFS: 1.971 + proxy0m (musdb18_44100) PESQ: 2.791 MRSTFT: 0.946 MSD: 2.800 SCE: 757.082 CFE: 4.798 RMS: 4.209 LUFS: 2.127 + proxy1m (musdb18_44100) PESQ: 2.493 MRSTFT: 1.198 MSD: 5.090 SCE: 1021.863 CFE: 11.674 RMS: 5.585 LUFS: 2.731 + proxy2m (musdb18_44100) PESQ: 2.575 MRSTFT: 1.225 MSD: 5.450 SCE: 1172.493 CFE: 12.245 RMS: 5.973 LUFS: 2.913 + Autodiff (musdb18_44100) PESQ: 3.396 MRSTFT: 0.608 MSD: 1.695 SCE: 456.131 CFE: 4.170 RMS: 2.559 LUFS: 1.197 + Baseline (musdb18_44100) PESQ: 2.994 MRSTFT: 0.821 MSD: 3.052 SCE: 379.400 CFE: 3.871 RMS: 4.078 LUFS: 1.665 +-------------------------------- +Evaluation complete. + + Corrupt & 2.900 & 1.252 & 4.342 & 1088.327 & 5.940 & 2.312 + TCN1 (musdb18_44100) & 3.121 & 0.896 & 2.956 & 730.446 & 4.548 & 2.231 + TCN2 (musdb18_44100) & 3.107 & 0.917 & 2.986 & 749.454 & 4.208 & 2.061 + SPSA (musdb18_44100) & 3.126 & 0.789 & 2.321 & 574.392 & 2.925 & 1.394 + proxy0 (musdb18_44100) & 2.804 & 0.950 & 2.778 & 742.068 & 3.835 & 1.921 + proxy1 (musdb18_44100) & 2.853 & 1.165 & 4.852 & 1005.729 & 6.451 & 3.269 + proxy2 (musdb18_44100) & 2.857 & 1.045 & 3.809 & 617.259 & 3.932 & 1.971 + proxy0m (musdb18_44100) & 2.791 & 0.946 & 2.800 & 757.082 & 4.209 & 2.127 + proxy1m (musdb18_44100) & 2.493 & 1.198 & 5.090 & 1021.863 & 5.585 & 2.731 + proxy2m (musdb18_44100) & 2.575 & 1.225 & 5.450 & 1172.493 & 5.973 & 2.913 + Autodiff (musdb18_44100) & 3.396 & 0.608 & 1.695 & 456.131 & 2.559 & 1.197 + Baseline (musdb18_44100) & 2.994 & 0.821 & 3.052 & 379.400 & 4.078 & 1.665 + + +## MUSDB18 @ 24kHz (train) +quota +CUDA_VISIBLE_DEVICES=4 python scripts/eval.py \ +/import/c4dm-datasets/deepafx_st/logs_jamendo/style/jamendo/ \ +--root_dir /import/c4dm-datasets/deepafx_st \ +--gpu \ +--dataset musdb18_24000 \ +--dataset_dir musdb18_24000/ \ +--spsa_version 0 \ +--tcn1_version 0 \ +--autodiff_version 0 \ +--tcn2_version 0 \ +--subset train \ +--length 131072 \ +--save \ +--ext wav \ +--output /import/c4dm-datasets/deepafx_st/ + +1000 + + Corrupt PESQ: 2.925 MRSTFT: 1.237 MSD: 6.452 SCE: 865.043 CFE: 4.969 RMS: 5.658 LUFS: 2.119 + TCN1 (musdb18_24000) PESQ: 3.506 MRSTFT: 0.501 MSD: 2.564 SCE: 205.675 CFE: 2.251 RMS: 3.483 LUFS: 1.657 + TCN2 (musdb18_24000) PESQ: 3.474 MRSTFT: 0.520 MSD: 2.588 SCE: 194.676 CFE: 2.338 RMS: 3.098 LUFS: 1.533 + SPSA (musdb18_24000) PESQ: 3.290 MRSTFT: 0.690 MSD: 3.340 SCE: 258.120 CFE: 3.590 RMS: 2.851 LUFS: 1.417 + proxy0 (musdb18_24000) PESQ: 3.083 MRSTFT: 0.709 MSD: 3.075 SCE: 280.187 CFE: 3.672 RMS: 3.188 LUFS: 1.577 + Proxy1 (musdb18_24000) PESQ: 2.936 MRSTFT: 1.085 MSD: 7.042 SCE: 536.682 CFE: 5.264 RMS: 5.913 LUFS: 2.865 + Proxy2 (musdb18_24000) PESQ: 2.972 MRSTFT: 1.036 MSD: 5.544 SCE: 361.326 CFE: 5.247 RMS: 3.547 LUFS: 1.780 + Autodiff (musdb18_24000) PESQ: 3.522 MRSTFT: 0.460 MSD: 2.269 SCE: 194.226 CFE: 3.316 RMS: 2.309 LUFS: 1.116 + Baseline (musdb18_24000) PESQ: 3.059 MRSTFT: 0.811 MSD: 4.069 SCE: 261.277 CFE: 3.133 RMS: 3.309 LUFS: 1.352 +-------------------------------- +Evaluation complete. + + +## Jamendo @ 44.1kHz (test) + +CUDA_VISIBLE_DEVICES=4 python scripts/eval.py \ +/import/c4dm-datasets/deepafx_st/logs_jamendo/style/jamendo/ \ +--root_dir /import/c4dm-datasets/mtg-jamendo-raw/mtg-jamendo-dataset/ \ +--gpu \ +--dataset jamendo_44100 \ +--dataset_dir mtg-jamendo_44100/ \ +--spsa_version 0 \ +--tcn1_version 0 \ +--autodiff_version 0 \ +--tcn2_version 0 \ +--subset test \ +--length 262144 \ +--save \ +--ext wav \ +--output /import/c4dm-datasets/deepafx_st/ + +1000 + + Corrupt PESQ: 2.874 MRSTFT: 1.109 MSD: 4.454 SCE: 767.664 CFE: 5.587 RMS: 6.793 LUFS: inf + TCN1 (jamendo_44100) PESQ: 3.168 MRSTFT: 0.876 MSD: 2.921 SCE: 494.368 CFE: 3.459 RMS: 4.372 LUFS: 2.070 + TCN2 (jamendo_44100) PESQ: 3.123 MRSTFT: 0.903 MSD: 2.973 SCE: 517.350 CFE: 3.417 RMS: 4.084 LUFS: 1.899 + SPSA (jamendo_44100) PESQ: 3.172 MRSTFT: 0.759 MSD: 2.458 SCE: 386.229 CFE: 4.634 RMS: 2.839 LUFS: 1.311 + proxy0 (jamendo_44100) PESQ: 2.764 MRSTFT: 1.033 MSD: 2.869 SCE: 488.273 CFE: 4.398 RMS: 3.710 LUFS: 1.824 + proxy1 (jamendo_44100) PESQ: 2.865 MRSTFT: 1.101 MSD: 5.194 SCE: 689.232 CFE: 7.534 RMS: 6.792 LUFS: 3.365 + proxy2 (jamendo_44100) PESQ: 2.888 MRSTFT: 0.977 MSD: 3.883 SCE: 429.862 CFE: 7.022 RMS: 3.480 LUFS: 1.709 + proxy0m (jamendo_44100) PESQ: 2.699 MRSTFT: 1.042 MSD: 2.928 SCE: 497.022 CFE: 4.798 RMS: 3.942 LUFS: 1.961 + proxy1m (jamendo_44100) PESQ: 2.512 MRSTFT: 1.148 MSD: 5.385 SCE: 854.932 CFE: 12.615 RMS: 5.940 LUFS: 2.740 + proxy2m (jamendo_44100) PESQ: 2.625 MRSTFT: 1.133 MSD: 5.512 SCE: 844.595 CFE: 12.405 RMS: 6.417 LUFS: 2.876 + Autodiff (jamendo_44100) PESQ: 3.400 MRSTFT: 0.585 MSD: 1.824 SCE: 304.393 CFE: 4.126 RMS: 2.425 LUFS: 1.106 + Baseline (jamendo_44100) PESQ: 2.931 MRSTFT: 0.887 MSD: 3.355 SCE: 341.957 CFE: 4.474 RMS: 4.749 LUFS: 1.882 +-------------------------------- +Evaluation complete. + + + Corrupt & 2.874 & 1.109 & 4.454 & 767.664 & 6.793 & inf + TCN1 (jamendo_44100) & 3.168 & 0.876 & 2.921 & 494.368 & 4.372 & 2.070 + TCN2 (jamendo_44100) & 3.123 & 0.903 & 2.973 & 517.350 & 4.084 & 1.899 + SPSA (jamendo_44100) & 3.172 & 0.759 & 2.458 & 386.229 & 2.839 & 1.311 + proxy0 (jamendo_44100) & 2.764 & 1.033 & 2.869 & 488.273 & 3.710 & 1.824 + proxy1 (jamendo_44100) & 2.865 & 1.101 & 5.194 & 689.232 & 6.792 & 3.365 + proxy2 (jamendo_44100) & 2.888 & 0.977 & 3.883 & 429.862 & 3.480 & 1.709 + proxy0m (jamendo_44100) & 2.699 & 1.042 & 2.928 & 497.022 & 3.942 & 1.961 + proxy1m (jamendo_44100) & 2.512 & 1.148 & 5.385 & 854.932 & 5.940 & 2.740 + proxy2m (jamendo_44100) & 2.625 & 1.133 & 5.512 & 844.595 & 6.417 & 2.876 + Autodiff (jamendo_44100) & 3.400 & 0.585 & 1.824 & 304.393 & 2.425 & 1.106 + Baseline (jamendo_44100) & 2.931 & 0.887 & 3.355 & 341.957 & 4.749 & 1.882 + + +## Probes +CUDA_VISIBLE_DEVICES=0 python scripts/eval_probes.py \ +--ckpt_dir /import/c4dm-datasets/deepafx_st/logs/probes_new/ \ +--eval_dataset /import/c4dm-datasets/deepafx_st/daps_24000_styles_1000_diverse/ \ +--subset test \ +--output_dir probes \ +--gpu \ + +true acc: 100.00% f1: 1.00 + precision recall f1-score support + + broadcast 1.00 1.00 1.00 100 + telephone 1.00 1.00 1.00 100 + neutral 1.00 1.00 1.00 100 + bright 1.00 1.00 1.00 100 + warm 1.00 1.00 1.00 100 + + accuracy 1.00 500 + macro avg 1.00 1.00 1.00 500 +weighted avg 1.00 1.00 1.00 500 + +cdpam-mlp acc: 100.00% f1: 1.00 + precision recall f1-score support + + broadcast 1.00 1.00 1.00 100 + telephone 1.00 1.00 1.00 100 + neutral 1.00 1.00 1.00 100 + bright 1.00 1.00 1.00 100 + warm 1.00 1.00 1.00 100 + + accuracy 1.00 500 + macro avg 1.00 1.00 1.00 500 +weighted avg 1.00 1.00 1.00 500 + +deepafx_st-linear acc: 97.60% f1: 0.98 + precision recall f1-score support + + broadcast 0.90 0.99 0.94 100 + telephone 1.00 1.00 1.00 100 + neutral 0.99 0.89 0.94 100 + bright 1.00 1.00 1.00 100 + warm 1.00 1.00 1.00 100 + + accuracy 0.98 500 + macro avg 0.98 0.98 0.98 500 +weighted avg 0.98 0.98 0.98 500 + +openl3-mlp acc: 45.60% f1: 0.40 + precision recall f1-score support + + broadcast 0.38 0.30 0.33 100 + telephone 0.36 0.65 0.46 100 + neutral 0.30 0.27 0.28 100 + bright 0.71 1.00 0.83 100 + warm 1.00 0.06 0.11 100 + + accuracy 0.46 500 + macro avg 0.55 0.46 0.40 500 +weighted avg 0.55 0.46 0.40 500 + +random_mel-mlp acc: 81.40% f1: 0.79 + precision recall f1-score support + + broadcast 0.53 0.89 0.66 100 + telephone 1.00 1.00 1.00 100 + neutral 0.69 0.22 0.33 100 + bright 0.97 1.00 0.99 100 + warm 0.99 0.96 0.97 100 + + accuracy 0.81 500 + macro avg 0.84 0.81 0.79 500 +weighted avg 0.84 0.81 0.79 500 + +openl3-linear acc: 42.00% f1: 0.37 + precision recall f1-score support + + broadcast 0.28 0.18 0.22 100 + telephone 0.33 0.58 0.42 100 + neutral 0.26 0.27 0.26 100 + bright 0.67 1.00 0.80 100 + warm 1.00 0.07 0.13 100 + + accuracy 0.42 500 + macro avg 0.51 0.42 0.37 500 +weighted avg 0.51 0.42 0.37 500 + +random_mel-linear acc: 40.00% f1: 0.23 + precision recall f1-score support + + broadcast 0.00 0.00 0.00 100 + telephone 0.46 1.00 0.63 100 + neutral 0.00 0.00 0.00 100 + bright 0.36 1.00 0.52 100 + warm 0.00 0.00 0.00 100 + + accuracy 0.40 500 + macro avg 0.16 0.40 0.23 500 +weighted avg 0.16 0.40 0.23 500 + +cdpam-linear acc: 64.20% f1: 0.58 + precision recall f1-score support + + broadcast 0.49 0.57 0.53 100 + telephone 1.00 1.00 1.00 100 + neutral 0.41 0.65 0.50 100 + bright 0.80 0.99 0.88 100 + warm 0.00 0.00 0.00 100 + + accuracy 0.64 500 + macro avg 0.54 0.64 0.58 500 +weighted avg 0.54 0.64 0.58 500 + +deepafx_st-mlp acc: 98.20% f1: 0.98 + precision recall f1-score support + + broadcast 0.92 1.00 0.96 100 + telephone 1.00 1.00 1.00 100 + neutral 1.00 0.91 0.95 100 + bright 1.00 1.00 1.00 100 + warm 1.00 1.00 1.00 100 + + accuracy 0.98 500 + macro avg 0.98 0.98 0.98 500 +weighted avg 0.98 0.98 0.98 500 + + + +# Updated style case study with averages + +broadcast-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:38<00:00, 3.83s/it] +spsa MSD: 7.570 SCE: 119.476 RMS: 1.916 LUFS: 0.530 +autodiff MSD: 7.330 SCE: 203.758 RMS: 2.988 LUFS: 1.229 +proxy0 MSD: 7.854 SCE: 373.396 RMS: 4.277 LUFS: 1.942 +Baseline MSD: 7.783 SCE: 359.935 RMS: 4.892 LUFS: 2.014 +Corrupt MSD: 8.454 SCE: 282.362 RMS: 4.441 LUFS: 1.709 + +broadcast-->telephone +100%|███████████████████████████████████████████| 10/10 [00:42<00:00, 4.28s/it] +spsa MSD: 6.390 SCE: 153.945 RMS: 12.168 LUFS: 5.746 +autodiff MSD: 5.990 SCE: 103.907 RMS: 7.123 LUFS: 3.099 +proxy0 MSD: 7.952 SCE: 250.072 RMS: 4.795 LUFS: 2.135 +Baseline MSD: 6.179 SCE: 204.214 RMS: 5.164 LUFS: 2.596 +Corrupt MSD: 10.323 SCE: 423.743 RMS: 5.038 LUFS: 2.358 + +broadcast-->neutral +100%|███████████████████████████████████████████| 10/10 [00:30<00:00, 3.05s/it] +spsa MSD: 8.632 SCE: 189.139 RMS: 3.200 LUFS: 1.323 +autodiff MSD: 8.419 SCE: 261.488 RMS: 5.080 LUFS: 1.973 +proxy0 MSD: 8.577 SCE: 398.681 RMS: 3.362 LUFS: 1.203 +Baseline MSD: 8.358 SCE: 455.740 RMS: 6.112 LUFS: 2.810 +Corrupt MSD: 8.877 SCE: 326.650 RMS: 4.517 LUFS: 2.076 + +broadcast-->bright +100%|███████████████████████████████████████████| 10/10 [00:28<00:00, 2.85s/it] +spsa MSD: 4.498 SCE: 1007.458 RMS: 7.614 LUFS: 3.487 +autodiff MSD: 4.369 SCE: 766.035 RMS: 12.674 LUFS: 5.192 +proxy0 MSD: 5.549 SCE: 1264.155 RMS: 12.701 LUFS: 7.485 +Baseline MSD: 4.642 SCE: 793.576 RMS: 4.322 LUFS: 2.077 +Corrupt MSD: 9.014 SCE: 2366.369 RMS: 10.363 LUFS: 2.558 + +broadcast-->warm +100%|███████████████████████████████████████████| 10/10 [00:36<00:00, 3.65s/it] +spsa MSD: 9.387 SCE: 150.020 RMS: 2.981 LUFS: 3.146 +autodiff MSD: 8.551 SCE: 237.002 RMS: 6.258 LUFS: 3.585 +proxy0 MSD: 9.239 SCE: 327.282 RMS: 20.747 LUFS: 12.137 +Baseline MSD: 11.344 SCE: 428.374 RMS: 1.965 LUFS: 0.915 +Corrupt MSD: 10.688 SCE: 911.725 RMS: 13.515 LUFS: 5.340 + +telephone-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:24<00:00, 2.42s/it] +spsa MSD: 8.833 SCE: 247.258 RMS: 2.139 LUFS: 0.943 +autodiff MSD: 7.513 SCE: 215.320 RMS: 1.897 LUFS: 0.671 +proxy0 MSD: 9.566 SCE: 1528.270 RMS: 6.313 LUFS: 1.872 +Baseline MSD: 9.376 SCE: 477.766 RMS: 8.829 LUFS: 3.934 +Corrupt MSD: 10.533 SCE: 340.476 RMS: 6.092 LUFS: 1.597 + +telephone-->telephone +100%|███████████████████████████████████████████| 10/10 [00:23<00:00, 2.34s/it] +spsa MSD: 5.719 SCE: 96.262 RMS: 7.819 LUFS: 3.660 +autodiff MSD: 6.300 SCE: 93.168 RMS: 3.613 LUFS: 1.572 +proxy0 MSD: 7.124 SCE: 116.400 RMS: 7.718 LUFS: 3.172 +Baseline MSD: 6.289 SCE: 140.480 RMS: 3.183 LUFS: 1.599 +Corrupt MSD: 6.701 SCE: 164.703 RMS: 2.821 LUFS: 1.302 + +telephone-->neutral +100%|███████████████████████████████████████████| 10/10 [00:22<00:00, 2.22s/it] +spsa MSD: 9.668 SCE: 268.287 RMS: 1.866 LUFS: 0.936 +autodiff MSD: 10.733 SCE: 257.832 RMS: 3.852 LUFS: 1.688 +proxy0 MSD: 10.094 SCE: 920.902 RMS: 13.909 LUFS: 5.293 +Baseline MSD: 10.038 SCE: 222.337 RMS: 5.335 LUFS: 2.283 +Corrupt MSD: 11.023 SCE: 394.824 RMS: 7.220 LUFS: 2.392 + +telephone-->bright +100%|███████████████████████████████████████████| 10/10 [00:27<00:00, 2.74s/it] +spsa MSD: 5.705 SCE: 1768.279 RMS: 2.928 LUFS: 1.505 +autodiff MSD: 6.130 SCE: 731.311 RMS: 6.528 LUFS: 2.418 +proxy0 MSD: 6.674 SCE: 1471.950 RMS: 12.486 LUFS: 6.438 +Baseline MSD: 6.124 SCE: 1207.425 RMS: 5.672 LUFS: 2.707 +Corrupt MSD: 8.100 SCE: 2675.410 RMS: 7.645 LUFS: 2.821 + +telephone-->warm +100%|███████████████████████████████████████████| 10/10 [00:35<00:00, 3.50s/it] +spsa MSD: 11.877 SCE: 289.413 RMS: 9.670 LUFS: 4.235 +autodiff MSD: 9.313 SCE: 337.220 RMS: 11.359 LUFS: 5.737 +proxy0 MSD: 11.498 SCE: 1026.376 RMS: 50.331 LUFS: 23.492 +Baseline MSD: 13.685 SCE: 574.240 RMS: 5.507 LUFS: 2.282 +Corrupt MSD: 14.749 SCE: 413.867 RMS: 15.833 LUFS: 4.805 + +neutral-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:18<00:00, 1.87s/it] +spsa MSD: 6.754 SCE: 164.936 RMS: 2.478 LUFS: 1.107 +autodiff MSD: 6.303 SCE: 178.091 RMS: 2.242 LUFS: 0.990 +proxy0 MSD: 7.266 SCE: 681.963 RMS: 2.686 LUFS: 1.476 +Baseline MSD: 8.260 SCE: 481.361 RMS: 7.533 LUFS: 3.396 +Corrupt MSD: 8.817 SCE: 343.222 RMS: 9.012 LUFS: 3.951 + +neutral-->telephone +100%|███████████████████████████████████████████| 10/10 [00:26<00:00, 2.68s/it] +spsa MSD: 6.912 SCE: 126.802 RMS: 11.842 LUFS: 5.800 +autodiff MSD: 6.790 SCE: 103.199 RMS: 5.143 LUFS: 1.898 +proxy0 MSD: 9.230 SCE: 219.290 RMS: 9.954 LUFS: 4.780 +Baseline MSD: 7.051 SCE: 183.061 RMS: 3.425 LUFS: 1.520 +Corrupt MSD: 11.376 SCE: 362.114 RMS: 8.370 LUFS: 2.681 + +neutral-->neutral +100%|███████████████████████████████████████████| 10/10 [00:24<00:00, 2.47s/it] +spsa MSD: 9.350 SCE: 337.903 RMS: 3.897 LUFS: 1.398 +autodiff MSD: 9.818 SCE: 424.667 RMS: 4.416 LUFS: 1.761 +proxy0 MSD: 10.904 SCE: 374.466 RMS: 8.422 LUFS: 3.045 +Baseline MSD: 9.088 SCE: 494.032 RMS: 3.246 LUFS: 1.274 +Corrupt MSD: 9.717 SCE: 479.381 RMS: 4.625 LUFS: 2.025 + +neutral-->bright +100%|███████████████████████████████████████████| 10/10 [00:19<00:00, 1.95s/it] +spsa MSD: 5.332 SCE: 674.385 RMS: 5.128 LUFS: 2.732 +autodiff MSD: 5.217 SCE: 382.314 RMS: 8.524 LUFS: 3.459 +proxy0 MSD: 6.210 SCE: 572.359 RMS: 9.692 LUFS: 4.994 +Baseline MSD: 6.106 SCE: 739.821 RMS: 4.730 LUFS: 1.907 +Corrupt MSD: 10.707 SCE: 2353.734 RMS: 14.699 LUFS: 4.520 + +neutral-->warm +100%|███████████████████████████████████████████| 10/10 [00:28<00:00, 2.84s/it] +spsa MSD: 9.096 SCE: 229.472 RMS: 4.478 LUFS: 2.479 +autodiff MSD: 8.925 SCE: 523.876 RMS: 10.179 LUFS: 4.529 +proxy0 MSD: 8.739 SCE: 439.980 RMS: 16.191 LUFS: 7.227 +Baseline MSD: 10.488 SCE: 563.679 RMS: 4.369 LUFS: 1.626 +Corrupt MSD: 11.037 SCE: 546.397 RMS: 9.678 LUFS: 3.205 + +bright-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:20<00:00, 2.06s/it] +spsa MSD: 9.104 SCE: 419.588 RMS: 2.438 LUFS: 1.230 +autodiff MSD: 8.564 SCE: 289.857 RMS: 1.341 LUFS: 0.803 +proxy0 MSD: 10.200 SCE: 484.384 RMS: 5.106 LUFS: 2.056 +Baseline MSD: 9.807 SCE: 845.671 RMS: 2.994 LUFS: 1.114 +Corrupt MSD: 10.234 SCE: 2210.149 RMS: 11.535 LUFS: 2.551 + +bright-->telephone +100%|███████████████████████████████████████████| 10/10 [00:32<00:00, 3.21s/it] +spsa MSD: 6.794 SCE: 352.609 RMS: 8.139 LUFS: 3.858 +autodiff MSD: 6.149 SCE: 84.044 RMS: 4.881 LUFS: 2.044 +proxy0 MSD: 7.840 SCE: 129.550 RMS: 5.039 LUFS: 1.975 +Baseline MSD: 6.521 SCE: 242.590 RMS: 3.510 LUFS: 1.372 +Corrupt MSD: 8.814 SCE: 2333.702 RMS: 7.755 LUFS: 2.493 + +bright-->neutral +100%|███████████████████████████████████████████| 10/10 [00:24<00:00, 2.47s/it] +spsa MSD: 9.321 SCE: 273.250 RMS: 2.907 LUFS: 0.970 +autodiff MSD: 9.645 SCE: 157.400 RMS: 5.455 LUFS: 2.232 +proxy0 MSD: 9.179 SCE: 184.585 RMS: 18.288 LUFS: 7.764 +Baseline MSD: 9.479 SCE: 375.173 RMS: 2.159 LUFS: 0.696 +Corrupt MSD: 9.293 SCE: 1956.294 RMS: 12.616 LUFS: 3.883 + +bright-->bright +100%|███████████████████████████████████████████| 10/10 [00:21<00:00, 2.15s/it] +spsa MSD: 4.704 SCE: 328.348 RMS: 4.520 LUFS: 1.959 +autodiff MSD: 5.305 SCE: 226.538 RMS: 8.824 LUFS: 4.115 +proxy0 MSD: 5.780 SCE: 727.537 RMS: 9.035 LUFS: 4.352 +Baseline MSD: 5.243 SCE: 774.760 RMS: 4.432 LUFS: 1.875 +Corrupt MSD: 5.752 SCE: 691.583 RMS: 4.774 LUFS: 1.930 + +bright-->warm +100%|███████████████████████████████████████████| 10/10 [00:29<00:00, 2.97s/it] +spsa MSD: 10.198 SCE: 539.897 RMS: 7.841 LUFS: 3.296 +autodiff MSD: 8.774 SCE: 305.210 RMS: 8.212 LUFS: 4.072 +proxy0 MSD: 11.001 SCE: 1614.079 RMS: 47.078 LUFS: 19.311 +Baseline MSD: 11.247 SCE: 701.682 RMS: 4.272 LUFS: 1.730 +Corrupt MSD: 11.401 SCE: 2321.016 RMS: 21.029 LUFS: 6.913 + +warm-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:18<00:00, 1.86s/it] +spsa MSD: 8.570 SCE: 191.909 RMS: 4.233 LUFS: 1.809 +autodiff MSD: 7.915 SCE: 166.873 RMS: 1.694 LUFS: 0.883 +proxy0 MSD: 15.288 SCE: 1290.383 RMS: 6.962 LUFS: 4.256 +Baseline MSD: 9.103 SCE: 527.103 RMS: 6.431 LUFS: 2.825 +Corrupt MSD: 11.239 SCE: 555.232 RMS: 12.235 LUFS: 5.123 + +warm-->telephone +100%|███████████████████████████████████████████| 10/10 [00:24<00:00, 2.42s/it] +spsa MSD: 6.636 SCE: 124.104 RMS: 13.568 LUFS: 7.087 +autodiff MSD: 6.827 SCE: 90.656 RMS: 4.065 LUFS: 1.877 +proxy0 MSD: 7.621 SCE: 297.015 RMS: 12.667 LUFS: 5.786 +Baseline MSD: 8.053 SCE: 358.172 RMS: 6.659 LUFS: 2.548 +Corrupt MSD: 13.733 SCE: 533.204 RMS: 17.164 LUFS: 5.107 + +warm-->neutral +100%|███████████████████████████████████████████| 10/10 [00:20<00:00, 2.10s/it] +spsa MSD: 8.977 SCE: 153.949 RMS: 4.014 LUFS: 1.645 +autodiff MSD: 10.304 SCE: 331.416 RMS: 5.010 LUFS: 2.185 +proxy0 MSD: 19.354 SCE: 1478.394 RMS: 10.675 LUFS: 6.258 +Baseline MSD: 10.615 SCE: 546.162 RMS: 7.670 LUFS: 2.573 +Corrupt MSD: 11.586 SCE: 328.435 RMS: 9.091 LUFS: 3.133 + +warm-->bright +100%|███████████████████████████████████████████| 10/10 [00:21<00:00, 2.14s/it] +spsa MSD: 6.275 SCE: 771.123 RMS: 3.469 LUFS: 2.794 +autodiff MSD: 6.093 SCE: 625.806 RMS: 7.622 LUFS: 3.542 +proxy0 MSD: 10.266 SCE: 767.358 RMS: 15.424 LUFS: 8.371 +Baseline MSD: 8.410 SCE: 1063.015 RMS: 12.834 LUFS: 4.710 +Corrupt MSD: 12.126 SCE: 2773.822 RMS: 21.162 LUFS: 6.825 + +warm-->warm +100%|███████████████████████████████████████████| 10/10 [00:27<00:00, 2.75s/it] +spsa MSD: 8.804 SCE: 228.748 RMS: 2.723 LUFS: 1.755 +autodiff MSD: 8.988 SCE: 350.737 RMS: 3.947 LUFS: 1.725 +proxy0 MSD: 8.415 SCE: 289.968 RMS: 5.926 LUFS: 2.945 +Baseline MSD: 9.734 SCE: 314.220 RMS: 3.452 LUFS: 1.352 +Corrupt MSD: 9.752 SCE: 395.311 RMS: 5.167 LUFS: 2.343 + +----- Averages ---- DAPS +autodiff MSD: 7.611 SCE: 297.909 RMS: 5.717 LUFS: 2.531 +spsa MSD: 7.804 SCE: 368.262 RMS: 5.359 LUFS: 2.617 +proxy0 MSD: 9.257 SCE: 689.152 RMS: 12.791 LUFS: 5.991 +Baseline MSD: 8.521 SCE: 522.984 RMS: 5.148 LUFS: 2.149 +Corrupt MSD: 10.162 SCE: 1059.349 RMS: 9.856 LUFS: 3.346 + +Global seed set to 16 + +broadcast-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:33<00:00, 3.33s/it] +spsa MSD: 7.371 SCE: 172.457 RMS: 1.926 LUFS: 1.073 +autodiff MSD: 7.017 SCE: 218.527 RMS: 3.325 LUFS: 1.112 +proxy0 MSD: 7.861 SCE: 398.562 RMS: 2.983 LUFS: 1.336 +Baseline MSD: 8.148 SCE: 402.896 RMS: 6.979 LUFS: 3.053 +Corrupt MSD: 9.118 SCE: 214.410 RMS: 9.975 LUFS: 4.143 + +broadcast-->telephone +100%|███████████████████████████████████████████| 10/10 [00:33<00:00, 3.39s/it] +spsa MSD: 5.549 SCE: 104.162 RMS: 9.593 LUFS: 4.933 +autodiff MSD: 7.373 SCE: 132.311 RMS: 7.241 LUFS: 2.998 +proxy0 MSD: 11.958 SCE: 244.265 RMS: 15.485 LUFS: 7.202 +Baseline MSD: 7.585 SCE: 166.084 RMS: 3.754 LUFS: 1.647 +Corrupt MSD: 10.272 SCE: 491.861 RMS: 5.794 LUFS: 2.461 + +broadcast-->neutral +100%|███████████████████████████████████████████| 10/10 [00:32<00:00, 3.25s/it] +spsa MSD: 8.019 SCE: 204.897 RMS: 2.908 LUFS: 1.362 +autodiff MSD: 8.441 SCE: 264.969 RMS: 3.701 LUFS: 1.647 +proxy0 MSD: 8.328 SCE: 368.036 RMS: 4.054 LUFS: 1.705 +Baseline MSD: 7.784 SCE: 270.785 RMS: 7.047 LUFS: 3.248 +Corrupt MSD: 7.904 SCE: 371.207 RMS: 6.007 LUFS: 2.700 + +broadcast-->bright +100%|███████████████████████████████████████████| 10/10 [00:26<00:00, 2.63s/it] +spsa MSD: 6.225 SCE: 490.014 RMS: 2.952 LUFS: 1.385 +autodiff MSD: 5.833 SCE: 318.524 RMS: 5.973 LUFS: 2.489 +proxy0 MSD: 7.602 SCE: 630.675 RMS: 9.428 LUFS: 5.563 +Baseline MSD: 6.505 SCE: 892.677 RMS: 4.263 LUFS: 2.216 +Corrupt MSD: 9.665 SCE: 1928.327 RMS: 8.414 LUFS: 2.670 + +broadcast-->warm +100%|███████████████████████████████████████████| 10/10 [00:45<00:00, 4.58s/it] +spsa MSD: 11.026 SCE: 170.762 RMS: 4.956 LUFS: 3.221 +autodiff MSD: 10.065 SCE: 147.897 RMS: 9.530 LUFS: 5.482 +proxy0 MSD: 10.475 SCE: 246.926 RMS: 15.965 LUFS: 9.458 +Baseline MSD: 10.470 SCE: 667.165 RMS: 9.631 LUFS: 4.567 +Corrupt MSD: 10.871 SCE: 471.331 RMS: 12.628 LUFS: 5.300 + +telephone-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:19<00:00, 1.99s/it] +spsa MSD: 8.465 SCE: 279.417 RMS: 1.311 LUFS: 0.910 +autodiff MSD: 7.611 SCE: 218.598 RMS: 2.633 LUFS: 0.895 +proxy0 MSD: 9.983 SCE: 972.418 RMS: 8.302 LUFS: 2.970 +Baseline MSD: 10.099 SCE: 452.687 RMS: 7.886 LUFS: 3.718 +Corrupt MSD: 10.465 SCE: 413.987 RMS: 6.029 LUFS: 2.001 + +telephone-->telephone +100%|███████████████████████████████████████████| 10/10 [00:20<00:00, 2.05s/it] +spsa MSD: 4.506 SCE: 138.731 RMS: 4.951 LUFS: 2.380 +autodiff MSD: 4.956 SCE: 61.731 RMS: 3.884 LUFS: 1.293 +proxy0 MSD: 6.587 SCE: 132.168 RMS: 8.864 LUFS: 3.503 +Baseline MSD: 6.005 SCE: 119.214 RMS: 5.546 LUFS: 2.056 +Corrupt MSD: 6.553 SCE: 280.447 RMS: 4.834 LUFS: 1.827 + +telephone-->neutral +100%|███████████████████████████████████████████| 10/10 [00:21<00:00, 2.11s/it] +spsa MSD: 9.945 SCE: 313.686 RMS: 3.363 LUFS: 1.403 +autodiff MSD: 9.247 SCE: 308.933 RMS: 4.673 LUFS: 1.794 +proxy0 MSD: 10.991 SCE: 864.742 RMS: 19.190 LUFS: 7.352 +Baseline MSD: 10.158 SCE: 460.452 RMS: 5.010 LUFS: 2.011 +Corrupt MSD: 10.700 SCE: 425.533 RMS: 9.457 LUFS: 3.232 + +telephone-->bright +100%|███████████████████████████████████████████| 10/10 [00:27<00:00, 2.79s/it] +spsa MSD: 5.898 SCE: 1389.797 RMS: 3.033 LUFS: 1.591 +autodiff MSD: 5.869 SCE: 429.690 RMS: 4.644 LUFS: 2.356 +proxy0 MSD: 7.527 SCE: 1012.915 RMS: 13.229 LUFS: 5.968 +Baseline MSD: 5.564 SCE: 818.532 RMS: 5.503 LUFS: 2.504 +Corrupt MSD: 8.339 SCE: 2389.284 RMS: 4.828 LUFS: 1.808 + +telephone-->warm +100%|███████████████████████████████████████████| 10/10 [00:28<00:00, 2.85s/it] +spsa MSD: 10.833 SCE: 312.101 RMS: 9.283 LUFS: 3.787 +autodiff MSD: 9.216 SCE: 374.449 RMS: 8.670 LUFS: 4.076 +proxy0 MSD: 12.483 SCE: 1148.659 RMS: 49.559 LUFS: 21.345 +Baseline MSD: 13.266 SCE: 470.901 RMS: 4.924 LUFS: 1.677 +Corrupt MSD: 13.197 SCE: 522.449 RMS: 18.464 LUFS: 5.796 + +neutral-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:18<00:00, 1.87s/it] +spsa MSD: 6.641 SCE: 231.272 RMS: 2.405 LUFS: 1.172 +autodiff MSD: 6.455 SCE: 303.714 RMS: 1.542 LUFS: 0.726 +proxy0 MSD: 6.944 SCE: 446.685 RMS: 2.408 LUFS: 1.497 +Baseline MSD: 8.370 SCE: 677.089 RMS: 5.512 LUFS: 2.459 +Corrupt MSD: 8.919 SCE: 283.764 RMS: 7.543 LUFS: 3.408 + +neutral-->telephone +100%|███████████████████████████████████████████| 10/10 [00:23<00:00, 2.37s/it] +spsa MSD: 6.548 SCE: 125.255 RMS: 10.246 LUFS: 5.055 +autodiff MSD: 6.691 SCE: 140.547 RMS: 5.201 LUFS: 2.269 +proxy0 MSD: 9.147 SCE: 200.192 RMS: 8.658 LUFS: 3.882 +Baseline MSD: 6.413 SCE: 176.637 RMS: 3.044 LUFS: 1.157 +Corrupt MSD: 11.548 SCE: 253.729 RMS: 10.898 LUFS: 3.087 + +neutral-->neutral +100%|███████████████████████████████████████████| 10/10 [00:21<00:00, 2.11s/it] +spsa MSD: 8.229 SCE: 397.836 RMS: 4.964 LUFS: 2.163 +autodiff MSD: 8.446 SCE: 547.878 RMS: 4.720 LUFS: 2.147 +proxy0 MSD: 9.892 SCE: 583.782 RMS: 10.960 LUFS: 4.535 +Baseline MSD: 8.831 SCE: 546.597 RMS: 3.727 LUFS: 1.621 +Corrupt MSD: 9.035 SCE: 628.128 RMS: 5.517 LUFS: 2.584 + +neutral-->bright +100%|███████████████████████████████████████████| 10/10 [00:17<00:00, 1.77s/it] +spsa MSD: 5.138 SCE: 773.308 RMS: 5.914 LUFS: 3.552 +autodiff MSD: 5.038 SCE: 454.673 RMS: 9.380 LUFS: 4.732 +proxy0 MSD: 6.190 SCE: 747.949 RMS: 8.446 LUFS: 4.713 +Baseline MSD: 6.132 SCE: 777.004 RMS: 5.666 LUFS: 1.866 +Corrupt MSD: 10.296 SCE: 2148.154 RMS: 13.907 LUFS: 3.789 + +neutral-->warm +100%|███████████████████████████████████████████| 10/10 [00:40<00:00, 4.05s/it] +spsa MSD: 12.036 SCE: 303.386 RMS: 6.242 LUFS: 3.284 +autodiff MSD: 11.478 SCE: 468.910 RMS: 12.385 LUFS: 5.592 +proxy0 MSD: 12.177 SCE: 326.679 RMS: 20.926 LUFS: 9.999 +Baseline MSD: 12.819 SCE: 929.818 RMS: 9.187 LUFS: 3.835 +Corrupt MSD: 12.108 SCE: 535.255 RMS: 13.302 LUFS: 5.392 + +bright-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:22<00:00, 2.25s/it] +spsa MSD: 8.609 SCE: 425.804 RMS: 2.318 LUFS: 1.060 +autodiff MSD: 7.712 SCE: 273.757 RMS: 2.447 LUFS: 0.752 +proxy0 MSD: 9.258 SCE: 650.227 RMS: 5.816 LUFS: 1.639 +Baseline MSD: 9.889 SCE: 658.042 RMS: 7.105 LUFS: 3.023 +Corrupt MSD: 9.141 SCE: 2054.307 RMS: 10.704 LUFS: 2.987 + +bright-->telephone +100%|███████████████████████████████████████████| 10/10 [00:25<00:00, 2.54s/it] +spsa MSD: 6.847 SCE: 297.782 RMS: 7.416 LUFS: 3.305 +autodiff MSD: 6.949 SCE: 80.701 RMS: 4.010 LUFS: 1.586 +proxy0 MSD: 8.813 SCE: 100.100 RMS: 6.955 LUFS: 3.087 +Baseline MSD: 6.842 SCE: 158.974 RMS: 3.036 LUFS: 1.131 +Corrupt MSD: 9.380 SCE: 2468.890 RMS: 5.970 LUFS: 2.489 + +bright-->neutral +100%|███████████████████████████████████████████| 10/10 [00:21<00:00, 2.14s/it] +spsa MSD: 11.284 SCE: 419.588 RMS: 2.265 LUFS: 1.078 +autodiff MSD: 12.047 SCE: 301.471 RMS: 2.768 LUFS: 1.481 +proxy0 MSD: 10.578 SCE: 389.600 RMS: 9.342 LUFS: 3.756 +Baseline MSD: 10.776 SCE: 445.448 RMS: 2.500 LUFS: 1.274 +Corrupt MSD: 11.276 SCE: 2091.734 RMS: 10.198 LUFS: 2.402 + +bright-->bright +100%|███████████████████████████████████████████| 10/10 [00:18<00:00, 1.89s/it] +spsa MSD: 5.786 SCE: 219.477 RMS: 2.651 LUFS: 1.053 +autodiff MSD: 6.340 SCE: 233.348 RMS: 4.010 LUFS: 1.949 +proxy0 MSD: 6.906 SCE: 654.335 RMS: 8.258 LUFS: 4.230 +Baseline MSD: 6.289 SCE: 474.055 RMS: 2.687 LUFS: 0.994 +Corrupt MSD: 7.115 SCE: 601.199 RMS: 5.821 LUFS: 1.852 + +bright-->warm +100%|███████████████████████████████████████████| 10/10 [00:26<00:00, 2.63s/it] +spsa MSD: 9.844 SCE: 752.189 RMS: 9.814 LUFS: 4.016 +autodiff MSD: 8.481 SCE: 267.815 RMS: 8.754 LUFS: 4.293 +proxy0 MSD: 10.927 SCE: 1941.049 RMS: 50.826 LUFS: 20.030 +Baseline MSD: 11.895 SCE: 437.921 RMS: 3.092 LUFS: 1.207 +Corrupt MSD: 12.235 SCE: 2533.021 RMS: 20.768 LUFS: 6.394 + +warm-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:16<00:00, 1.66s/it] +spsa MSD: 8.253 SCE: 144.919 RMS: 4.567 LUFS: 1.956 +autodiff MSD: 7.501 SCE: 255.792 RMS: 2.247 LUFS: 0.889 +proxy0 MSD: 10.815 SCE: 1238.939 RMS: 3.606 LUFS: 2.852 +Baseline MSD: 10.156 SCE: 567.506 RMS: 12.300 LUFS: 5.272 +Corrupt MSD: 11.231 SCE: 549.288 RMS: 14.310 LUFS: 6.126 + +warm-->telephone +100%|███████████████████████████████████████████| 10/10 [00:18<00:00, 1.86s/it] +spsa MSD: 6.870 SCE: 166.390 RMS: 11.550 LUFS: 6.152 +autodiff MSD: 6.178 SCE: 114.869 RMS: 2.939 LUFS: 1.617 +proxy0 MSD: 7.863 SCE: 203.042 RMS: 10.660 LUFS: 4.645 +Baseline MSD: 9.449 SCE: 471.936 RMS: 8.080 LUFS: 2.670 +Corrupt MSD: 14.019 SCE: 252.479 RMS: 19.388 LUFS: 6.550 + +warm-->neutral +100%|███████████████████████████████████████████| 10/10 [00:19<00:00, 1.96s/it] +spsa MSD: 9.106 SCE: 166.963 RMS: 4.564 LUFS: 1.977 +autodiff MSD: 9.523 SCE: 302.502 RMS: 4.357 LUFS: 1.775 +proxy0 MSD: 20.448 SCE: 1060.644 RMS: 10.820 LUFS: 6.420 +Baseline MSD: 10.240 SCE: 496.880 RMS: 4.583 LUFS: 1.665 +Corrupt MSD: 11.977 SCE: 709.600 RMS: 6.823 LUFS: 2.232 + +warm-->bright +100%|███████████████████████████████████████████| 10/10 [00:17<00:00, 1.72s/it] +spsa MSD: 6.002 SCE: 751.138 RMS: 5.238 LUFS: 3.321 +autodiff MSD: 5.313 SCE: 334.639 RMS: 8.967 LUFS: 3.842 +proxy0 MSD: 9.097 SCE: 914.759 RMS: 12.348 LUFS: 6.958 +Baseline MSD: 8.472 SCE: 1110.614 RMS: 9.309 LUFS: 3.231 +Corrupt MSD: 13.787 SCE: 2904.344 RMS: 23.853 LUFS: 7.778 + +warm-->warm +100%|███████████████████████████████████████████| 10/10 [00:30<00:00, 3.03s/it] +spsa MSD: 10.273 SCE: 266.991 RMS: 3.600 LUFS: 1.592 +autodiff MSD: 10.117 SCE: 396.560 RMS: 2.964 LUFS: 1.410 +proxy0 MSD: 11.400 SCE: 437.774 RMS: 5.956 LUFS: 2.307 +Baseline MSD: 10.142 SCE: 279.552 RMS: 5.683 LUFS: 2.484 +Corrupt MSD: 10.802 SCE: 518.633 RMS: 4.915 LUFS: 2.210 + +----- Averages ---- DAPS +spsa MSD: 7.972 SCE: 360.733 RMS: 5.121 LUFS: 2.511 +autodiff MSD: 7.756 SCE: 278.112 RMS: 5.239 LUFS: 2.368 +proxy0 MSD: 9.770 SCE: 636.605 RMS: 12.922 LUFS: 5.878 +Baseline MSD: 8.892 SCE: 517.179 RMS: 5.842 LUFS: 2.423 +Corrupt MSD: 10.398 SCE: 1041.654 RMS: 10.414 LUFS: 3.649 + +## Style case study on MUSDB18 @ 44.1 kHz + +CUDA_VISIBLE_DEVICES=1 python scripts/style_case_study.py \ +--ckpt_paths \ +"/import/c4dm-datasets/deepafx_st/logs_jamendo/style/jamendo/autodiff/lightning_logs/version_0/checkpoints/epoch=362-step=1210241-val-jamendo-autodiff.ckpt" \ +"/import/c4dm-datasets/deepafx_st/logs_jamendo/style/jamendo/spsa/lightning_logs/version_0/checkpoints/epoch=362-step=1210241-val-jamendo-spsa.ckpt" \ +"/import/c4dm-datasets/deepafx_st/logs_jamendo/style/jamendo/proxy0/lightning_logs/version_0/checkpoints/epoch=362-step=1210241-val-jamendo-proxy0.ckpt" \ +--style_audio "/import/c4dm-datasets/deepafx_st/musdb18_44100_styles_100/train" \ +--output_dir "/import/c4dm-datasets/deepafx_st/style_case_study_musdb18" \ +--sample_rate 44100 \ +--gpu \ +--save \ +--plot \ +broadcast-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:38<00:00, 3.87s/it] +autodiff MSD: 5.268 SCE: 764.455 RMS: 1.486 LUFS: 0.620 +spsa MSD: 5.717 SCE: 742.577 RMS: 2.132 LUFS: 1.101 +proxy0 MSD: 7.515 SCE: 954.857 RMS: 6.580 LUFS: 3.552 +Baseline MSD: 6.712 SCE: 324.917 RMS: 6.104 LUFS: 2.953 +Corrupt MSD: 7.248 SCE: 857.312 RMS: 7.288 LUFS: 3.347 + +broadcast-->telephone +100%|███████████████████████████████████████████| 10/10 [01:28<00:00, 8.88s/it] +autodiff MSD: 6.153 SCE: 165.554 RMS: 3.281 LUFS: 1.483 +spsa MSD: 7.145 SCE: 414.354 RMS: 9.122 LUFS: 4.119 +proxy0 MSD: 7.849 SCE: 303.187 RMS: 6.100 LUFS: 2.475 +Baseline MSD: 8.953 SCE: 261.259 RMS: 15.381 LUFS: 6.724 +Corrupt MSD: 11.741 SCE: 1618.816 RMS: 11.667 LUFS: 5.312 + +broadcast-->neutral +100%|███████████████████████████████████████████| 10/10 [01:11<00:00, 7.14s/it] +autodiff MSD: 6.254 SCE: 721.382 RMS: 2.862 LUFS: 1.141 +spsa MSD: 6.460 SCE: 825.238 RMS: 2.443 LUFS: 1.190 +proxy0 MSD: 8.555 SCE: 1317.455 RMS: 6.520 LUFS: 3.426 +Baseline MSD: 7.031 SCE: 238.390 RMS: 7.928 LUFS: 3.514 +Corrupt MSD: 7.839 SCE: 1029.224 RMS: 5.390 LUFS: 2.837 + +broadcast-->bright +100%|███████████████████████████████████████████| 10/10 [00:38<00:00, 3.83s/it] +autodiff MSD: 2.377 SCE: 1834.532 RMS: 4.029 LUFS: 1.704 +spsa MSD: 3.088 SCE: 2325.082 RMS: 6.319 LUFS: 2.218 +proxy0 MSD: 3.582 SCE: 1312.444 RMS: 12.568 LUFS: 5.890 +Baseline MSD: 3.043 SCE: 985.595 RMS: 8.910 LUFS: 3.665 +Corrupt MSD: 7.679 SCE: 4955.671 RMS: 23.504 LUFS: 7.104 + +broadcast-->warm +100%|███████████████████████████████████████████| 10/10 [00:49<00:00, 4.92s/it] +autodiff MSD: 3.891 SCE: 748.354 RMS: 3.421 LUFS: 1.569 +spsa MSD: 4.541 SCE: 769.316 RMS: 3.495 LUFS: 1.453 +proxy0 MSD: 4.705 SCE: 1007.030 RMS: 6.546 LUFS: 3.328 +Baseline MSD: 5.870 SCE: 705.780 RMS: 6.706 LUFS: 2.960 +Corrupt MSD: 7.728 SCE: 1744.439 RMS: 6.603 LUFS: 2.843 + +telephone-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:42<00:00, 4.29s/it] +autodiff MSD: 5.457 SCE: 795.807 RMS: 1.699 LUFS: 0.687 +spsa MSD: 6.593 SCE: 589.769 RMS: 2.581 LUFS: 1.151 +proxy0 MSD: 8.515 SCE: 3452.242 RMS: 15.419 LUFS: 6.475 +Baseline MSD: 9.751 SCE: 1464.298 RMS: 11.973 LUFS: 5.026 +Corrupt MSD: 11.900 SCE: 1627.915 RMS: 8.630 LUFS: 4.750 + +telephone-->telephone +100%|███████████████████████████████████████████| 10/10 [01:06<00:00, 6.63s/it] +autodiff MSD: 5.490 SCE: 122.907 RMS: 3.295 LUFS: 1.602 +spsa MSD: 6.032 SCE: 91.678 RMS: 3.282 LUFS: 1.391 +proxy0 MSD: 6.820 SCE: 357.135 RMS: 7.806 LUFS: 3.530 +Baseline MSD: 6.814 SCE: 253.513 RMS: 6.370 LUFS: 2.765 +Corrupt MSD: 7.155 SCE: 328.757 RMS: 8.424 LUFS: 3.690 + +telephone-->neutral +100%|███████████████████████████████████████████| 10/10 [00:48<00:00, 4.84s/it] +autodiff MSD: 6.228 SCE: 563.005 RMS: 2.564 LUFS: 1.157 +spsa MSD: 7.018 SCE: 622.279 RMS: 2.742 LUFS: 1.198 +proxy0 MSD: 8.846 SCE: 2406.702 RMS: 10.261 LUFS: 4.415 +Baseline MSD: 8.775 SCE: 1237.468 RMS: 4.171 LUFS: 2.020 +Corrupt MSD: 9.914 SCE: 1318.137 RMS: 6.742 LUFS: 2.769 + +telephone-->bright +100%|███████████████████████████████████████████| 10/10 [02:26<00:00, 14.63s/it] +autodiff MSD: 3.270 SCE: 3387.736 RMS: 8.958 LUFS: 4.675 +spsa MSD: 3.858 SCE: 4975.919 RMS: 5.703 LUFS: 3.585 +proxy0 MSD: 4.210 SCE: 2814.990 RMS: 5.121 LUFS: 2.485 +Baseline MSD: 4.113 SCE: 2005.442 RMS: 25.054 LUFS: 11.366 +Corrupt MSD: 8.285 SCE: 7409.718 RMS: 15.362 LUFS: 5.134 + +telephone-->warm +100%|███████████████████████████████████████████| 10/10 [00:30<00:00, 3.07s/it] +autodiff MSD: 4.053 SCE: 798.839 RMS: 6.723 LUFS: 3.331 +spsa MSD: 5.598 SCE: 652.353 RMS: 11.710 LUFS: 4.834 +proxy0 MSD: 5.526 SCE: 2542.052 RMS: 48.147 LUFS: 18.399 +Baseline MSD: 10.660 SCE: 1727.024 RMS: 4.498 LUFS: 2.267 +Corrupt MSD: 12.201 SCE: 1363.306 RMS: 8.944 LUFS: 3.653 + +neutral-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:51<00:00, 5.14s/it] +autodiff MSD: 5.633 SCE: 473.112 RMS: 1.688 LUFS: 0.884 +spsa MSD: 6.232 SCE: 620.590 RMS: 2.096 LUFS: 1.177 +proxy0 MSD: 7.190 SCE: 973.757 RMS: 8.096 LUFS: 3.795 +Baseline MSD: 6.703 SCE: 408.505 RMS: 7.695 LUFS: 3.240 +Corrupt MSD: 7.213 SCE: 697.631 RMS: 5.424 LUFS: 2.277 + +neutral-->telephone +100%|███████████████████████████████████████████| 10/10 [01:12<00:00, 7.22s/it] +autodiff MSD: 5.293 SCE: 170.508 RMS: 2.995 LUFS: 1.343 +spsa MSD: 6.183 SCE: 302.104 RMS: 5.546 LUFS: 2.798 +proxy0 MSD: 6.489 SCE: 206.780 RMS: 5.344 LUFS: 2.495 +Baseline MSD: 7.031 SCE: 291.637 RMS: 6.276 LUFS: 3.036 +Corrupt MSD: 9.642 SCE: 1073.369 RMS: 7.456 LUFS: 3.109 + +neutral-->neutral +100%|███████████████████████████████████████████| 10/10 [01:13<00:00, 7.34s/it] +autodiff MSD: 6.086 SCE: 596.216 RMS: 2.165 LUFS: 0.993 +spsa MSD: 6.659 SCE: 817.107 RMS: 2.225 LUFS: 0.850 +proxy0 MSD: 8.871 SCE: 1231.226 RMS: 4.540 LUFS: 2.858 +Baseline MSD: 6.545 SCE: 396.184 RMS: 5.075 LUFS: 2.517 +Corrupt MSD: 7.287 SCE: 1037.534 RMS: 3.076 LUFS: 2.261 + +neutral-->bright +100%|███████████████████████████████████████████| 10/10 [02:10<00:00, 13.01s/it] +autodiff MSD: 3.766 SCE: 2586.719 RMS: 5.256 LUFS: 2.830 +spsa MSD: 3.988 SCE: 2741.861 RMS: 5.056 LUFS: 2.633 +proxy0 MSD: 5.062 SCE: 1127.514 RMS: 3.931 LUFS: 2.326 +Baseline MSD: 4.613 SCE: 1388.867 RMS: 20.508 LUFS: 9.528 +Corrupt MSD: 9.180 SCE: 5913.314 RMS: 10.544 LUFS: 3.077 + +neutral-->warm +100%|███████████████████████████████████████████| 10/10 [00:32<00:00, 3.26s/it] +autodiff MSD: 3.895 SCE: 839.517 RMS: 5.932 LUFS: 2.306 +spsa MSD: 4.252 SCE: 777.918 RMS: 3.545 LUFS: 1.911 +proxy0 MSD: 4.595 SCE: 1163.291 RMS: 9.814 LUFS: 5.204 +Baseline MSD: 6.711 SCE: 421.702 RMS: 6.794 LUFS: 3.228 +Corrupt MSD: 10.006 SCE: 1059.995 RMS: 6.357 LUFS: 4.262 + +bright-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:29<00:00, 2.94s/it] +autodiff MSD: 5.007 SCE: 1785.158 RMS: 1.634 LUFS: 0.619 +spsa MSD: 5.712 SCE: 3068.999 RMS: 3.311 LUFS: 1.212 +proxy0 MSD: 7.551 SCE: 3772.967 RMS: 7.368 LUFS: 3.364 +Baseline MSD: 7.888 SCE: 918.202 RMS: 4.133 LUFS: 1.998 +Corrupt MSD: 7.071 SCE: 5855.776 RMS: 12.743 LUFS: 3.426 + +bright-->telephone +100%|███████████████████████████████████████████| 10/10 [01:36<00:00, 9.63s/it] +autodiff MSD: 6.150 SCE: 245.520 RMS: 2.700 LUFS: 1.190 +spsa MSD: 7.275 SCE: 715.758 RMS: 6.034 LUFS: 2.993 +proxy0 MSD: 9.355 SCE: 269.956 RMS: 5.673 LUFS: 2.097 +Baseline MSD: 8.040 SCE: 231.228 RMS: 9.257 LUFS: 4.368 +Corrupt MSD: 10.568 SCE: 4923.880 RMS: 14.471 LUFS: 5.918 + +bright-->neutral +100%|███████████████████████████████████████████| 10/10 [00:34<00:00, 3.50s/it] +autodiff MSD: 6.739 SCE: 1748.574 RMS: 6.367 LUFS: 2.415 +spsa MSD: 8.106 SCE: 2257.716 RMS: 4.058 LUFS: 1.124 +proxy0 MSD: 10.557 SCE: 2358.070 RMS: 7.158 LUFS: 3.306 +Baseline MSD: 9.158 SCE: 903.038 RMS: 2.445 LUFS: 1.018 +Corrupt MSD: 8.662 SCE: 5363.465 RMS: 14.566 LUFS: 3.676 + +bright-->bright +100%|███████████████████████████████████████████| 10/10 [00:53<00:00, 5.39s/it] +autodiff MSD: 2.583 SCE: 842.237 RMS: 3.639 LUFS: 1.679 +spsa MSD: 2.942 SCE: 1069.083 RMS: 3.446 LUFS: 1.603 +proxy0 MSD: 3.404 SCE: 1263.463 RMS: 7.770 LUFS: 3.352 +Baseline MSD: 3.194 SCE: 1182.327 RMS: 8.312 LUFS: 3.665 +Corrupt MSD: 3.611 SCE: 1865.160 RMS: 7.540 LUFS: 3.165 + +bright-->warm +100%|███████████████████████████████████████████| 10/10 [00:41<00:00, 4.11s/it] +autodiff MSD: 3.720 SCE: 918.005 RMS: 5.844 LUFS: 2.801 +spsa MSD: 4.430 SCE: 1835.436 RMS: 11.297 LUFS: 4.988 +proxy0 MSD: 5.261 SCE: 4984.333 RMS: 29.548 LUFS: 10.549 +Baseline MSD: 8.866 SCE: 911.462 RMS: 4.441 LUFS: 2.750 +Corrupt MSD: 8.584 SCE: 5182.931 RMS: 16.600 LUFS: 4.017 + +warm-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:36<00:00, 3.66s/it] +autodiff MSD: 4.536 SCE: 462.483 RMS: 1.149 LUFS: 0.397 +spsa MSD: 5.336 SCE: 1403.954 RMS: 2.841 LUFS: 1.359 +proxy0 MSD: 9.801 SCE: 857.491 RMS: 8.365 LUFS: 5.247 +Baseline MSD: 5.814 SCE: 1245.149 RMS: 7.553 LUFS: 2.836 +Corrupt MSD: 7.006 SCE: 1214.527 RMS: 10.749 LUFS: 4.726 + +warm-->telephone +100%|███████████████████████████████████████████| 10/10 [00:56<00:00, 5.65s/it] +autodiff MSD: 6.247 SCE: 306.690 RMS: 7.223 LUFS: 3.547 +spsa MSD: 7.720 SCE: 976.353 RMS: 8.565 LUFS: 3.589 +proxy0 MSD: 8.222 SCE: 800.051 RMS: 9.571 LUFS: 4.325 +Baseline MSD: 10.837 SCE: 814.371 RMS: 13.657 LUFS: 4.271 +Corrupt MSD: 12.381 SCE: 1682.943 RMS: 11.264 LUFS: 4.622 + +warm-->neutral +100%|███████████████████████████████████████████| 10/10 [00:44<00:00, 4.45s/it] +autodiff MSD: 6.247 SCE: 759.577 RMS: 1.195 LUFS: 0.692 +spsa MSD: 6.635 SCE: 1124.939 RMS: 2.152 LUFS: 0.646 +proxy0 MSD: 12.783 SCE: 910.850 RMS: 7.942 LUFS: 4.809 +Baseline MSD: 7.825 SCE: 1237.193 RMS: 4.960 LUFS: 1.748 +Corrupt MSD: 8.902 SCE: 1197.048 RMS: 4.863 LUFS: 1.927 + +warm-->bright +100%|███████████████████████████████████████████| 10/10 [00:20<00:00, 2.04s/it] +autodiff MSD: 2.768 SCE: 1369.470 RMS: 3.846 LUFS: 2.007 +spsa MSD: 3.445 SCE: 1910.909 RMS: 1.579 LUFS: 1.279 +proxy0 MSD: 4.475 SCE: 802.250 RMS: 5.425 LUFS: 1.860 +Baseline MSD: 5.136 SCE: 2548.782 RMS: 13.753 LUFS: 2.340 +Corrupt MSD: 7.634 SCE: 6192.568 RMS: 21.637 LUFS: 4.702 + +warm-->warm +100%|███████████████████████████████████████████| 10/10 [00:44<00:00, 4.49s/it] +autodiff MSD: 3.388 SCE: 683.876 RMS: 4.456 LUFS: 1.592 +spsa MSD: 3.400 SCE: 794.539 RMS: 3.175 LUFS: 0.996 +proxy0 MSD: 4.359 SCE: 626.161 RMS: 4.856 LUFS: 2.238 +Baseline MSD: 3.495 SCE: 782.423 RMS: 4.227 LUFS: 1.738 +Corrupt MSD: 3.910 SCE: 1368.821 RMS: 5.118 LUFS: 2.271 + +----- Averages ---- MUSDB18 +autodiff MSD: 4.820 SCE: 947.609 RMS: 3.776 LUFS: 1.731 +spsa MSD: 5.535 SCE: 1297.033 RMS: 4.578 LUFS: 2.056 +proxy0 MSD: 6.964 SCE: 1512.650 RMS: 10.019 LUFS: 4.472 +Baseline MSD: 6.943 SCE: 915.390 RMS: 8.647 LUFS: 3.662 +Corrupt MSD: 8.534 SCE: 2675.290 RMS: 10.199 LUFS: 3.795 + + ------ + +Global seed set to 16 +Proxy Processor: peq @ fs=24000 Hz +TCN receptive field: 7021 samples or 292.542 ms +Proxy Processor: comp @ fs=24000 Hz +TCN receptive field: 7021 samples or 292.542 ms +broadcast-->broadcast +100%|███████████████████████████████████████████| 10/10 [01:18<00:00, 7.87s/it] +autodiff MSD: 5.214 SCE: 539.719 RMS: 1.776 LUFS: 0.896 +spsa MSD: 5.499 SCE: 532.192 RMS: 3.119 LUFS: 1.426 +proxy0 MSD: 6.473 SCE: 1390.071 RMS: 7.491 LUFS: 4.578 +Baseline MSD: 6.449 SCE: 611.495 RMS: 13.743 LUFS: 5.791 +Corrupt MSD: 6.739 SCE: 1188.642 RMS: 9.744 LUFS: 4.377 + +broadcast-->telephone +100%|███████████████████████████████████████████| 10/10 [01:00<00:00, 6.00s/it] +autodiff MSD: 5.974 SCE: 151.923 RMS: 2.261 LUFS: 1.189 +spsa MSD: 6.501 SCE: 260.794 RMS: 5.734 LUFS: 2.850 +proxy0 MSD: 7.920 SCE: 273.202 RMS: 4.547 LUFS: 1.867 +Baseline MSD: 7.449 SCE: 280.922 RMS: 7.111 LUFS: 2.715 +Corrupt MSD: 10.834 SCE: 1312.504 RMS: 8.379 LUFS: 2.868 + +broadcast-->neutral +100%|███████████████████████████████████████████| 10/10 [01:55<00:00, 11.54s/it] +autodiff MSD: 5.757 SCE: 625.424 RMS: 2.377 LUFS: 1.079 +spsa MSD: 6.238 SCE: 598.486 RMS: 1.932 LUFS: 0.808 +proxy0 MSD: 7.995 SCE: 1080.859 RMS: 5.854 LUFS: 3.147 +Baseline MSD: 6.528 SCE: 305.291 RMS: 15.089 LUFS: 6.711 +Corrupt MSD: 6.470 SCE: 1340.071 RMS: 7.338 LUFS: 3.598 + +broadcast-->bright +100%|███████████████████████████████████████████| 10/10 [01:28<00:00, 8.86s/it] +autodiff MSD: 3.124 SCE: 1826.629 RMS: 6.154 LUFS: 3.618 +spsa MSD: 3.612 SCE: 2296.033 RMS: 4.474 LUFS: 2.919 +proxy0 MSD: 4.519 SCE: 1239.892 RMS: 5.006 LUFS: 2.656 +Baseline MSD: 3.695 SCE: 968.008 RMS: 11.576 LUFS: 5.373 +Corrupt MSD: 8.048 SCE: 5699.976 RMS: 13.067 LUFS: 3.319 + +broadcast-->warm +100%|███████████████████████████████████████████| 10/10 [00:25<00:00, 2.52s/it] +autodiff MSD: 3.129 SCE: 665.939 RMS: 3.023 LUFS: 0.865 +spsa MSD: 3.759 SCE: 966.383 RMS: 3.144 LUFS: 1.248 +proxy0 MSD: 4.048 SCE: 1058.204 RMS: 6.544 LUFS: 3.036 +Baseline MSD: 5.674 SCE: 566.713 RMS: 6.871 LUFS: 3.654 +Corrupt MSD: 8.552 SCE: 1016.717 RMS: 6.865 LUFS: 4.339 + +telephone-->broadcast +100%|███████████████████████████████████████████| 10/10 [01:54<00:00, 11.46s/it] +autodiff MSD: 6.452 SCE: 780.362 RMS: 2.015 LUFS: 1.112 +spsa MSD: 7.557 SCE: 869.791 RMS: 2.794 LUFS: 1.316 +proxy0 MSD: 9.715 SCE: 2663.854 RMS: 11.838 LUFS: 4.143 +Baseline MSD: 9.111 SCE: 1393.222 RMS: 17.848 LUFS: 7.498 +Corrupt MSD: 11.864 SCE: 1874.800 RMS: 7.548 LUFS: 2.650 + +telephone-->telephone +100%|███████████████████████████████████████████| 10/10 [01:14<00:00, 7.49s/it] +autodiff MSD: 5.115 SCE: 112.800 RMS: 2.275 LUFS: 1.213 +spsa MSD: 5.915 SCE: 155.356 RMS: 3.879 LUFS: 1.523 +proxy0 MSD: 6.797 SCE: 232.531 RMS: 9.969 LUFS: 4.578 +Baseline MSD: 7.631 SCE: 180.769 RMS: 11.585 LUFS: 4.900 +Corrupt MSD: 7.546 SCE: 269.839 RMS: 10.740 LUFS: 4.552 + +telephone-->neutral +100%|███████████████████████████████████████████| 10/10 [00:44<00:00, 4.42s/it] +autodiff MSD: 6.357 SCE: 429.069 RMS: 2.349 LUFS: 0.944 +spsa MSD: 7.289 SCE: 294.847 RMS: 2.050 LUFS: 0.745 +proxy0 MSD: 9.393 SCE: 2912.440 RMS: 12.004 LUFS: 5.530 +Baseline MSD: 9.369 SCE: 1105.188 RMS: 2.573 LUFS: 1.102 +Corrupt MSD: 10.772 SCE: 1039.346 RMS: 4.271 LUFS: 1.896 + +telephone-->bright +100%|███████████████████████████████████████████| 10/10 [02:11<00:00, 13.12s/it] +autodiff MSD: 3.781 SCE: 2569.188 RMS: 7.647 LUFS: 3.893 +spsa MSD: 4.613 SCE: 3651.535 RMS: 3.470 LUFS: 1.885 +proxy0 MSD: 5.099 SCE: 1862.924 RMS: 4.515 LUFS: 1.621 +Baseline MSD: 5.135 SCE: 1852.895 RMS: 29.542 LUFS: 13.255 +Corrupt MSD: 8.587 SCE: 5692.969 RMS: 13.567 LUFS: 4.993 + +telephone-->warm +100%|███████████████████████████████████████████| 10/10 [00:35<00:00, 3.57s/it] +autodiff MSD: 3.965 SCE: 947.003 RMS: 11.701 LUFS: 4.596 +spsa MSD: 5.228 SCE: 1044.573 RMS: 15.173 LUFS: 5.606 +proxy0 MSD: 6.266 SCE: 2187.456 RMS: 45.108 LUFS: 17.355 +Baseline MSD: 9.049 SCE: 1669.726 RMS: 6.980 LUFS: 3.745 +Corrupt MSD: 11.205 SCE: 1524.191 RMS: 5.889 LUFS: 4.271 + +neutral-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:38<00:00, 3.83s/it] +autodiff MSD: 5.919 SCE: 930.650 RMS: 1.039 LUFS: 0.554 +spsa MSD: 6.176 SCE: 1030.785 RMS: 3.445 LUFS: 1.300 +proxy0 MSD: 7.642 SCE: 955.156 RMS: 6.284 LUFS: 3.115 +Baseline MSD: 6.963 SCE: 584.665 RMS: 6.691 LUFS: 2.750 +Corrupt MSD: 7.364 SCE: 1393.375 RMS: 5.646 LUFS: 2.405 + +neutral-->telephone +100%|███████████████████████████████████████████| 10/10 [01:15<00:00, 7.55s/it] +autodiff MSD: 5.892 SCE: 174.921 RMS: 1.639 LUFS: 0.751 +spsa MSD: 6.844 SCE: 205.956 RMS: 6.119 LUFS: 2.741 +proxy0 MSD: 7.938 SCE: 220.270 RMS: 4.373 LUFS: 1.958 +Baseline MSD: 7.668 SCE: 243.061 RMS: 7.975 LUFS: 3.246 +Corrupt MSD: 10.047 SCE: 922.323 RMS: 5.995 LUFS: 3.302 + +neutral-->neutral +100%|███████████████████████████████████████████| 10/10 [00:42<00:00, 4.24s/it] +autodiff MSD: 4.376 SCE: 479.564 RMS: 1.821 LUFS: 0.728 +spsa MSD: 5.174 SCE: 700.180 RMS: 2.980 LUFS: 1.256 +proxy0 MSD: 6.352 SCE: 410.978 RMS: 6.181 LUFS: 2.945 +Baseline MSD: 5.412 SCE: 588.237 RMS: 4.344 LUFS: 2.115 +Corrupt MSD: 6.826 SCE: 944.537 RMS: 3.581 LUFS: 2.232 + +neutral-->bright +100%|███████████████████████████████████████████| 10/10 [01:39<00:00, 9.97s/it] +autodiff MSD: 3.813 SCE: 2155.787 RMS: 5.397 LUFS: 2.366 +spsa MSD: 4.466 SCE: 2451.703 RMS: 3.662 LUFS: 1.580 +proxy0 MSD: 5.212 SCE: 1388.196 RMS: 4.916 LUFS: 2.789 +Baseline MSD: 4.552 SCE: 856.215 RMS: 13.886 LUFS: 6.292 +Corrupt MSD: 8.861 SCE: 5333.085 RMS: 13.093 LUFS: 4.045 + +neutral-->warm +100%|███████████████████████████████████████████| 10/10 [00:26<00:00, 2.64s/it] +autodiff MSD: 3.852 SCE: 1447.159 RMS: 3.055 LUFS: 1.282 +spsa MSD: 4.310 SCE: 1533.351 RMS: 3.364 LUFS: 1.599 +proxy0 MSD: 4.953 SCE: 1366.485 RMS: 8.061 LUFS: 3.729 +Baseline MSD: 6.034 SCE: 851.415 RMS: 5.795 LUFS: 2.972 +Corrupt MSD: 7.863 SCE: 1160.896 RMS: 5.189 LUFS: 2.998 + +bright-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:48<00:00, 4.82s/it] +autodiff MSD: 5.281 SCE: 1287.706 RMS: 2.544 LUFS: 0.997 +spsa MSD: 6.023 SCE: 2018.120 RMS: 3.074 LUFS: 1.097 +proxy0 MSD: 7.068 SCE: 2807.467 RMS: 6.503 LUFS: 3.095 +Baseline MSD: 7.483 SCE: 465.684 RMS: 6.261 LUFS: 2.503 +Corrupt MSD: 7.669 SCE: 4878.829 RMS: 12.448 LUFS: 4.104 + +bright-->telephone +100%|███████████████████████████████████████████| 10/10 [01:31<00:00, 9.16s/it] +autodiff MSD: 5.727 SCE: 294.151 RMS: 2.993 LUFS: 1.275 +spsa MSD: 6.517 SCE: 822.264 RMS: 5.998 LUFS: 2.810 +proxy0 MSD: 8.026 SCE: 314.646 RMS: 6.618 LUFS: 2.634 +Baseline MSD: 6.978 SCE: 297.031 RMS: 5.284 LUFS: 2.069 +Corrupt MSD: 9.277 SCE: 5712.550 RMS: 12.289 LUFS: 4.404 + +bright-->neutral +100%|███████████████████████████████████████████| 10/10 [01:02<00:00, 6.28s/it] +autodiff MSD: 5.694 SCE: 1402.436 RMS: 2.791 LUFS: 1.241 +spsa MSD: 6.566 SCE: 2049.418 RMS: 2.237 LUFS: 0.772 +proxy0 MSD: 7.959 SCE: 2064.933 RMS: 6.404 LUFS: 3.087 +Baseline MSD: 7.848 SCE: 706.934 RMS: 1.997 LUFS: 0.614 +Corrupt MSD: 8.448 SCE: 4981.422 RMS: 15.911 LUFS: 4.215 + +bright-->bright +100%|███████████████████████████████████████████| 10/10 [01:30<00:00, 9.01s/it] +autodiff MSD: 3.347 SCE: 1051.300 RMS: 2.393 LUFS: 1.379 +spsa MSD: 3.728 SCE: 909.130 RMS: 3.405 LUFS: 1.348 +proxy0 MSD: 4.121 SCE: 836.498 RMS: 6.039 LUFS: 2.327 +Baseline MSD: 4.790 SCE: 980.283 RMS: 11.279 LUFS: 4.587 +Corrupt MSD: 5.055 SCE: 2172.028 RMS: 12.907 LUFS: 4.884 + +bright-->warm +100%|███████████████████████████████████████████| 10/10 [00:36<00:00, 3.65s/it] +autodiff MSD: 3.569 SCE: 1351.627 RMS: 5.278 LUFS: 2.889 +spsa MSD: 4.743 SCE: 2563.007 RMS: 13.225 LUFS: 5.640 +proxy0 MSD: 5.911 SCE: 5251.111 RMS: 27.244 LUFS: 10.150 +Baseline MSD: 9.030 SCE: 846.539 RMS: 4.628 LUFS: 2.327 +Corrupt MSD: 8.576 SCE: 5245.425 RMS: 18.363 LUFS: 3.916 + +warm-->broadcast +100%|███████████████████████████████████████████| 10/10 [00:58<00:00, 5.90s/it] +autodiff MSD: 5.288 SCE: 510.650 RMS: 1.278 LUFS: 0.772 +spsa MSD: 5.692 SCE: 1178.561 RMS: 2.900 LUFS: 1.194 +proxy0 MSD: 9.389 SCE: 988.039 RMS: 7.070 LUFS: 4.741 +Baseline MSD: 6.917 SCE: 1349.983 RMS: 11.752 LUFS: 4.370 +Corrupt MSD: 7.793 SCE: 1454.567 RMS: 8.747 LUFS: 3.901 + +warm-->telephone +100%|███████████████████████████████████████████| 10/10 [01:18<00:00, 7.82s/it] +autodiff MSD: 4.924 SCE: 127.741 RMS: 3.423 LUFS: 1.673 +spsa MSD: 6.156 SCE: 592.691 RMS: 8.652 LUFS: 4.304 +proxy0 MSD: 5.924 SCE: 564.229 RMS: 12.786 LUFS: 5.267 +Baseline MSD: 9.343 SCE: 813.276 RMS: 10.653 LUFS: 3.503 +Corrupt MSD: 11.138 SCE: 983.834 RMS: 5.029 LUFS: 2.949 + +warm-->neutral +100%|███████████████████████████████████████████| 10/10 [01:10<00:00, 7.07s/it] +autodiff MSD: 5.059 SCE: 474.697 RMS: 1.627 LUFS: 0.606 +spsa MSD: 5.559 SCE: 979.341 RMS: 2.158 LUFS: 0.808 +proxy0 MSD: 8.685 SCE: 860.294 RMS: 4.988 LUFS: 3.829 +Baseline MSD: 7.061 SCE: 1128.742 RMS: 13.217 LUFS: 5.004 +Corrupt MSD: 7.815 SCE: 1242.012 RMS: 5.997 LUFS: 2.878 + +warm-->bright +100%|███████████████████████████████████████████| 10/10 [01:16<00:00, 7.67s/it] +autodiff MSD: 3.362 SCE: 1751.219 RMS: 5.462 LUFS: 2.800 +spsa MSD: 4.090 SCE: 2119.844 RMS: 5.119 LUFS: 2.240 +proxy0 MSD: 5.014 SCE: 1057.485 RMS: 4.544 LUFS: 1.985 +Baseline MSD: 5.832 SCE: 2804.064 RMS: 19.218 LUFS: 7.381 +Corrupt MSD: 8.620 SCE: 6288.224 RMS: 17.063 LUFS: 4.609 + +warm-->warm +100%|███████████████████████████████████████████| 10/10 [00:40<00:00, 4.07s/it] +autodiff MSD: 3.549 SCE: 646.522 RMS: 4.088 LUFS: 2.016 +spsa MSD: 3.863 SCE: 699.310 RMS: 4.146 LUFS: 1.984 +proxy0 MSD: 4.896 SCE: 857.195 RMS: 7.609 LUFS: 3.748 +Baseline MSD: 4.317 SCE: 323.523 RMS: 6.576 LUFS: 3.416 +Corrupt MSD: 4.534 SCE: 1513.056 RMS: 4.237 LUFS: 2.441 + +----- Averages ---- MUSDB18 +autodiff MSD: 4.741 SCE: 909.368 RMS: 3.456 LUFS: 1.629 +spsa MSD: 5.445 SCE: 1220.946 RMS: 4.650 LUFS: 2.040 +proxy0 MSD: 6.693 SCE: 1393.777 RMS: 9.300 LUFS: 4.156 +Baseline MSD: 6.813 SCE: 870.955 RMS: 10.099 LUFS: 4.316 +Corrupt MSD: 8.420 SCE: 2607.409 RMS: 9.356 LUFS: 3.606 \ No newline at end of file diff --git a/results/eval_probes.md b/results/eval_probes.md new file mode 100644 index 0000000..cd656ea --- /dev/null +++ b/results/eval_probes.md @@ -0,0 +1,224 @@ +# all probes +CUDA_VISIBLE_DEVICES=0 python scripts/eval_probes.py \ +--ckpt_dir /import/c4dm-datasets/deepafx_st/logs/probes_100/speech \ +--eval_dataset /import/c4dm-datasets/deepafx_st/daps_24000_styles_100/ \ +--subset test \ +--audio_type speech \ +--output_dir probes \ +--gpu \ + +CUDA_VISIBLE_DEVICES=0 python scripts/eval_probes.py \ +--ckpt_dir /import/c4dm-datasets/deepafx_st/logs/probes_100/music \ +--eval_dataset /import/c4dm-datasets/deepafx_st/musdb18_44100_styles_100/ \ +--audio_type music \ +--subset test \ +--output_dir probes \ +--gpu \ + +------------------------------------------------------- +-------------------speech--------------------- +------------------------------------------------------- +true acc: 100.00% f1: 1.00 + precision recall f1-score support + + broadcast 1.00 1.00 1.00 20 + telephone 1.00 1.00 1.00 20 + neutral 1.00 1.00 1.00 20 + bright 1.00 1.00 1.00 20 + warm 1.00 1.00 1.00 20 + + accuracy 1.00 100 + macro avg 1.00 1.00 1.00 100 +weighted avg 1.00 1.00 1.00 100 + +deepafx_st_spsa-linear acc: 96.00% f1: 0.96 + precision recall f1-score support + + broadcast 0.94 0.85 0.89 20 + telephone 0.95 1.00 0.98 20 + neutral 0.95 0.95 0.95 20 + bright 0.95 1.00 0.98 20 + warm 1.00 1.00 1.00 20 + + accuracy 0.96 100 + macro avg 0.96 0.96 0.96 100 +weighted avg 0.96 0.96 0.96 100 + +deepafx_st_proxy0-linear acc: 100.00% f1: 1.00 + precision recall f1-score support + + broadcast 1.00 1.00 1.00 20 + telephone 1.00 1.00 1.00 20 + neutral 1.00 1.00 1.00 20 + bright 1.00 1.00 1.00 20 + warm 1.00 1.00 1.00 20 + + accuracy 1.00 100 + macro avg 1.00 1.00 1.00 100 +weighted avg 1.00 1.00 1.00 100 + +openl3-linear acc: 30.00% f1: 0.23 + precision recall f1-score support + + broadcast 1.00 0.05 0.10 20 + telephone 0.25 0.15 0.19 20 + neutral 0.14 0.25 0.18 20 + bright 0.43 1.00 0.61 20 + warm 0.20 0.05 0.08 20 + + accuracy 0.30 100 + macro avg 0.40 0.30 0.23 100 +weighted avg 0.40 0.30 0.23 100 + +random_mel-linear acc: 64.00% f1: 0.55 + precision recall f1-score support + + broadcast 0.55 0.30 0.39 20 + telephone 0.77 1.00 0.87 20 + neutral 0.00 0.00 0.00 20 + bright 0.69 0.90 0.78 20 + warm 0.57 1.00 0.73 20 + + accuracy 0.64 100 + macro avg 0.52 0.64 0.55 100 +weighted avg 0.52 0.64 0.55 100 + +cdpam-linear acc: 76.00% f1: 0.73 + precision recall f1-score support + + broadcast 0.75 0.15 0.25 20 + telephone 1.00 1.00 1.00 20 + neutral 0.47 1.00 0.63 20 + bright 1.00 1.00 1.00 20 + warm 1.00 0.65 0.79 20 + + accuracy 0.76 100 + macro avg 0.84 0.76 0.73 100 +weighted avg 0.84 0.76 0.73 100 + +deepafx_st_autodiff-linear acc: 100.00% f1: 1.00 + precision recall f1-score support + + broadcast 1.00 1.00 1.00 20 + telephone 1.00 1.00 1.00 20 + neutral 1.00 1.00 1.00 20 + bright 1.00 1.00 1.00 20 + warm 1.00 1.00 1.00 20 + + accuracy 1.00 100 + macro avg 1.00 1.00 1.00 100 +weighted avg 1.00 1.00 1.00 100 + +------------------------------------------------------- +epoch=4-step=94-val-deepafx_st_spsa-linear.ckpt +epoch=399-step=7599-val-deepafx_st_proxy0-linear.ckpt +Proxy Processor: peq @ fs=24000 Hz +TCN receptive field: 7021 samples or 292.542 ms +Proxy Processor: comp @ fs=24000 Hz +TCN receptive field: 7021 samples or 292.542 ms +epoch=86-step=1652-val-openl3-linear.ckpt +epoch=399-step=7599-val-random_mel-linear.ckpt +epoch=398-step=7580-val-cdpam-linear.ckpt +epoch=399-step=7599-val-deepafx_st_autodiff-linear.ckpt +100%|███████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 40.57it/s] +100%|███████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 45.07it/s] +100%|███████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 44.43it/s] +100%|███████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 42.75it/s] +100%|███████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 42.04it/s] +Loaded 100 examples for test subset. +100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:09<00:00, 11.10it/s] +------------------------------------------------------- +-------------------music--------------------- +------------------------------------------------------- +true acc: 100.00% f1: 1.00 + precision recall f1-score support + + broadcast 1.00 1.00 1.00 20 + telephone 1.00 1.00 1.00 20 + neutral 1.00 1.00 1.00 20 + bright 1.00 1.00 1.00 20 + warm 1.00 1.00 1.00 20 + + accuracy 1.00 100 + macro avg 1.00 1.00 1.00 100 +weighted avg 1.00 1.00 1.00 100 + +deepafx_st_spsa-linear acc: 64.00% f1: 0.63 + precision recall f1-score support + + broadcast 0.23 0.30 0.26 20 + telephone 0.95 1.00 0.98 20 + neutral 0.00 0.00 0.00 20 + bright 1.00 1.00 1.00 20 + warm 0.90 0.90 0.90 20 + + accuracy 0.64 100 + macro avg 0.62 0.64 0.63 100 +weighted avg 0.62 0.64 0.63 100 + +deepafx_st_proxy0-linear acc: 82.00% f1: 0.82 + precision recall f1-score support + + broadcast 0.65 0.55 0.59 20 + telephone 0.95 1.00 0.98 20 + neutral 0.57 0.65 0.60 20 + bright 1.00 1.00 1.00 20 + warm 0.95 0.90 0.92 20 + + accuracy 0.82 100 + macro avg 0.82 0.82 0.82 100 +weighted avg 0.82 0.82 0.82 100 + +openl3-linear acc: 38.00% f1: 0.33 + precision recall f1-score support + + broadcast 0.67 0.10 0.17 20 + telephone 0.25 0.45 0.32 20 + neutral 0.38 0.25 0.30 20 + bright 0.50 0.95 0.66 20 + warm 0.30 0.15 0.20 20 + + accuracy 0.38 100 + macro avg 0.42 0.38 0.33 100 +weighted avg 0.42 0.38 0.33 100 + +random_mel-linear acc: 62.00% f1: 0.51 + precision recall f1-score support + + broadcast 0.50 0.10 0.17 20 + telephone 0.67 1.00 0.80 20 + neutral 0.00 0.00 0.00 20 + bright 0.95 1.00 0.98 20 + warm 0.44 1.00 0.62 20 + + accuracy 0.62 100 + macro avg 0.51 0.62 0.51 100 +weighted avg 0.51 0.62 0.51 100 + +cdpam-linear acc: 60.00% f1: 0.51 + precision recall f1-score support + + broadcast 0.00 0.00 0.00 20 + telephone 0.80 1.00 0.89 20 + neutral 0.07 0.05 0.06 20 + bright 0.91 1.00 0.95 20 + warm 0.50 0.95 0.66 20 + + accuracy 0.60 100 + macro avg 0.46 0.60 0.51 100 +weighted avg 0.46 0.60 0.51 100 + +deepafx_st_autodiff-linear acc: 79.00% f1: 0.79 + precision recall f1-score support + + broadcast 0.52 0.55 0.54 20 + telephone 0.95 1.00 0.98 20 + neutral 0.50 0.50 0.50 20 + bright 1.00 1.00 1.00 20 + warm 1.00 0.90 0.95 20 + + accuracy 0.79 100 + macro avg 0.80 0.79 0.79 100 +weighted avg 0.80 0.79 0.79 100 + +------------------------------------------------------- \ No newline at end of file diff --git a/results/eval_time.md b/results/eval_time.md new file mode 100644 index 0000000..d832a65 --- /dev/null +++ b/results/eval_time.md @@ -0,0 +1,39 @@ +# Machine +sandle +Intel(R) Xeon(R) CPU E5-2623 v3 @ 3.00GHz (16 core) +GeForce GTX 1080 Ti + +# 100 + +dsp_infer : sec/step 0.0177 0.0035 RTF +autodiff_cpu_infer : sec/step 0.0256 0.0051 RTF +autodiff_gpu_infer : sec/step 0.0047 0.0009 RTF +tcn1_cpu_infer : sec/step 0.7828 0.1566 RTF +tcn2_cpu_infer : sec/step 1.3870 0.2774 RTF +tcn1_gpu_infer : sec/step 0.0116 0.0023 RTF +tcn2_gpu_infer : sec/step 0.0222 0.0044 RTF +autodiff_gpu_grad : sec/step 0.3009 0.0602 RTF +np_norm_gpu_grad : sec/step 0.3880 0.0776 RTF +np_hh_gpu_grad : sec/step 0.4226 0.0845 RTF +np_fh_gpu_grad : sec/step 0.4319 0.0864 RTF +tcn1_gpu_grad : sec/step 0.4323 0.0865 RTF +tcn2_gpu_grad : sec/step 0.6371 0.1274 RTF +spsa_gpu_grad : sec/step 0.3945 0.0789 RTF + +# 1000 + +rb_infer : sec/step 0.0186 0.0037 RTF +dsp_infer : sec/step 0.0172 0.0034 RTF +autodiff_cpu_infer : sec/step 0.0295 0.0059 RTF +autodiff_gpu_infer : sec/step 0.0049 0.0010 RTF +tcn1_cpu_infer : sec/step 0.6580 0.1316 RTF +tcn2_cpu_infer : sec/step 1.3409 0.2682 RTF +tcn1_gpu_infer : sec/step 0.0114 0.0023 RTF +tcn2_gpu_infer : sec/step 0.0223 0.0045 RTF +autodiff_gpu_grad : sec/step 0.3086 0.0617 RTF +np_norm_gpu_grad : sec/step 0.4346 0.0869 RTF +np_hh_gpu_grad : sec/step 0.4379 0.0876 RTF +np_fh_gpu_grad : sec/step 0.4339 0.0868 RTF +tcn1_gpu_grad : sec/step 0.4382 0.0876 RTF +tcn2_gpu_grad : sec/step 0.6424 0.1285 RTF +spsa_gpu_grad : sec/step 0.4132 0.0826 RTF \ No newline at end of file diff --git a/scripts/download.py b/scripts/download.py new file mode 100755 index 0000000..ed30cf9 --- /dev/null +++ b/scripts/download.py @@ -0,0 +1,458 @@ +#!/usr/bin/env python3 +# ************************************************************************* +# Copyright 2021 Adobe Systems Incorporated. +# +# Please see the attached LICENSE file for more information. +# +# **************************************************************************/ + +import os +import sox +import wget +import glob +import torch +import shutil +import resampy +import hashlib +import itertools +import subprocess +import torchaudio +import numpy as np +import multiprocessing +import soundfile as sf + +from tqdm import tqdm +from deepafx_st import utils +from argparse import ArgumentParser +from joblib import Parallel, delayed + + +def resample_file(spkr_file, sr): + x, sr_orig = sf.read(spkr_file) + if sr_orig != sr: + x = resampy.resample(x, sr_orig, sr, axis=0) + return x + + +def ffmpeg_resample(input_avfile, output_audiofile, sr, channels=None): + cmd = ["ffmpeg", "-y", "-i", input_avfile, + "-ar", str(sr), "-ac", str(1), output_audiofile, + "-hide_banner", "-loglevel", "error", ] + completed_process = subprocess.run(cmd) + return completed_process + + +def resample_dir(input_dir, output_dir, target_sr, channels=1, num_cpus=33): + + # files = get_audio_file_list(input_dir) + files = glob.glob(os.path.join(input_dir, "*.mp3")) + + os.makedirs(output_dir, exist_ok=True) + + file_pairs = [] + for file in files: + new_file = os.path.join(output_dir, os.path.basename(file)[:-4] + ".wav") + file_pairs.append((file, new_file, target_sr)) + + def par_resample(item): + orig, new, sr = item + ffmpeg_resample(orig, new, sr, channels=1) + return True + + # FFMPEG seems to have issue when multi-threaded + # results = Parallel(n_jobs=num_cpus)( + # delayed(par_resample)(i) for i in tqdm(file_pairs) + # ) + for item in tqdm(file_pairs): + par_resample(item) + + + +def resample_file_torchaudio(spkr_file, sr): + x, sr_orig = torchaudio.load(spkr_file) + x = x.numpy() + if sr_orig != sr: + x = resampy.resample(x, sr_orig, sr) + x = torch.tensor(x) + return x + + +def download_daps_dataset(output_dir): + """Download and resample the DAPS dataset to a given sample rate.""" + + archive_path = os.path.join(output_dir, "daps.tar.gz") + + cmd = f"wget -O {archive_path} https://zenodo.org/record/4660670/files/daps.tar.gz?download=1" + os.system(cmd) + + # Untar + print("Extracting tar...") + cmd = f"tar -xvf {archive_path} -C {output_dir}" + os.system(cmd) + + +def process_daps_dataset(output_dir): + + set_dirs = glob.glob(os.path.join(output_dir, "daps", "*")) + + for sr in [16000, 24000, 44100]: + resampled_output_dir = os.path.join(output_dir, f"daps_{sr}") + if not os.path.isdir(resampled_output_dir): + os.makedirs(resampled_output_dir) + + for set_dir in set_dirs: + print(set_dir) + if "produced" in set_dir or "cleanraw" in set_dir: + # get all files in speaker directory + spkr_files = glob.glob(os.path.join(set_dir, "*.wav")) + + with multiprocessing.Pool(16) as pool: + audios = pool.starmap( + resample_file, + zip(spkr_files, itertools.repeat(sr)), + ) + + for spkr_file, audio in tqdm(zip(spkr_files, audios)): + spkr_id = os.path.basename(spkr_file) + + if not os.path.isdir( + os.path.join( + resampled_output_dir, f"{os.path.basename(set_dir)}" + ) + ): + os.makedirs( + os.path.join( + resampled_output_dir, f"{os.path.basename(set_dir)}" + ) + ) + + out_filepath = os.path.join( + resampled_output_dir, + f"{os.path.basename(set_dir)}", + f"{spkr_id}", + ) + sf.write(out_filepath, audio, sr) + +def download_vctk_dataset(output_dir): + + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + + archive_path = os.path.join(output_dir, "vctk.zip") + + cmd = ( + f"wget -O {archive_path} https://datashare.ed.ac.uk/download/DS_10283_3443.zip" + ) + os.system(cmd) + + # Untar + print("Extracting zip...") + cmd = f"unzip {archive_path} -C {output_dir}" + os.system(cmd) + + +def process_vctk_dataset(output_dir, num_workers=16): + + spkr_dirs = glob.glob( + os.path.join( + output_dir, + "VCTK-Corpus-0.92", + "wav48_silence_trimmed", + "*", + ) + ) + + for sr in [16000, 24000, 44100]: + + resampled_output_dir = os.path.join(output_dir, f"vctk_{sr}") + if not os.path.isdir(resampled_output_dir): + os.makedirs(resampled_output_dir) + + for spkr_dir in tqdm(spkr_dirs, ncols=80): + print(spkr_dir) + # get all files in speaker directory + spkr_files = glob.glob(os.path.join(spkr_dir, "*.flac")) + spkr_id = os.path.basename(spkr_dir) + + if len(spkr_files) > 0: + with multiprocessing.Pool(num_workers) as pool: + audios = pool.starmap( + resample_file, + zip(spkr_files, itertools.repeat(sr)), + ) + # combine all audio files into one long file + x = np.concatenate(audios, axis=-1) + + out_filepath = os.path.join(resampled_output_dir, f"{spkr_id}.wav") + sf.write(out_filepath, x, sr) + else: + print(f"{spkr_dir} contained no audio files.") + + +def download_libritts_dataset(output_dir, sr=24000): + + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + + archive_path = os.path.join(output_dir, "train-clean-360.tar.gz") + + cmd = f"wget -O {archive_path} https://www.openslr.org/resources/60/train-clean-360.tar.gz" + os.system(cmd) + + # Untar + print("Extracting tar...") + cmd = f"tar -xvf {archive_path} -C {output_dir}" + os.system(cmd) + + +def process_libritts_dataset(output_dir): + + spkr_dirs = glob.glob( + os.path.join( + output_dir, + "LibriTTS", + "train-clean-360", + "*", + ) + ) + + for sr in [16000, 24000]: + + resampled_output_dir = os.path.join( + output_dir, + "LibriTTS", + f"train_clean_360_{sr}c", + ) + if not os.path.isdir(resampled_output_dir): + os.makedirs(resampled_output_dir) + + for spkr_dir in tqdm(spkr_dirs, ncols=80): + # get all book directories + spkr_id = os.path.basename(spkr_dir) + book_dirs = glob.glob(os.path.join(spkr_dir, "*")) + + spkr_files = [] + for book_dir in book_dirs: + # get all files in speaker directory + spkr_files += glob.glob(os.path.join(book_dir, "*.wav")) + print( + f"Found {len(book_dirs)} books with {len(spkr_files)} files by {spkr_id}" + ) + + if len(spkr_files) > 0: + with multiprocessing.Pool(16) as pool: + audios = pool.starmap( + resample_file, + zip(spkr_files, itertools.repeat(sr)), + ) + # combine all audio files into one long file + x = np.concatenate(audios, axis=-1) + # print(x.shape, (x.shape[0] / sr) / 60) + + out_filepath = os.path.join(resampled_output_dir, f"{spkr_id}.wav") + sf.write(out_filepath, x, sr) + else: + print(f"{spkr_dir} contained no audio files.") + + +def download_jamendo_dataset(output_dir): + + hash_url = "https://essentia.upf.edu/datasets/mtg-jamendo/autotagging_moodtheme/audio/checksums_sha256.txt" + cmd = f"""wget -O {os.path.join(output_dir, "checksums_sha256.txt")} {hash_url}""" + os.system(cmd) + + with open(os.path.join(output_dir, "checksums_sha256.txt"), "r") as fp: + hashes = fp.readlines() + + hash_dict = {} + for sha256_hash in hashes: + value = sha256_hash.split(" ")[0] + fname = sha256_hash.split(" ")[1].strip("\n") + hash_dict[fname] = value + + for n in range(100): + base_url = ( + "https://essentia.upf.edu/datasets/mtg-jamendo/autotagging_moodtheme/audio/" + ) + fname = f"autotagging_moodtheme_audio-{n:02}.tar" + url = base_url + fname + # check if file has been downloaded + if os.path.isfile(os.path.join(output_dir, fname)): + + # comute hash for downloaded file + sha256_hash = check_sha256(os.path.join(output_dir, fname)) + + # check this against out dictionary + if sha256_hash == hash_dict[fname]: + print(f"Checksum PASSED. Skipping {fname}...") + continue + else: + print("Checksum FAILED. Re-downloading...") + + cmd = f"wget -O {os.path.join(output_dir, fname)} {url}" + os.system(cmd) + + for n in range(100): + fname = f"autotagging_moodtheme_audio-{n:02}.tar" + # Untar + print(f"Extracting {fname}...") + cmd = f"tar -xvf {os.path.join(output_dir, fname)} -C {output_dir}" + os.system(cmd) + +def process_jamendo_dataset(output_dir): + + num_cpus = multiprocessing.cpu_count() + set_dirs = [] + for n in range(100): + set_dirs.append(os.path.join(output_dir, str(n))) + + for sr in [24000]: + resampled_output_dir = os.path.join(output_dir, f"mtg-jamendo_{sr}") + if not os.path.isdir(resampled_output_dir): + os.makedirs(resampled_output_dir) + + for set_dir in set_dirs: + # get all files in speaker directory + resample_dir(set_dir, resampled_output_dir, sr, channels=1, num_cpus=num_cpus) + + +def download_musdb_dataset(output_dir): + # from https://zenodo.org/record/3338373. + cmd = 'wget https://zenodo.org/record/3338373/files/musdb18hq.zip?download=1 -O ' + os.path.join(output_dir, 'musdb18.zip') + os.system(cmd) + + cmd = 'unzip ' + os.path.join(output_dir, 'musdb18.zip') + ' -d ' + os.path.join(output_dir, 'musdb18') + os.system(cmd) + + +def process_musdb_dataset(output_dir): + + def resample_file(item): + orig, new, sr = item + x, sr_orig = sf.read(orig) + if sr_orig != sr: + x = resampy.resample(x, sr_orig, sr, axis=0) + + sf.write(new, x, sr) + return True + + + mix_files = glob.glob(os.path.join(output_dir, 'musdb18', "**", "*.wav"), recursive=True) + mix_files = [mix_file for mix_file in mix_files if "mix" in mix_file] + + items = [] + for sr in [24000, 44100]: + resampled_output_dir = os.path.join(output_dir, f"musdb18_{sr}") + if not os.path.isdir(resampled_output_dir): + os.makedirs(resampled_output_dir) + + for mix_file in mix_files: + song_id = os.path.basename(os.path.dirname(mix_file)).replace(" ", "") + out_filepath = os.path.join( + resampled_output_dir, + f"{song_id}.wav", + ) + items.append((mix_file, out_filepath, sr)) + + + num_cpus = multiprocessing.cpu_count() + results = Parallel(n_jobs=num_cpus)( + delayed(resample_file)(i) for i in tqdm(items) + ) + + + +if __name__ == "__main__": + + parser = ArgumentParser(description="Download all models and datasets.") + parser.add_argument( + "--checkpoint", + help="Download pre-trained model checkpoints.", + action="store_true", + ) + parser.add_argument( + "--datasets", + help="Datasets to download.", + nargs="+", + default=[ + "daps", + "vctk", + "jamendo", + "libritts", + "musdb", + ], + ) + parser.add_argument( + "-d", + "--download", + help="Download the dataset.", + action="store_true", + ) + parser.add_argument( + "-p", + "--process", + help="Process the dataset assuming it is already downloaded.", + action="store_true", + ) + parser.add_argument( + "--output", + help="Root directory to download dataset.", + default=None, + ) + parser.add_argument( + "--num_workers", + help="Number of parallel workers", + type=int, + default=16, + ) + args = parser.parse_args() + + if args.output is None: + args.output = "./" + + if not os.path.isdir(args.output): + os.makedirs(args.output) + + for dataset in args.datasets: + if dataset == "daps": + + if args.download: + print("Downloading DAPS...") + download_daps_dataset(args.output) + if args.process: + print(f"Processing DAPS dataset...") + process_daps_dataset(args.output) + elif dataset == "vctk": + if args.download: + print("Downloading VCTK...") + download_vctk_dataset(args.output) + if args.process: + print(f"Processing VCTK dataset...") + process_vctk_dataset(args.output) + elif dataset == "libritts": + + if args.download: + print("Downloading LibriTTS...") + download_libritts_dataset(args.output) + if args.process: + print(f"Processing libriTTS dataset...") + process_libritts_dataset(args.output) + elif dataset == "jamendo": + + if args.download: + print(f"Downloading Jamendo dataset...") + download_jamendo_dataset(args.output) + if args.process: + print(f"Processing Jamendo dataset...") + process_jamendo_dataset(args.output) + elif dataset == "musdb": + + if args.download: + print(f"Downloading MUSDB dataset...") + download_musdb_dataset(args.output) + if args.process: + print(f"Processing MUSDB dataset...") + process_musdb_dataset(args.output) + else: + print("\nInvalid dataset.\n") + parser.print_help() diff --git a/scripts/effect_plotting.py b/scripts/effect_plotting.py new file mode 100644 index 0000000..21558ed --- /dev/null +++ b/scripts/effect_plotting.py @@ -0,0 +1,117 @@ +import numpy as np +import scipy.signal +import matplotlib.pyplot as plt + +from deepafx_st.processors.dsp.peq import biqaud + + +def plot_peq_response( + p_peq_denorm, + sr, + ax=None, + label=None, + color=None, + points=False, + center_line=False, +): + + ls_gain = p_peq_denorm[0] + ls_freq = p_peq_denorm[1] + ls_q = p_peq_denorm[2] + b0, a0 = biqaud(ls_gain, ls_freq, ls_q, sr, filter_type="low_shelf") + sos0 = np.concatenate((b0, a0)) + + f1_gain = p_peq_denorm[3] + f1_freq = p_peq_denorm[4] + f1_q = p_peq_denorm[5] + b1, a1 = biqaud(f1_gain, f1_freq, f1_q, sr, filter_type="peaking") + sos1 = np.concatenate((b1, a1)) + + f2_gain = p_peq_denorm[6] + f2_freq = p_peq_denorm[7] + f2_q = p_peq_denorm[8] + b2, a2 = biqaud(f2_gain, f2_freq, f2_q, sr, filter_type="peaking") + sos2 = np.concatenate((b2, a2)) + + f3_gain = p_peq_denorm[9] + f3_freq = p_peq_denorm[10] + f3_q = p_peq_denorm[11] + b3, a3 = biqaud(f3_gain, f3_freq, f3_q, sr, filter_type="peaking") + sos3 = np.concatenate((b3, a3)) + + f4_gain = p_peq_denorm[12] + f4_freq = p_peq_denorm[13] + f4_q = p_peq_denorm[14] + b4, a4 = biqaud(f4_gain, f4_freq, f4_q, sr, filter_type="peaking") + sos4 = np.concatenate((b4, a4)) + + hs_gain = p_peq_denorm[15] + hs_freq = p_peq_denorm[16] + hs_q = p_peq_denorm[17] + b5, a5 = biqaud(hs_gain, hs_freq, hs_q, sr, filter_type="high_shelf") + sos5 = np.concatenate((b5, a5)) + + sos = [sos0, sos1, sos2, sos3, sos4, sos5] + sos = np.array(sos) + # print(sos.shape) + # print(sos) + + # measure freq response + w, h = scipy.signal.sosfreqz(sos, fs=22050, worN=2048) + + if ax is None: + fig, axs = plt.subplots() + + if center_line: + ax.plot(w, np.zeros(w.shape), color="lightgray") + + ax.plot(w, 20 * np.log10(np.abs(h)), label=label, color=color) + if points: + ax.scatter(ls_freq, ls_gain, color=color) + ax.scatter(f1_freq, f1_gain, color=color) + ax.scatter(f2_freq, f2_gain, color=color) + ax.scatter(f3_freq, f3_gain, color=color) + ax.scatter(f4_freq, f4_gain, color=color) + ax.scatter(hs_freq, hs_gain, color=color) + + +def plot_comp_response( + p_comp_denorm, + sr, + ax=None, + label=None, + color=None, + center_line=False, +): + + # get parameters + threshold = p_comp_denorm[0] + ratio = p_comp_denorm[1] + attack_ms = p_comp_denorm[2] * 1000 + release_ms = p_comp_denorm[3] * 1000 + knee_db = p_comp_denorm[4] + makeup_db = p_comp_denorm[5] + + # print(knee_db) + + x = np.linspace(-80, 0) # input level + y = np.zeros(x.shape) # output level + + idx = np.where((2 * (x - threshold)) < -knee_db) + y[idx] = x[idx] + + idx = np.where((2 * np.abs(x - threshold)) <= knee_db) + y[idx] = x[idx] + ( + (1 / ratio - 1) * (((x[idx] - threshold + (knee_db / 2))) ** 2) + ) / (2 * knee_db) + + idx = np.where((2 * (x - threshold)) > knee_db) + y[idx] = threshold + ((x[idx] - threshold) / (ratio)) + + text_height = threshold + ((0 - threshold) / (ratio)) + + # plot the first part of the line + ax.plot(x, y, label=label, color=color) + if center_line: + ax.plot(x, x, color="lightgray", linestyle="--") + ax.text(0, text_height, f"{threshold:0.1f} dB {ratio:0.1f}:1") diff --git a/scripts/eval_probes.py b/scripts/eval_probes.py new file mode 100644 index 0000000..336c35a --- /dev/null +++ b/scripts/eval_probes.py @@ -0,0 +1,149 @@ +import os +import glob +import torch +import argparse +import torchaudio +import numpy as np +from tqdm import tqdm +import matplotlib.pyplot as plt +from sklearn.metrics import ConfusionMatrixDisplay, accuracy_score, f1_score +from sklearn.metrics import classification_report + +from deepafx_st.data.style import StyleDataset +from deepafx_st.probes.probe_system import ProbeSystem + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--ckpt_dir", + help="Path to top-level directory with probe checkpoints.", + type=str, + ) + parser.add_argument( + "--eval_dataset", + help="Path to directory containing style dataset for evaluation.", + type=str, + ) + parser.add_argument( + "--audio_type", + help="Evaluate only models trained on this type of audio 'speech' or 'music'.", + type=str, + ) + parser.add_argument( + "--output_dir", + help="Save outputs here.", + type=str, + ) + parser.add_argument( + "--subset", + help="One of either train, val, or test.", + type=str, + default="val", + ) + parser.add_argument("--gpu", help="Use gpu.", action="store_true") + args = parser.parse_args() + + # ------------------ load models ------------------ + models = {} # storage for pretrained models + model_dirs = glob.glob(os.path.join(args.ckpt_dir, "*")) + model_dirs = [md for md in model_dirs if os.path.isdir(md)] + + for model_dir in model_dirs: + model_name = os.path.basename(model_dir) + ckpt_paths = glob.glob( + os.path.join( + model_dir, + "lightning_logs", + "version_0", + "checkpoints", + "*.ckpt", + ) + ) + + if len(ckpt_paths) < 1: + print(f"WARNING: No checkpoint found for {model_name} model.") + continue + + ckpt_path = ckpt_paths[0] + + if args.audio_type not in ckpt_path: + print(f"Skipping {ckpt_path}") + continue + + print(os.path.basename(ckpt_path)) + if "speech" in ckpt_path: + deepafx_st_autodiff_ckpt = "checkpoints/style/libritts/autodiff/lightning_logs/version_1/checkpoints/epoch=367-step=1226911-val-libritts-autodiff.ckpt" + deepafx_st_spsa_ckpt = "checkpoints/style/libritts/spsa/lightning_logs/version_2/checkpoints/epoch=367-step=1226911-val-libritts-spsa.ckpt" + deepafx_st_proxy0_ckpt = "checkpoints/style/libritts/proxy0/lightning_logs/version_0/checkpoints/epoch=327-step=1093551-val-libritts-proxy0.ckpt" + elif "music" in ckpt_path: + deepafx_st_autodiff_ckpt = "checkpoints/style/jamendo/autodiff/lightning_logs/version_0/checkpoints/epoch=362-step=1210241-val-jamendo-autodiff.ckpt" + deepafx_st_spsa_ckpt = "checkpoints/style/jamendo/spsa/lightning_logs/version_0/checkpoints/epoch=362-step=1210241-val-jamendo-spsa.ckpt" + deepafx_st_proxy0_ckpt = "checkpoints/style/jamendo/proxy0/lightning_logs/version_0/checkpoints/epoch=362-step=1210241-val-jamendo-proxy0.ckpt" + + model = ProbeSystem.load_from_checkpoint( + ckpt_path, + strict=False, + deepafx_st_autodiff_ckpt=deepafx_st_autodiff_ckpt, + deepafx_st_spsa_ckpt=deepafx_st_spsa_ckpt, + deepafx_st_proxy0_ckpt=deepafx_st_proxy0_ckpt, + ) + model.eval() + if args.gpu: + model.cuda() + models[model_name] = model + + # create evaluation dataset + eval_dataset = StyleDataset( + args.eval_dataset, + subset=args.subset, + ) + + # iterate over dataset and make predictions with all models + preds = {"true": []} + for bidx, batch in enumerate(tqdm(eval_dataset), 0): + x, y = batch + + if args.gpu: + x = x.to("cuda") + + preds["true"].append(y) + + for model_name, model in models.items(): + with torch.no_grad(): + y_hat = model(x.view(1, 1, -1)) + + if model_name not in preds: + preds[model_name] = [] + + preds[model_name].append(y_hat.argmax().cpu()) + + # create confusion matracies + print("-------------------------------------------------------") + print(f"-------------------{args.audio_type}---------------------") + print("-------------------------------------------------------") + if not os.path.isdir(args.output_dir): + os.makedirs(args.output_dir) + for model_name, pred in preds.items(): + y_true = np.array(preds["true"]).reshape(-1) + y_pred = np.array(pred).reshape(-1) + ConfusionMatrixDisplay.from_predictions( + y_true, + y_pred, + display_labels=eval_dataset.class_labels, + cmap="Blues", + ) + acc = accuracy_score(y_true, y_pred) + f1score = f1_score(y_true, y_pred, average="weighted") + plt.title(f"{model_name} acc: {acc*100:0.2f}% f1: {f1score:0.2f}") + print(f"{model_name} acc: {acc*100:0.2f}% f1: {f1score:0.2f}") + print( + classification_report( + y_true, y_pred, target_names=eval_dataset.class_labels + ) + ) + + filename = f"{model_name}-{args.subset}" + filepath = os.path.join(args.output_dir, filename) + plt.savefig(filepath + ".png", dpi=300) + + print("-------------------------------------------------------") diff --git a/scripts/eval_style.py b/scripts/eval_style.py new file mode 100755 index 0000000..81307be --- /dev/null +++ b/scripts/eval_style.py @@ -0,0 +1,513 @@ +import os +import sys +import glob +import json +import torch +import auraloss +import argparse +import torchaudio +import numpy as np +import pytorch_lightning as pl +from multiprocessing import process + +from deepafx_st.utils import DSPMode, seed_worker +from deepafx_st.system import System +from deepafx_st.data.dataset import AudioDataset +from deepafx_st.models.baselines import BaselineEQAndComp +from deepafx_st.metrics import ( + LoudnessError, + RMSEnergyError, + SpectralCentroidError, + CrestFactorError, + PESQ, + MelSpectralDistance, +) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "ckpt_dir", + help="Top level directory containing model checkpoints to evaluate", + ) + parser.add_argument( + "--root_dir", + default="/mnt/session_space", + help="Top level directory containing datasets.", + ) + parser.add_argument( + "--real", + help="Run real world evaluation. Otherwise synthetic.", + action="store_true", + ) + parser.add_argument( + "--dataset", + help="Dataset to evaluate on. (vctk, daps, podcast, libritts, jamendo)", + default="libritts", + ) + parser.add_argument( + "--dataset_dir", + help="Path to root dataset directory", + default="LibriTTS/train_clean_360_24000c", + ) + parser.add_argument( + "--output", + help="Path to root directory to store outputs.", + default="./", + ) + parser.add_argument( + "--length", + help="Audio example length to use in samples.", + type=int, + default=131072, + ) + parser.add_argument( + "--gpu", + help="Run models on GPU.", + action="store_true", + default=False, + ) + parser.add_argument( + "--save", + help="Save audio and plots for each example.", + action="store_true", + default=False, + ) + parser.add_argument( + "--examples", + help="Number of examples to evaluate.", + type=int, + default=1000, + ) + parser.add_argument( + "--subset", + help="Evaluate on the train, val, or test sets.", + default="val", + type=str, + ) + parser.add_argument( + "--seed", + help="Random seed.", + default=42, + ) + parser.add_argument( + "--tcn1_version", + help="Pre-trained TCN 1 model version.", + default=0, + ) + parser.add_argument( + "--tcn2_version", + help="Pre-trained TCN 2 model version.", + default=0, + ) + parser.add_argument( + "--proxy0_version", + help="Pre-trained Proxy 0 model version.", + default=0, + ) + parser.add_argument( + "--proxy2_version", + help="Pre-trained Proxy 2 model version.", + default=0, + ) + parser.add_argument( + "--proxy0m_version", + help="Pre-trained Proxy 0 model version.", + default=0, + ) + parser.add_argument( + "--proxy2m_version", + help="Pre-trained Proxy 2 model version.", + default=0, + ) + parser.add_argument( + "--spsa_version", + help="Pre-trained SPSA model version.", + default=0, + ) + parser.add_argument( + "--autodiff_version", + help="Pre-trained autodiff model version.", + default=0, + ) + parser.add_argument( + "--checkpoint_loss", + type=str, + default="val", + help="Evaluate on best 'train' or 'val' loss", + ) + parser.add_argument( + "--ext", + type=str, + default="wav", + help="Dataset audio extension.", + ) + + args = parser.parse_args() + + torch.backends.cudnn.benchmark = True + pl.seed_everything(args.seed) + + eval_dir = os.path.join(args.output, f"eval_{args.dataset}") + + if not os.path.isdir(eval_dir): + os.makedirs(eval_dir) + + if args.gpu: + device = "cuda" + else: + device = "cpu" + + # ---- setup the dataset ---- + if not args.real: + eval_dataset = AudioDataset( + args.root_dir, + subset=args.subset, + half=False, + train_frac=0.9, + length=args.length, + input_dirs=[args.dataset_dir], + buffer_size_gb=2.0, + buffer_reload_rate=args.examples, + num_examples_per_epoch=args.examples, + augmentations={}, + freq_corrupt=True, + drc_corrupt=True, + ext=args.ext, + ) + sample_rate = eval_dataset.sample_rate + # eval_dataset = torch.utils.data.DataLoader( + # eval_dataset, + # shuffle=False, + # num_workers=1, + # batch_size=1, + # worker_init_fn=seed_worker, + # pin_memory=True, + # persistent_workers=True, + # timeout=60, + # ) + else: + eval_dataset = None + print(f"Dataset fs={sample_rate}") + + models = {} + # --------------- setup pre-trained models --------------- + for processor_model_id in [ + "tcn1", + "tcn2", + "spsa", + "proxy0", + "proxy1", + "proxy2", + "proxy0m", + "proxy1m", + "proxy2m", + "autodiff", + ]: + + if processor_model_id == "proxy1": + processor_model_id_dir = "proxy0" + elif processor_model_id == "proxy1m": + processor_model_id_dir = "proxy0m" + else: + processor_model_id_dir = processor_model_id + + log_dir = os.path.join( + args.ckpt_dir, + processor_model_id_dir, + "lightning_logs", + f"""version_{getattr(args, f"{processor_model_id_dir}_version")}""", + ) + ckpt_dir = os.path.join(log_dir, "checkpoints") + pckpt_dir = os.path.join(log_dir, "pretrained_checkpoints") + ckpts = glob.glob(os.path.join(ckpt_dir, "*.ckpt")) + pckpts = glob.glob(os.path.join(pckpt_dir, "*.ckpt")) + + if len(ckpts) < 1: + print( + f"No {processor_model_id} checkpoint found in {ckpt_dir}. Skipping..." + ) + continue + else: + ckpt = [c for c in ckpts if args.checkpoint_loss in c][0] + + print(f"Loading {processor_model_id} {ckpt} on {device}...") + + dsp_mode = DSPMode.NONE + + # search for pre-trained models + if "m" in processor_model_id: + peq_ckpt = "checkpoints/proxies/jamendo/peq/lightning_logs/version_0/checkpoints/epoch=326-step=204374-val-jamendo-peq.ckpt" + comp_ckpt = "checkpoints/proxies/jamendo/comp/lightning_logs/version_0/checkpoints/epoch=274-step=171874-val-jamendo-comp.ckpt" + else: + peq_ckpt = "checkpoints/proxies/libritts/peq/lightning_logs/version_1/checkpoints/epoch=111-step=139999-val-libritts-peq.ckpt" + comp_ckpt = "checkpoints/proxies/libritts/comp/lightning_logs/version_1/checkpoints/epoch=255-step=319999-val-libritts-comp.ckpt" + + if processor_model_id == "proxy0" or processor_model_id == "proxy0m": + # peq_ckpt = [pc for pc in pckpts if "peq" in pc][0] + # comp_ckpt = [pc for pc in pckpts if "comp" in pc][0] + proxy_ckpts = [peq_ckpt, comp_ckpt] + print(f"Found {len(proxy_ckpts)}: {proxy_ckpts}") + model = ( + System.load_from_checkpoint( + ckpt, + dsp_mode=DSPMode.NONE, + proxy_ckpts=proxy_ckpts, + dsp_sample_rate=sample_rate, + strict=False, + ) + .eval() + .to(device) + ) + dsp_mode = DSPMode.NONE + processor_model_name = processor_model_id + elif processor_model_id == "proxy1" or processor_model_id == "proxy1m": + # peq_ckpt = [pc for pc in pckpts if "peq" in pc][0] + # comp_ckpt = [pc for pc in pckpts if "comp" in pc][0] + proxy_ckpts = [peq_ckpt, comp_ckpt] + print(f"Found {len(proxy_ckpts)}: {proxy_ckpts}") + model = ( + System.load_from_checkpoint( + ckpt, + use_dsp=DSPMode.INFER, + proxy_ckpts=proxy_ckpts, + dsp_sample_rate=sample_rate, + strict=False, + ) + .eval() + .to(device) + ) + processor_model_name = processor_model_id + model.hparams.dsp_mode = DSPMode.INFER + elif processor_model_id == "proxy2" or processor_model_id == "proxy2m": + # peq_ckpt = [pc for pc in pckpts if "peq" in pc][0] + # comp_ckpt = [pc for pc in pckpts if "comp" in pc][0] + proxy_ckpts = [peq_ckpt, comp_ckpt] + print(f"Found {len(proxy_ckpts)}: {proxy_ckpts}") + model = ( + System.load_from_checkpoint( + ckpt, + dsp_mode=DSPMode.INFER, + proxy_ckpts=proxy_ckpts, + dsp_sample_rate=sample_rate, + strict=False, + ) + .eval() + .to(device) + ) + processor_model_name = processor_model_id + model.hparams.dsp_mode = DSPMode.INFER + elif processor_model_id == "tcn1": + processor_model_name = "TCN1" + model = ( + System.load_from_checkpoint( + ckpt, + dsp_mode=DSPMode.INFER, + dsp_sample_rate=sample_rate, + strict=False, + ) + .eval() + .to(device) + ) + elif processor_model_id == "tcn2": + processor_model_name = "TCN2" + model = ( + System.load_from_checkpoint( + ckpt, + dsp_mode=DSPMode.INFER, + dsp_sample_rate=sample_rate, + strict=False, + ) + .eval() + .to(device) + ) + elif processor_model_id == "autodiff": + processor_model_name = "Autodiff" + model = ( + System.load_from_checkpoint( + ckpt, + dsp_mode=DSPMode.NONE, + dsp_sample_rate=sample_rate, + strict=False, + ) + .eval() + .to(device) + ) + elif processor_model_id == "spsa": + processor_model_name = "SPSA" + model = ( + System.load_from_checkpoint( + ckpt, + dsp_mode=DSPMode.NONE, + dsp_sample_rate=sample_rate, + strict=False, + spsa_parallel=False, + ) + .eval() + .to(device) + ) + else: + raise RuntimeError(f"Unexpected processor_model_id: {processor_model_id}") + + models[f"{processor_model_name} ({args.dataset})"] = model + + if len(list(models.keys())) < 1: + raise ValueError("No checkpoints found for evaluation. Exiting...") + + # create the baseline model + baseline_model = BaselineEQAndComp(sample_rate=sample_rate) + models[f"Baseline ({args.dataset})"] = baseline_model + + # ---- setup the metrics ---- + metrics = { + "PESQ": PESQ(sample_rate), + "MRSTFT": auraloss.freq.MultiResolutionSTFTLoss( + fft_sizes=[32, 128, 512, 2048, 8192, 32768], + hop_sizes=[16, 64, 256, 1024, 4096, 16384], + win_lengths=[32, 128, 512, 2048, 8192, 32768], + w_sc=0.0, + w_phs=0.0, + w_lin_mag=1.0, + w_log_mag=1.0, + ), + "MSD": MelSpectralDistance(sample_rate), + "SCE": SpectralCentroidError(sample_rate), + "CFE": CrestFactorError(), + "RMS": RMSEnergyError(), + "LUFS": LoudnessError(sample_rate), + } + metrics_dict = {"Corrupt": {}} + + # ---- start the evaluation ---- + for bidx, batch in enumerate(eval_dataset, 0): + x, y = batch + + x = x.to(device) + y = y.to(device) + + # sum to mono + x = x.mean(0, keepdim=True) + y = y.mean(0, keepdim=True) + # print(x.shape, y.shape) + + # split inputs in half for style transfer + length = x.shape[-1] + x_A = x[..., : length // 2] + x_B = x[..., length // 2 :] + + y_A = y[..., : length // 2] + y_B = y[..., length // 2 :] + + if torch.rand(1).sum() > 0.5: + y_ref = y_B + y = y_A + x = x_A + else: + y_ref = y_A + y = y_B + x = x_B + + # corrupted input peak normalized to -3 dBFS + # x_norm = x / x.abs().max() + # x_norm *= 10 ** (-12.0 / 20) + + # compute metrics with the corrupt input + for metric_name, metric in metrics.items(): + if metric_name not in metrics_dict["Corrupt"]: + metrics_dict["Corrupt"][metric_name] = [] + + try: + val = metric(x.cpu().view(1, 1, -1), y.cpu().view(1, 1, -1)) + except: + val = -1 + metrics_dict["Corrupt"][metric_name].append(val) + + outputs = {} + # now iterate over models and compute metrics + for model_name, model in models.items(): + if model_name not in metrics_dict: + metrics_dict[model_name] = {} + + # forward pass through model + with torch.no_grad(): + if "Baseline" in model_name: + y_hat = model( + x.cpu().view(1, 1, -1).clone(), + y_ref.cpu().view(1, 1, -1).clone(), + ) + else: + y_hat, p, e = model( + x.view(1, 1, -1).clone(), + y=y_ref.view(1, 1, -1).clone(), + dsp_mode=model.hparams.dsp_mode, + sample_rate=sample_rate, + analysis_length=131072, + ) + + y_hat = y_hat.cpu() + y = y.cpu() + outputs[model_name] = y_hat # store + + # compute all metrics + for metric_name, metric in metrics.items(): + if metric_name not in metrics_dict[model_name]: + metrics_dict[model_name][metric_name] = [] + + try: + val = metric(y_hat.view(1, 1, -1), y.view(1, 1, -1)) + except: + val = -1 + metrics_dict[model_name][metric_name].append(val) + + if args.save: + y_hat_filepath = os.path.join( + eval_dir, f"{bidx:04d}_{model_name}_y_hat.wav" + ) + torchaudio.save(y_hat_filepath, y_hat.view(1, -1), sample_rate) + + if args.save: + x_filepath = os.path.join(eval_dir, f"{bidx:04d}_x.wav") + y_filepath = os.path.join(eval_dir, f"{bidx:04d}_y.wav") + torchaudio.save(x_filepath, x.view(1, -1).cpu(), sample_rate) + torchaudio.save(y_filepath, y.view(1, -1).cpu(), sample_rate) + + print(bidx + 1) + for model_name, model_metrics in metrics_dict.items(): + sys.stdout.write(f"\n {model_name:22} ") + for metric_name, metric_list in model_metrics.items(): + sys.stdout.write(f"{metric_name}: {np.mean(metric_list):0.3f} ") + sys.stdout.flush() + print() + print("-" * 32) + + if bidx + 1 == args.examples: + print("Evaluation complete.") + json_metrics_dict = {} + for model_name, model_metrics in metrics_dict.items(): + if model_name not in json_metrics_dict: + json_metrics_dict[model_name] = {} + for metric_name, metric_list in model_metrics.items(): + if metric_name not in json_metrics_dict: + sanitized_metric_list = [] + for elm in metric_list: + if isinstance(elm, torch.Tensor): + sanitized_metric_list.append(elm.numpy().tolist()) + else: + sanitized_metric_list.append(elm) + + json_metrics_dict[model_name][ + metric_name + ] = sanitized_metric_list + with open( + os.path.join( + eval_dir, + f"{args.dataset.lower()}_{args.checkpoint_loss}_results.json", + ), + "w", + ) as fp: + json.dump(json_metrics_dict, fp, indent=True) + + for model_name, model in models.items(): + del model + + break diff --git a/scripts/export_ckpt.py b/scripts/export_ckpt.py new file mode 100644 index 0000000..f62e28a --- /dev/null +++ b/scripts/export_ckpt.py @@ -0,0 +1,98 @@ +import os +import sys +import glob +import torch +import pickle +import pytorch_lightning as pl +import deepafx_st + +sys.modules["deepafx_st"] = deepafx_st # patch for name change + +if __name__ == "__main__": + + checkpoint_dir = "checkpoints_fixed" + if not os.path.isdir(checkpoint_dir): + os.makedirs(checkpoint_dir) + + for experiment in ["probes", "style", "proxies"]: + + for v in [0, 1, 2]: + ckpt_paths = glob.glob( + os.path.join( + "checkpoints", + experiment, + "**", + "**", + "lightning_logs", + f"version_{v}", + "checkpoints", + "*.ckpt", + ) + ) + + for ckpt_path in ckpt_paths: + print(ckpt_path) + + processor_model_id = ckpt_path.split("/")[-5] + print(processor_model_id) + + if "m" in processor_model_id: + peq_ckpt = "checkpoints/proxies/jamendo/peq/lightning_logs/version_0/checkpoints/epoch=326-step=204374-val-jamendo-peq.ckpt" + comp_ckpt = "checkpoints/proxies/jamendo/comp/lightning_logs/version_0/checkpoints/epoch=274-step=171874-val-jamendo-comp.ckpt" + else: + peq_ckpt = "checkpoints/proxies/libritts/peq/lightning_logs/version_1/checkpoints/epoch=111-step=139999-val-libritts-peq.ckpt" + comp_ckpt = "checkpoints/proxies/libritts/comp/lightning_logs/version_1/checkpoints/epoch=255-step=319999-val-libritts-comp.ckpt" + + proxy_ckpts = [peq_ckpt, comp_ckpt] + + if experiment == "style": + model = deepafx_st.system.System.load_from_checkpoint( + ckpt_path, + proxy_ckpts=proxy_ckpts, + strict=False, + ) + elif experiment == "probes": + if "speech" in ckpt_path: + deepafx_st_autodiff_ckpt = "checkpoints/style/libritts/autodiff/lightning_logs/version_1/checkpoints/epoch=367-step=1226911-val-libritts-autodiff.ckpt" + deepafx_st_spsa_ckpt = "checkpoints/style/libritts/spsa/lightning_logs/version_2/checkpoints/epoch=367-step=1226911-val-libritts-spsa.ckpt" + deepafx_st_proxy0_ckpt = "checkpoints/style/libritts/proxy0/lightning_logs/version_0/checkpoints/epoch=327-step=1093551-val-libritts-proxy0.ckpt" + elif "music" in ckpt_path: + deepafx_st_autodiff_ckpt = "checkpoints/style/jamendo/autodiff/lightning_logs/version_0/checkpoints/epoch=362-step=1210241-val-jamendo-autodiff.ckpt" + deepafx_st_spsa_ckpt = "checkpoints/style/jamendo/spsa/lightning_logs/version_0/checkpoints/epoch=362-step=1210241-val-jamendo-spsa.ckpt" + deepafx_st_proxy0_ckpt = "checkpoints/style/jamendo/proxy0/lightning_logs/version_0/checkpoints/epoch=362-step=1210241-val-jamendo-proxy0.ckpt" + + model = ( + deepafx_st.probes.probe_system.ProbeSystem.load_from_checkpoint( + ckpt_path, + strict=False, + deepafx_st_autodiff_ckpt=deepafx_st_autodiff_ckpt, + deepafx_st_spsa_ckpt=deepafx_st_spsa_ckpt, + deepafx_st_proxy0_ckpt=deepafx_st_proxy0_ckpt, + ) + ) + elif experiment == "proxies": + model = deepafx_st.processors.proxy.proxy_system.ProxySystem.load_from_checkpoint( + ckpt_path, + strict=False, + ) + else: + raise RuntimeError(f"Invalid experiment: {experiment}") + + ckpt_path_dirname = os.path.dirname(ckpt_path) + ckpt_path_basename = os.path.basename(ckpt_path) + ckpt_path_fixed = ckpt_path_dirname.replace( + "checkpoints", checkpoint_dir, 1 + ) + + if not os.path.isdir(ckpt_path_fixed): + os.makedirs(ckpt_path_fixed) + + ckpt_path_fixed = os.path.join(ckpt_path_fixed, ckpt_path_basename) + print(ckpt_path_fixed) + + trainer = pl.Trainer() + trainer.model = model + trainer.save_checkpoint(ckpt_path_fixed) + + del model + del trainer diff --git a/scripts/generate_styles.py b/scripts/generate_styles.py new file mode 100644 index 0000000..6b431e6 --- /dev/null +++ b/scripts/generate_styles.py @@ -0,0 +1,285 @@ +import os +import sys +import glob +import torch +import random +import argparse +import torchaudio +import numpy as np +from tqdm import tqdm +from itertools import repeat +import pytorch_lightning as pl + +from deepafx_st.processors.dsp.peq import parametric_eq +from deepafx_st.processors.dsp.compressor import compressor + + +def get_random_patch(x, sample_rate, length_samples): + length = int(length_samples) + silent = True + while silent: + start_idx = np.random.randint(0, x.shape[-1] - length - 1) + stop_idx = start_idx + length + x_crop = x[:, start_idx:stop_idx] + + # check for silence + frames = length // sample_rate + silent_frames = [] + for n in range(frames): + start_idx = int(n * sample_rate) + stop_idx = start_idx + sample_rate + x_frame = x_crop[:, start_idx:stop_idx] + if (x_frame ** 2).mean() > 3e-4: + silent_frames.append(False) + else: + silent_frames.append(True) + silent = True if any(silent_frames) else False + + x_crop /= x_crop.abs().max() + + return x_crop + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "audio_dir", + help="Path to directory containing source audio.", + type=str, + ) + parser.add_argument( + "--output_dir", + help="Path to directory to save output audio.", + type=str, + ) + parser.add_argument( + "--length_samples", + help="Length of the output audio examples in samples.", + type=float, + default=131072, + ) + parser.add_argument( + "--lookahead_samples", + help="Length of the processing lookahead.", + type=float, + default=16384, + ) + parser.add_argument( + "--num", + help="Number of examples to generate from each style.", + type=int, + default=1000, + ) + parser.add_argument( + "--ext", + help="Expected file extension for audio files.", + type=str, + default="wav", + ) + + args = parser.parse_args() + pl.seed_everything(42) + + # find all audio files in directory + dataset_name = os.path.basename(args.audio_dir) + audio_filepaths = glob.glob(os.path.join(args.audio_dir, f"*.{args.ext}")) + print(f"Found {len(audio_filepaths)} audio files.") + if len(audio_filepaths) < 1: + raise RuntimeError(f"No audio files found in {args.audio_dir}.") + + # split files into three subsets (train, val, test) + random.shuffle(audio_filepaths) + train_idx = int(len(audio_filepaths) * 0.6) + val_idx = int(len(audio_filepaths) * 0.2) + train_subset = audio_filepaths[:train_idx] + val_subset = audio_filepaths[train_idx : train_idx + val_idx] + test_subset = audio_filepaths[train_idx + val_idx :] + print( + f"Train ({len(train_subset)}) Val ({len(val_subset)}) Test ({len(test_subset)})" + ) + + subsets = { + "train": train_subset, + "val": val_subset, + "test": test_subset, + } + + # There are five different pre-defined styles + styles = [ + "neutral", # 1. Neutral - Presecnce + Normal compression + "broadcast", # 2. Broadcast - More Presence + Aggressive compression + "telephone", # 3. Telephone - Bandpass effect + compressor + "bright", # 4. Bright - Strong high-shelf filter + "warm", # 5. Warm - Bass boost with high-shelf decrease + ] + + for style_idx, style in enumerate(styles): + # reset the seed for each style + # pl.seed_everything(42) + print(f"Generating {style} ({style_idx+1}/{len(styles)}) style examples...") + # generate examples + subset_index_offset = 0 + + for subset_name, subset_filepaths in subsets.items(): + # create output directory if needed + style_dir = os.path.join(args.output_dir, subset_name, style) + if not os.path.isdir(style_dir): + os.makedirs(style_dir) + + # futher split the tracks for each style + tracks_per_style = len(subset_filepaths) // len(styles) + start_idx = style_idx * tracks_per_style + stop_idx = start_idx + tracks_per_style + print(start_idx, stop_idx) + style_subset_filepaths = subset_filepaths[start_idx:stop_idx] + + if subset_name == "train": + num_examples = int(args.num * 0.6) + else: + num_examples = int(args.num * 0.2) + + style_subset_filepaths = style_subset_filepaths * len(styles) + + # copy style subset filepaths to create desired number of examples + if num_examples > len(style_subset_filepaths): + style_subset_filepaths *= int(num_examples // len(style_subset_filepaths)) + else: + style_subset_filepaths = style_subset_filepaths[:num_examples] + + for n, input_filepath in enumerate(tqdm(style_subset_filepaths, ncols=120)): + x, sr = torchaudio.load(input_filepath) # load file + chs, samp = x.size() + + # get random audio patch + x = get_random_patch( + x, + sr, + args.length_samples + args.lookahead_samples, + ) + + # add some randomized headroom + headroom_db = (torch.rand(1) * 6) + 3 + x = x / x.abs().max() + x *= 10 ** (-headroom_db / 20.0) + + # apply selected style + if style == "neutral": + # ----------- compressor ------------- + threshold = -((torch.rand(1) * 10.0).numpy().squeeze() + 20.0) + attack_sec = (torch.rand(1) * 0.020).numpy().squeeze() + 0.050 + release_sec = (torch.rand(1) * 0.200).numpy().squeeze() + 0.100 + ratio = (torch.rand(1) * 0.5).numpy().squeeze() + 1.5 + # ----------- parametric eq ----------- + low_shelf_gain_db = (torch.rand(1) * 2.0).numpy().squeeze() + 1.0 + low_shelf_cutoff_freq = (torch.rand(1) * 120).numpy().squeeze() + 80 + first_band_gain_db = 0.0 + first_band_cutoff_freq = 1000.0 + high_shelf_gain_db = (torch.rand(1) * 2.0).numpy().squeeze() + 1.0 + high_shelf_cutoff_freq = ( + torch.rand(1) * 2000 + ).numpy().squeeze() + 6000 + elif style == "broadcast": + # ----------- compressor ------------- + threshold = -((torch.rand(1) * 10).numpy().squeeze() + 40) + attack_sec = (torch.rand(1) * 0.025).numpy().squeeze() + 0.005 + release_sec = (torch.rand(1) * 0.100).numpy().squeeze() + 0.050 + ratio = (torch.rand(1) * 2.0).numpy().squeeze() + 3.0 + # ----------- parametric eq ----------- + low_shelf_gain_db = (torch.rand(1) * 4.0).numpy().squeeze() + 2.0 + low_shelf_cutoff_freq = (torch.rand(1) * 120).numpy().squeeze() + 80 + first_band_gain_db = 0.0 + first_band_cutoff_freq = 1000.0 + high_shelf_gain_db = (torch.rand(1) * 4.0).numpy().squeeze() + 2.0 + high_shelf_cutoff_freq = ( + torch.rand(1) * 2000 + ).numpy().squeeze() + 6000 + elif style == "telephone": + # ----------- compressor ------------- + threshold = -((torch.rand(1) * 20.0).numpy().squeeze() + 20) + attack_sec = (torch.rand(1) * 0.005).numpy().squeeze() + 0.001 + release_sec = (torch.rand(1) * 0.050).numpy().squeeze() + 0.010 + ratio = (torch.rand(1) * 1.5).numpy().squeeze() + 1.5 + # ----------- parametric eq ----------- + low_shelf_gain_db = -((torch.rand(1) * 6).numpy().squeeze() + 20) + low_shelf_cutoff_freq = ( + torch.rand(1) * 200 + ).numpy().squeeze() + 200 + first_band_gain_db = (torch.rand(1) * 4).numpy().squeeze() + 12 + first_band_cutoff_freq = ( + torch.rand(1) * 1000 + ).numpy().squeeze() + 1000 + high_shelf_gain_db = -((torch.rand(1) * 6).numpy().squeeze() + 20) + high_shelf_cutoff_freq = ( + torch.rand(1) * 2000 + ).numpy().squeeze() + 4000 + elif style == "bright": + # ----------- compressor ------------- + ratio = 1.0 + threshold = 0.0 + attack_sec = 0.050 + release_sec = 0.250 + # ----------- parametric eq ----------- + low_shelf_gain_db = -((torch.rand(1) * 6).numpy().squeeze() + 20) + low_shelf_cutoff_freq = ( + torch.rand(1) * 200 + ).numpy().squeeze() + 200 + first_band_gain_db = 0.0 + first_band_cutoff_freq = 1000.0 + high_shelf_gain_db = (torch.rand(1) * 6).numpy().squeeze() + 20 + high_shelf_cutoff_freq = ( + torch.rand(1) * 2000 + ).numpy().squeeze() + 8000 + elif style == "warm": + # ----------- compressor ------------- + ratio = 1.0 + threshold = 0.0 + attack_sec = 0.050 + release_sec = 0.250 + # ----------- parametric eq ----------- + low_shelf_gain_db = (torch.rand(1) * 6).numpy().squeeze() + 20 + low_shelf_cutoff_freq = ( + torch.rand(1) * 200 + ).numpy().squeeze() + 200 + first_band_gain_db = 0.0 + first_band_cutoff_freq = 1000.0 + high_shelf_gain_db = -(torch.rand(1) * 6).numpy().squeeze() + 20 + high_shelf_cutoff_freq = ( + torch.rand(1) * 2000 + ).numpy().squeeze() + 8000 + else: + raise RuntimeError(f"Invalid style: {style}.") + + # apply effects with parameters + x_out = torch.zeros(x.shape).type_as(x) + for ch_idx in range(chs): + x_peq_ch = parametric_eq( + x[ch_idx, :].view(-1).numpy(), + sr, + low_shelf_gain_dB=low_shelf_gain_db, + low_shelf_cutoff_freq=low_shelf_cutoff_freq, + first_band_gain_dB=first_band_gain_db, + first_band_cutoff_freq=first_band_cutoff_freq, + high_shelf_gain_dB=high_shelf_gain_db, + high_shelf_cutoff_freq=high_shelf_cutoff_freq, + ) + x_comp_ch = compressor( + x_peq_ch, + sr, + threshold=threshold, + ratio=ratio, + attack_time=attack_sec, + release_time=release_sec, + ) + x_out[ch_idx, :] = torch.tensor(x_comp_ch) + + # crop out lookahead + x_out = x_out.view(chs, -1) + x_out = x_out[:, args.lookahead_samples :] + + # peak normalize + x_out /= x_out.abs().max() + + output_filename = f"{n+subset_index_offset:03d}_{style}_{dataset_name}_{subset_name}.wav" + output_filepath = os.path.join(style_dir, output_filename) + torchaudio.save(output_filepath, x_out, sr) + subset_index_offset += n + 1 diff --git a/scripts/print_results.py b/scripts/print_results.py new file mode 100644 index 0000000..1303ddf --- /dev/null +++ b/scripts/print_results.py @@ -0,0 +1,19 @@ +import os +import sys +import json +import argparse +import numpy as np + +parser = argparse.ArgumentParser() +parser.add_argument("input", help="Path to results .json file.", type=str) +args = parser.parse_args() + +with open(args.input) as fp: + results = json.load(fp) + +for method, metrics in results.items(): + sys.stdout.write(f"{method:20s} ") + for metric_name, values in metrics.items(): + sys.stdout.write(f"{np.mean(values):0.3f} & ") + print() + sys.stdout.flush() diff --git a/scripts/process.py b/scripts/process.py new file mode 100755 index 0000000..c56db39 --- /dev/null +++ b/scripts/process.py @@ -0,0 +1,170 @@ +import os +import glob +import torch +import resampy +import argparse +import torchaudio +import numpy as np + +from deepafx_st.utils import DSPMode +from deepafx_st.utils import count_parameters +from deepafx_st.system import System + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "-i", + "--input", + help="Path to audio file to process.", + type=str, + ) + parser.add_argument( + "-r", + "--reference", + help="Path to reference audio file.", + type=str, + ) + parser.add_argument( + "-c", + "--ckpt", + help="Path to pre-trained checkpoint.", + type=str, + ) + parser.add_argument( + "--gpu", + help="Run inference on GPU. (Otherwise CPU).", + action="store_true", + ) + parser.add_argument( + "--time", + help="Execute inference 100x in a loop and time the model.", + action="store_true", + ) + parser.add_argument( + "--no_dsp", + help="Only use neural networks for proxy.", + action="store_true", + ) + + args = parser.parse_args() + + # load the model + if "proxy" in args.ckpt: + logdir = os.path.dirname(os.path.dirname(args.ckpt)) + # Assumes speech proxies in specific location + pckpts = 'checkpoints' + if 'proxy0m' in logdir or 'proxy2m' in logdir: + peq_ckpt = os.path.join(pckpts, "proxies/jamendo/peq/lightning_logs/version_0/checkpoints/epoch=326-step=204374-val-jamendo-peq.ckpt" ) + comp_ckpt = os.path.join(pckpts, "proxies/jamendo/comp/lightning_logs/version_0/checkpoints/epoch=274-step=171874-val-jamendo-comp.ckpt" ) + else: + peq_ckpt = os.path.join(pckpts, "proxies/libritts/peq/lightning_logs/version_1/checkpoints/epoch=111-step=139999-val-libritts-peq.ckpt" ) + comp_ckpt = os.path.join(pckpts, "proxies/libritts/comp/lightning_logs/version_1/checkpoints/epoch=255-step=319999-val-libritts-comp.ckpt" ) + + proxy_ckpts = [peq_ckpt, comp_ckpt] + print(f"Found {len(proxy_ckpts)}: {proxy_ckpts}") + dsp_mode = DSPMode.INFER + if args.no_dsp: + dsp_mode = DSPMode.NONE + + system = System.load_from_checkpoint( + args.ckpt, dsp_mode=dsp_mode, proxy_ckpts=proxy_ckpts + ).eval() + else: + use_dsp = False + system = System.load_from_checkpoint( + args.ckpt, dsp_mode=DSPMode.NONE, batch_size=1 + ).eval() + + if args.gpu: + system.to("cuda") + + # load audio data + x, x_sr = torchaudio.load(args.input) + r, r_sr = torchaudio.load(args.reference) + + # resample if needed + if x_sr != 24000: + print("Resampling to 24000 Hz...") + x_24000 = torch.tensor(resampy.resample(x.view(-1).numpy(), x_sr, 24000)) + x_24000 = x_24000.view(1, -1) + else: + x_24000 = x + + if r_sr != 24000: + print("Resampling to 24000 Hz...") + r_24000 = torch.tensor(resampy.resample(r.view(-1).numpy(), r_sr, 24000)) + r_24000 = r_24000.view(1, -1) + else: + r_24000 = r + + # peak normalize to -12 dBFS + x_24000 = x_24000[0:1, : 24000 * 5] + x_24000 /= x_24000.abs().max() + x_24000 *= 10 ** (-12 / 20.0) + x_24000 = x_24000.view(1, 1, -1) + + # peak normalize to -12 dBFS + r_24000 = r_24000[0:1, : 24000 * 5] + r_24000 /= r_24000.abs().max() + r_24000 *= 10 ** (-12 / 20.0) + r_24000 = r_24000.view(1, 1, -1) + + if args.gpu: + x_24000 = x_24000.to("cuda") + r_24000 = r_24000.to("cuda") + + if args.time: + torch.set_num_threads(1) + # Warm up + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dummy_input = torch.randn(1, 3, 224, 224, dtype=torch.float).to(device) + + # pass audio through model + times = [] + num_times = 13 + warm_up = 3 + with torch.no_grad(): + for i in range(num_times): + print("iteration", i) + y_hat, p, e, encoder_time_sec, dsp_time_sec = system( + x_24000, dsp_mode=DSPMode.INFER, time_it=args.time + ) + if i >= warm_up: + times.append((encoder_time_sec, dsp_time_sec)) + + audio_len_sec = x_24000.shape[-1] / 24000.0 + ave_times = np.mean(np.array(times), axis=0) + + print("**********Config**********") + print("gpu", args.gpu) + print("dsp", system.hparams.processor_model) + print("ave_times", ave_times, "length", audio_len_sec) + print( + "rtf (encoder, dsp)", + ave_times[0] / audio_len_sec, + ave_times[1] / audio_len_sec, + ) + print("#parameters", count_parameters(system.processor, trainable_only=False)) + + else: + # pass audio through model + with torch.no_grad(): + y_hat, p, e = system(x_24000, r_24000) + + y_hat = y_hat.view(1, -1) + y_hat /= y_hat.abs().max() + x_24000 /= x_24000.abs().max() + + # save to disk + dirname = os.path.dirname(args.input) + filename = os.path.basename(args.input).replace(".wav", "") + reference = os.path.basename(args.reference).replace(".wav", "") + out_filepath = os.path.join(dirname, f"{filename}_out_ref={reference}.wav") + in_filepath = os.path.join(dirname, f"{filename}_in.wav") + print(f"Saved output to {out_filepath}") + torchaudio.save(out_filepath, y_hat.cpu().view(1, -1), 24000) + torchaudio.save(in_filepath, x_24000.cpu().view(1, -1), 24000) + + system.shutdown() + diff --git a/scripts/run_generate_styles.sh b/scripts/run_generate_styles.sh new file mode 100755 index 0000000..71b4b16 --- /dev/null +++ b/scripts/run_generate_styles.sh @@ -0,0 +1,13 @@ +root_dir="/path/to/data" # path to audio datasets + +python scripts/generate_styles.py \ +"$root_dir/daps_24000/cleanraw" \ +--output_dir "$root_dir/daps_24000_styles_100" \ +--length_samples 131072 \ +--num 100 \ + +python scripts/generate_styles.py \ +"$root_dir/musdb18_44100" \ +--output_dir "$root_dir/musdb18_44100_styles_100" \ +--length_samples 262144 \ +--num 100 \ diff --git a/scripts/run_style_case_study.sh b/scripts/run_style_case_study.sh new file mode 100755 index 0000000..07be85e --- /dev/null +++ b/scripts/run_style_case_study.sh @@ -0,0 +1,23 @@ +CUDA_VISIBLE_DEVICES=0 python scripts/style_case_study.py \ +--ckpt_paths \ +"/import/c4dm-datasets/deepafx_st/logs/style/libritts/spsa/lightning_logs/version_2/checkpoints/epoch=367-step=1226911-val-libritts-spsa.ckpt" \ +"/import/c4dm-datasets/deepafx_st/logs/style/libritts/autodiff/lightning_logs/version_1/checkpoints/epoch=367-step=1226911-val-libritts-autodiff.ckpt" \ +"/import/c4dm-datasets/deepafx_st/logs/style/libritts/proxy0/lightning_logs/version_0/checkpoints/epoch=327-step=1093551-val-libritts-proxy0.ckpt" \ +--style_audio "/import/c4dm-datasets/deepafx_st/daps_24000_styles_100/train" \ +--output_dir "/import/c4dm-datasets/deepafx_st/style_case_study_daps" \ +--num_examples 1000 \ +--gpu \ +--save \ +--plot \ + +#CUDA_VISIBLE_DEVICES=1 python scripts/style_case_study.py \ +#--ckpt_paths \ +#"/import/c4dm-datasets/deepafx_st/logs_jamendo/style/jamendo/autodiff/lightning_logs/version_0/checkpoints/epoch=362-step=1210241-val-jamendo-autodiff.ckpt" \ +#"/import/c4dm-datasets/deepafx_st/logs_jamendo/style/jamendo/spsa/lightning_logs/version_0/checkpoints/epoch=362-step=1210241-val-jamendo-spsa.ckpt" \ +#"/import/c4dm-datasets/deepafx_st/logs_jamendo/style/jamendo/proxy0/lightning_logs/version_0/checkpoints/epoch=362-step=1210241-val-jamendo-proxy0.ckpt" \ +#--style_audio "/import/c4dm-datasets/deepafx_st/musdb18_44100_styles_100/train" \ +#--output_dir "/import/c4dm-datasets/deepafx_st/style_case_study_musdb18" \ +#--sample_rate 44100 \ +#--gpu \ +#--save \ +#--plot \ \ No newline at end of file diff --git a/scripts/run_style_interpolation.sh b/scripts/run_style_interpolation.sh new file mode 100755 index 0000000..b82625f --- /dev/null +++ b/scripts/run_style_interpolation.sh @@ -0,0 +1,11 @@ +python scripts/style_interpolation.py \ +--ckpt_path "/import/c4dm-datasets/deepafx_st/logs/style/libritts/spsa/lightning_logs/version_2/checkpoints/epoch=367-step=1226911-val-libritts-spsa.ckpt" \ +--input_audio "/import/c4dm-datasets/deepafx_st/daps_24000_styles/val/telephone/799_telephone_cleanraw_val.wav" \ +--input_length 10 \ +--style_a "/import/c4dm-datasets/deepafx_st/daps_24000_styles/val/warm/822_warm_cleanraw_val.wav" \ +--style_a_name "warm" \ +--style_b "/import/c4dm-datasets/deepafx_st/daps_24000_styles/val/radio/806_radio_cleanraw_val.wav" \ +--style_b_name "radio" \ +--save \ + +# "/import/c4dm-datasets/deepafx_st/vctk_24000/p314.wav" \ No newline at end of file diff --git a/scripts/run_style_transfer.sh b/scripts/run_style_transfer.sh new file mode 100755 index 0000000..d4a1a9c --- /dev/null +++ b/scripts/run_style_transfer.sh @@ -0,0 +1,24 @@ +python scripts/style_transfer.py \ +--ckpt_path "/import/c4dm-datasets/deepafx_st/logs/style/libritts/spsa/lightning_logs/version_2/checkpoints/epoch=367-step=1226911-val-libritts-spsa.ckpt" \ +--input_filepaths \ +"/import/c4dm-datasets/deepafx_st/vctk_24000/p314.wav" \ +--style_filepaths \ +"/import/c4dm-datasets/deepafx_st/daps_24000_styles_100/val/broadcast/062_broadcast_cleanraw_val.wav" \ +--save \ +--modify_input \ +--output_dir style_transfer_modify \ + +#"examples/obama.wav" \ +#"examples/60min_presenter.wav" \ + +#python scripts/style_transfer.py \ +#--ckpt_path "/import/c4dm-datasets/deepafx_st/logs/style/libritts/spsa/lightning_logs/version_2/checkpoints/epoch=367-step=1226911-val-libritts-spsa.ckpt" \ +#--input_filepaths \ +#"/import/c4dm-datasets/deepafx_st/vctk_24000/p314.wav" \ +#--style_filepaths \ +#"/import/c4dm-datasets/deepafx_st/daps_24000_styles_100/val/telephone/066_telephone_cleanraw_val.wav" \ +#"/import/c4dm-datasets/deepafx_st/daps_24000_styles_100/val/bright/061_bright_cleanraw_val.wav" \ +#"/import/c4dm-datasets/deepafx_st/daps_24000_styles_100/val/warm/067_warm_cleanraw_val.wav" \ +#"/import/c4dm-datasets/deepafx_st/daps_24000_styles_100/val/broadcast/060_broadcast_cleanraw_val.wav" \ +#"/import/c4dm-datasets/deepafx_st/daps_24000_styles_100/val/neutral/069_neutral_cleanraw_val.wav" \ +#--save \ \ No newline at end of file diff --git a/scripts/run_style_transfer_bulk.sh b/scripts/run_style_transfer_bulk.sh new file mode 100755 index 0000000..8b359fb --- /dev/null +++ b/scripts/run_style_transfer_bulk.sh @@ -0,0 +1,9 @@ +python scripts/style_transfer_bulk.py \ +--ckpt_path "/import/c4dm-datasets/deepafx_st/logs/style/libritts/spsa/lightning_logs/version_2/checkpoints/epoch=367-step=1226911-val-libritts-spsa.ckpt" \ +--input_filepaths \ +"/import/c4dm-datasets/deepafx_st/daps_24000_styles/val/telephone/810_telephone_cleanraw_val.wav" \ +"/import/c4dm-datasets/deepafx_st/daps_24000_styles/val/bright/801_bright_cleanraw_val.wav" \ +"/import/c4dm-datasets/deepafx_st/daps_24000_styles/val/warm/807_warm_cleanraw_val.wav" \ +"/import/c4dm-datasets/deepafx_st/daps_24000_styles/val/radio/803_radio_cleanraw_val.wav" \ +"/import/c4dm-datasets/deepafx_st/daps_24000_styles/val/podcast/804_podcast_cleanraw_val.wav" \ +--save \ \ No newline at end of file diff --git a/scripts/style_case_study.py b/scripts/style_case_study.py new file mode 100644 index 0000000..25c495e --- /dev/null +++ b/scripts/style_case_study.py @@ -0,0 +1,543 @@ +import os +import sys +import glob +import torch +import auraloss +import itertools +import argparse +import torchaudio +import numpy as np +import scipy.signal +import pytorch_lightning as pl +import matplotlib.pyplot as plt + +from tqdm import tqdm + +from deepafx_st.utils import DSPMode +from deepafx_st.utils import get_random_patch +from deepafx_st.system import System +from deepafx_st.models.baselines import BaselineEQAndComp +from deepafx_st.processors.dsp.peq import biqaud +from deepafx_st.metrics import ( + LoudnessError, + RMSEnergyError, + SpectralCentroidError, + CrestFactorError, + PESQ, + MelSpectralDistance, +) + +colors = { + "neutral": (70 / 255, 181 / 255, 211 / 255), # neutral + "broadcast": (52 / 255, 57 / 255, 60 / 255), # broadcast + "telephone": (219 / 255, 73 / 255, 76 / 255), # telephone + "warm": (235 / 255, 164 / 255, 50 / 255), # warm + "bright": (134 / 255, 170 / 255, 109 / 255), # bright +} + + +def plot_peq_response( + p_peq_denorm, + sr, + ax=None, + label=None, + color=None, + points=False, + center_line=False, +): + + ls_gain = p_peq_denorm[0] + ls_freq = p_peq_denorm[1] + ls_q = p_peq_denorm[2] + b0, a0 = biqaud(ls_gain, ls_freq, ls_q, sr, filter_type="low_shelf") + sos0 = np.concatenate((b0, a0)) + + f1_gain = p_peq_denorm[3] + f1_freq = p_peq_denorm[4] + f1_q = p_peq_denorm[5] + b1, a1 = biqaud(f1_gain, f1_freq, f1_q, sr, filter_type="peaking") + sos1 = np.concatenate((b1, a1)) + + f2_gain = p_peq_denorm[6] + f2_freq = p_peq_denorm[7] + f2_q = p_peq_denorm[8] + b2, a2 = biqaud(f2_gain, f2_freq, f2_q, sr, filter_type="peaking") + sos2 = np.concatenate((b2, a2)) + + f3_gain = p_peq_denorm[9] + f3_freq = p_peq_denorm[10] + f3_q = p_peq_denorm[11] + b3, a3 = biqaud(f3_gain, f3_freq, f3_q, sr, filter_type="peaking") + sos3 = np.concatenate((b3, a3)) + + f4_gain = p_peq_denorm[12] + f4_freq = p_peq_denorm[13] + f4_q = p_peq_denorm[14] + b4, a4 = biqaud(f4_gain, f4_freq, f4_q, sr, filter_type="peaking") + sos4 = np.concatenate((b4, a4)) + + hs_gain = p_peq_denorm[15] + hs_freq = p_peq_denorm[16] + hs_q = p_peq_denorm[17] + b5, a5 = biqaud(hs_gain, hs_freq, hs_q, sr, filter_type="high_shelf") + sos5 = np.concatenate((b5, a5)) + + sos = [sos0, sos1, sos2, sos3, sos4, sos5] + sos = np.array(sos) + # print(sos.shape) + # print(sos) + + # measure freq response + w, h = scipy.signal.sosfreqz(sos, fs=sr, worN=2048) + + if ax is None: + fig, axs = plt.subplots() + + if center_line: + ax.plot(w, np.zeros(w.shape), color="lightgray") + + ax.plot(w, 20 * np.log10(np.abs(h)), label=label, color=color) + if points: + ax.scatter(ls_freq, ls_gain, color=color) + ax.scatter(f1_freq, f1_gain, color=color) + ax.scatter(f2_freq, f2_gain, color=color) + ax.scatter(f3_freq, f3_gain, color=color) + ax.scatter(f4_freq, f4_gain, color=color) + ax.scatter(hs_freq, hs_gain, color=color) + + +def plot_comp_response( + p_comp_denorm, + sr, + ax=None, + label=None, + color=None, + center_line=False, +): + + # get parameters + threshold = p_comp_denorm[0] + ratio = p_comp_denorm[1] + attack_ms = p_comp_denorm[2] * 1000 + release_ms = p_comp_denorm[3] * 1000 + knee_db = p_comp_denorm[4] + makeup_db = p_comp_denorm[5] + + # print(knee_db) + + x = np.linspace(-80, 0) # input level + y = np.zeros(x.shape) # output level + + idx = np.where((2 * (x - threshold)) < -knee_db) + y[idx] = x[idx] + + idx = np.where((2 * np.abs(x - threshold)) <= knee_db) + y[idx] = x[idx] + ( + (1 / ratio - 1) * (((x[idx] - threshold + (knee_db / 2))) ** 2) + ) / (2 * knee_db) + + idx = np.where((2 * (x - threshold)) > knee_db) + y[idx] = threshold + ((x[idx] - threshold) / (ratio)) + + text_height = threshold + ((0 - threshold) / (ratio)) + + # plot the first part of the line + ax.plot(x, y, label=label, color=color) + if center_line: + ax.plot(x, x, color="lightgray", linestyle="--") + ax.text(0, text_height, f"{threshold:0.1f} dB {ratio:0.1f}:1") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--ckpt_paths", + type=str, + help="Path to pre-trained system checkpoints.", + nargs="+", + ) + parser.add_argument( + "--num_examples", + type=int, + default=10, + help="Number of style transfer to perform for each style transfer example.", + ) + parser.add_argument( + "--style_audio", + help="List of style audio filepaths.", + type=str, + ) + parser.add_argument( + "--gpu", + help="Run System on GPU.", + action="store_true", + ) + parser.add_argument( + "--save", + help="Save audio examples.", + action="store_true", + ) + parser.add_argument( + "--plot", + help="Save parameter prediction plots.", + action="store_true", + ) + parser.add_argument( + "--output_dir", + type=str, + help="Path to save audio outputs.", + default="style_case_study", + ) + parser.add_argument( + "--sample_rate", + type=int, + help="Input audio sample rate.", + default=24000, + ) + parser.add_argument( + "--seed", + type=int, + help="Random seed.", + default=16, + ) + + args = parser.parse_args() + pl.seed_everything(args.seed) + device = "cuda" if args.gpu else "cpu" + metrics_dict = {"Corrupt": {}, "Baseline": {}} + + # --------------- setup pre-trained model --------------- + models = {} + peq_ckpt = "/import/c4dm-datasets/deepafx_st/logs/proxies/libritts/peq/lightning_logs/version_1/checkpoints/epoch=111-step=139999-val-libritts-peq.ckpt" + comp_ckpt = "/import/c4dm-datasets/deepafx_st/logs/proxies/libritts/comp/lightning_logs/version_1/checkpoints/epoch=255-step=319999-val-libritts-comp.ckpt" + + for ckpt_path in args.ckpt_paths: + model_name = os.path.basename(ckpt_path).replace(".ckpt", "") + + if "proxy" in model_name: + use_dsp = DSPMode.INFER + else: + use_dsp = DSPMode.NONE + + system = System.load_from_checkpoint( + ckpt_path, + use_dsp=use_dsp, + batch_size=1, + spsa_parallel=False, + proxy_ckpts=[peq_ckpt, comp_ckpt], + strict=False, + ) + system.eval() + if args.gpu: + system.to("cuda") + models[model_name] = system + + metrics_dict[model_name] = {} + + # create the baseline model + baseline_model = BaselineEQAndComp(sample_rate=args.sample_rate) + + # ---- setup the metrics ---- + metrics = { + # "PESQ": PESQ(sample_rate), + # "MRSTFT": auraloss.freq.MultiResolutionSTFTLoss( + # fft_sizes=[32, 128, 512, 2048, 8192, 32768], + # hop_sizes=[16, 64, 256, 1024, 4096, 16384], + # win_lengths=[32, 128, 512, 2048, 8192, 32768], + # w_sc=0.0, + # w_phs=0.0, + # w_lin_mag=1.0, + # w_log_mag=1.0, + # ), + "MSD": MelSpectralDistance(args.sample_rate), + "SCE": SpectralCentroidError(args.sample_rate), + # "CFE": CrestFactorError(), + "RMS": RMSEnergyError(), + "LUFS": LoudnessError(args.sample_rate), + } + + # ----------- load and pre-process audio ------------- + style_dirs = glob.glob(os.path.join(args.style_audio, "*")) + style_dirs = [sd for sd in style_dirs if os.path.isdir(sd)] + + transfers = itertools.product(style_dirs, style_dirs) + + metrics_overall = {} + for transfer in transfers: + input_style_dir, target_style_dir = transfer + input_style = os.path.basename(input_style_dir) + target_style = os.path.basename(target_style_dir) + style_transfer_name = f"{input_style}-->{target_style}" + transfer_output_dir = os.path.join(args.output_dir, style_transfer_name) + if not os.path.isdir(transfer_output_dir): + os.makedirs(transfer_output_dir) + print(style_transfer_name) + metrics_dict[style_transfer_name] = {} + + # get all examples from the input style + input_filepaths = glob.glob(os.path.join(input_style_dir, "*.wav")) + + # get all examples from the target style + target_filepaths = glob.glob(os.path.join(target_style_dir, "*.wav")) + + for n in tqdm(range(args.num_examples), ncols=80): + input_filepath = np.random.choice(input_filepaths) + target_filepath = np.random.choice(target_filepaths) + input_name = os.path.basename(input_filepath).replace("*.wav", "") + target_name = os.path.basename(target_filepath).replace("*.wav", "") + x, x_sr = torchaudio.load(input_filepath) + y, y_sr = torchaudio.load(target_filepath) + + chs, samp = x.size() + + # normalize + x_norm = x / x.abs().max() + x_norm *= 10 ** (-12.0 / 20.0) + y_norm = y / y.abs().max() + y_norm *= 10 ** (-12.0 / 20.0) + + if args.gpu: + x_norm = x_norm.to("cuda") + y_norm = y_norm.to("cuda") + + # ------------------ compute model metrics ------------------ + # run our models + model_outputs = {} + model_params = {} + for model_name, system in models.items(): + with torch.no_grad(): + y_hat_system, p, e_system = system( + x_norm.view(1, 1, -1), + y=y_norm.view(1, 1, -1), + dsp_mode=system.hparams.use_dsp, + analysis_length=131072, + sample_rate=x_sr, + ) + + short_model_name = model_name.split("-")[-1] + + # normalize + # y_hat_system = y_hat_system / y_hat_system.abs().max() + # y_hat_system *= 10 ** (-12.0 / 20.0) + + # ----------- store predicted audio and parameters ----------- + autodiff_key = [key for key in models.keys() if "autodiff" in key][0] + tmp_system = models[autodiff_key] + model_outputs[short_model_name] = y_hat_system # store audio + + p_peq = p[:, : tmp_system.processor.peq.num_control_params].cpu() + p_comp = p[:, tmp_system.processor.peq.num_control_params :].cpu() + + p_peq_denorm = tmp_system.processor.peq.denormalize_params( + p_peq.view(-1) + ) + p_peq_denorm = [p.numpy() for p in p_peq_denorm] + + p_comp_denorm = tmp_system.processor.comp.denormalize_params( + p_comp.view(-1) + ) + p_comp_denorm = [p.numpy() for p in p_comp_denorm] + + model_params[short_model_name] = {} # store parameters + model_params[short_model_name]["p_peq_denorm"] = p_peq_denorm + model_params[short_model_name]["p_comp_denorm"] = p_comp_denorm + + # ----------- compute metrics ----------- + if short_model_name not in metrics_dict[style_transfer_name]: + metrics_dict[style_transfer_name][short_model_name] = {} + + for metric_name, metric_fn in metrics.items(): + system_val = metric_fn( + y_hat_system[..., 16384:].cpu(), + y_norm.view(1, 1, -1)[..., 16384:].cpu(), + ) + + if ( + metric_name + not in metrics_dict[style_transfer_name][short_model_name] + ): + metrics_dict[style_transfer_name][short_model_name][ + metric_name + ] = [] + metrics_dict[style_transfer_name][short_model_name][ + metric_name + ].append(system_val) + + # ----------------- compute baseline metrics ------------------ + # run the baseline model + y_hat_baseline = baseline_model( + x_norm.view(1, 1, -1).cpu(), + y_norm.view(1, 1, -1).cpu(), + ) + + if "Baseline" not in metrics_dict[style_transfer_name]: + metrics_dict[style_transfer_name]["Baseline"] = {} + + for metric_name, metric_fn in metrics.items(): + baseline_val = metric_fn( + y_hat_baseline.cpu(), y_norm.view(1, 1, -1).cpu() + ) + + if metric_name not in metrics_dict[style_transfer_name]["Baseline"]: + metrics_dict[style_transfer_name]["Baseline"][metric_name] = [] + metrics_dict[style_transfer_name]["Baseline"][metric_name].append( + baseline_val + ) + + if "Corrupt" not in metrics_dict[style_transfer_name]: + metrics_dict[style_transfer_name]["Corrupt"] = {} + + for metric_name, metric_fn in metrics.items(): + baseline_val = metric_fn( + x_norm.view(1, 1, -1).cpu(), y_norm.view(1, 1, -1).cpu() + ) + + if metric_name not in metrics_dict[style_transfer_name]["Corrupt"]: + metrics_dict[style_transfer_name]["Corrupt"][metric_name] = [] + metrics_dict[style_transfer_name]["Corrupt"][metric_name].append( + baseline_val + ) + + if args.save: + input_filepath = os.path.join( + transfer_output_dir, + f"{n}_{input_style}-->{target_style}_input.wav", + ) + target_filepath = os.path.join( + transfer_output_dir, + f"{n}_{input_style}-->{target_style}_target.wav", + ) + baseline_filepath = os.path.join( + transfer_output_dir, + f"{n}_{input_style}-->{target_style}_baseline.wav", + ) + + for short_model_name, model_output in model_outputs.items(): + model_filepath = os.path.join( + transfer_output_dir, + f"{n}_{input_style}-->{target_style}_{short_model_name}.wav", + ) + torchaudio.save( + model_filepath, + model_output.view(chs, -1).cpu(), + y_sr, + ) + + torchaudio.save(input_filepath, x_norm.cpu(), y_sr) + torchaudio.save(target_filepath, y_norm.cpu(), y_sr) + torchaudio.save( + baseline_filepath, + y_hat_baseline.view(chs, -1).cpu(), + y_sr, + ) + + if args.plot: + # create main figure + fig, axs = plt.subplots(figsize=(8, 3), nrows=1, ncols=2) + + for model_idx, (short_model_name, p) in enumerate(model_params.items()): + + p_peq_denorm = p["p_peq_denorm"] + p_comp_denorm = p["p_comp_denorm"] + + # -------- Create Frequency response plot -------- + plot_peq_response( + p_peq_denorm, + args.sample_rate, + ax=axs[0], + label=short_model_name, + color=list(colors.values())[model_idx], + center_line=True if model_idx == 0 else False, + ) + + # -------- Create Compressor response plot -------- + plot_comp_response( + p_comp_denorm, + args.sample_rate, + ax=axs[1], + label=short_model_name, + color=list(colors.values())[model_idx], + center_line=True if model_idx == 0 else False, + ) + + plot_filepath = os.path.join( + transfer_output_dir, + f"{n}_{input_style}-->{target_style}", + ) + + # --------- formating for Parametric EQ --------- + axs[0].set_ylim([-24, 24]) + axs[0].set_xlim([10, 10000]) + axs[0].set_xscale("log") + axs[0].grid(c="lightgray", which="major") + axs[0].grid(c="lightgray", which="minor") + axs[0].set_ylabel("Magnitude (dB)") + axs[0].set_xlabel("Frequency (Hz)") + axs[0].spines["right"].set_visible(False) + axs[0].spines["left"].set_visible(False) + axs[0].spines["top"].set_visible(False) + axs[0].spines["bottom"].set_visible(False) + axs[0].tick_params( + axis="x", which="minor", colors="lightgray", labelcolor="k" + ) + axs[0].tick_params( + axis="x", which="major", colors="lightgray", labelcolor="k" + ) + axs[0].tick_params( + axis="y", which="major", colors="lightgray", labelcolor="k" + ) + axs[0].legend( + ncol=6, + loc="lower center", + columnspacing=0.8, + framealpha=1.0, + bbox_to_anchor=(0.5, 1.05), + ) + # --------- formating for compressor curve --------- + axs[1].set_ylim([-80, 0]) + axs[1].set_xlim([-80, 0]) + axs[1].grid(c="lightgray", which="major") + axs[1].spines["right"].set_visible(False) + axs[1].spines["left"].set_visible(False) + axs[1].spines["top"].set_visible(False) + axs[1].spines["bottom"].set_visible(False) + axs[1].set_ylabel("Output (dB)") + axs[1].set_xlabel("Input (dB)") + axs[1].set_title(f"{input_style} --> {target_style}") + axs[1].tick_params( + axis="x", which="major", colors="lightgray", labelcolor="k" + ) + axs[1].tick_params( + axis="y", which="major", colors="lightgray", labelcolor="k" + ) + axs[1].set(adjustable="box", aspect="equal") + + plt.tight_layout() + plt.savefig(plot_filepath + ".png", dpi=300) + plt.savefig(plot_filepath + ".svg") + plt.savefig(plot_filepath + ".pdf") + plt.close("all") + + for model_name, model_metrics in metrics_dict[style_transfer_name].items(): + if model_name not in metrics_overall: + metrics_overall[model_name] = {} + sys.stdout.write(f"{model_name.ljust(10)} ") + for metric_name, metric_values in model_metrics.items(): + mean_val = np.mean(metric_values) + + if metric_name not in metrics_overall[model_name]: + metrics_overall[model_name][metric_name] = [] + metrics_overall[model_name][metric_name].append(mean_val) + + sys.stdout.write(f"{metric_name}: {mean_val:0.3f} ") + print() + print() + + print("----- Averages ----") + for model_name, model_metrics in metrics_overall.items(): + sys.stdout.write(f"{model_name.ljust(10)} ") + for metric_name, metric_values in model_metrics.items(): + mean_val = np.mean(metric_values) + sys.stdout.write(f"{metric_name}: {mean_val:0.3f} ") + + sys.exit(0) diff --git a/scripts/style_interpolation.py b/scripts/style_interpolation.py new file mode 100644 index 0000000..3b7a686 --- /dev/null +++ b/scripts/style_interpolation.py @@ -0,0 +1,418 @@ +import os +import sys +import glob +import torch +import auraloss +import argparse +import torchaudio +import numpy as np +import scipy.signal +import matplotlib +import matplotlib.pyplot as plt + +from deepafx_st.utils import DSPMode +from deepafx_st.utils import get_random_patch +from deepafx_st.processors.dsp.peq import biqaud +from deepafx_st.system import System + + +def plot_peq_response( + p_peq_denorm, + sr, + ax=None, + label=None, + color=None, + points=False, + center_line=False, +): + + ls_gain = p_peq_denorm[0] + ls_freq = p_peq_denorm[1] + ls_q = p_peq_denorm[2] + b0, a0 = biqaud(ls_gain, ls_freq, ls_q, sr, filter_type="low_shelf") + sos0 = np.concatenate((b0, a0)) + + f1_gain = p_peq_denorm[3] + f1_freq = p_peq_denorm[4] + f1_q = p_peq_denorm[5] + b1, a1 = biqaud(f1_gain, f1_freq, f1_q, sr, filter_type="peaking") + sos1 = np.concatenate((b1, a1)) + + f2_gain = p_peq_denorm[6] + f2_freq = p_peq_denorm[7] + f2_q = p_peq_denorm[8] + b2, a2 = biqaud(f2_gain, f2_freq, f2_q, sr, filter_type="peaking") + sos2 = np.concatenate((b2, a2)) + + f3_gain = p_peq_denorm[9] + f3_freq = p_peq_denorm[10] + f3_q = p_peq_denorm[11] + b3, a3 = biqaud(f3_gain, f3_freq, f3_q, sr, filter_type="peaking") + sos3 = np.concatenate((b3, a3)) + + f4_gain = p_peq_denorm[12] + f4_freq = p_peq_denorm[13] + f4_q = p_peq_denorm[14] + b4, a4 = biqaud(f4_gain, f4_freq, f4_q, sr, filter_type="peaking") + sos4 = np.concatenate((b4, a4)) + + hs_gain = p_peq_denorm[15] + hs_freq = p_peq_denorm[16] + hs_q = p_peq_denorm[17] + b5, a5 = biqaud(hs_gain, hs_freq, hs_q, sr, filter_type="high_shelf") + sos5 = np.concatenate((b5, a5)) + + sos = [sos0, sos1, sos2, sos3, sos4, sos5] + sos = np.array(sos) + # print(sos.shape) + # print(sos) + + # measure freq response + w, h = scipy.signal.sosfreqz(sos, fs=22050, worN=2048) + + if ax is None: + fig, axs = plt.subplots() + + if center_line: + ax.plot(w, np.zeros(w.shape), color="lightgray") + + ax.plot(w, 20 * np.log10(np.abs(h)), label=label, color=color) + if points: + ax.scatter(ls_freq, ls_gain, color=color) + ax.scatter(f1_freq, f1_gain, color=color) + ax.scatter(f2_freq, f2_gain, color=color) + ax.scatter(f3_freq, f3_gain, color=color) + ax.scatter(f4_freq, f4_gain, color=color) + ax.scatter(hs_freq, hs_gain, color=color) + + +def plot_comp_response( + p_comp_denorm, + sr, + ax=None, + label=None, + color=None, + center_line=False, + param_text=True, +): + + # get parameters + threshold = p_comp_denorm[0] + ratio = p_comp_denorm[1] + attack_ms = p_comp_denorm[2] * 1000 + release_ms = p_comp_denorm[3] * 1000 + knee_db = p_comp_denorm[4] + makeup_db = p_comp_denorm[5] + + # print(knee_db) + + x = np.linspace(-80, 0) # input level + y = np.zeros(x.shape) # output level + + idx = np.where((2 * (x - threshold)) < -knee_db) + y[idx] = x[idx] + + idx = np.where((2 * np.abs(x - threshold)) <= knee_db) + y[idx] = x[idx] + ( + (1 / ratio - 1) * (((x[idx] - threshold + (knee_db / 2))) ** 2) + ) / (2 * knee_db) + + idx = np.where((2 * (x - threshold)) > knee_db) + y[idx] = threshold + ((x[idx] - threshold) / (ratio)) + + text_height = threshold + ((0 - threshold) / (ratio)) + + # plot the first part of the line + ax.plot(x, y, label=label, color=color) + if center_line: + ax.plot(x, x, color="lightgray", linestyle="--") + + if param_text: + ax.text( + 0, + text_height, + f"{threshold:0.1f} dB {ratio:0.1f}:1", + fontsize="small", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--ckpt_path", + type=str, + help="Path to pre-trained system checkpoint.", + ) + parser.add_argument( + "--input_audio", + type=str, + help="Path to input audio file.", + ) + parser.add_argument( + "--input_length", + type=int, + help="Number of seconds to process from input file.", + ) + parser.add_argument( + "--num_steps", + type=int, + default=10, + help="Number of interpolation steps.", + ) + parser.add_argument( + "--style_a", + help="Starting style.", + type=str, + ) + parser.add_argument( + "--style_a_name", + help="Starting style name.", + type=str, + default="a", + ) + parser.add_argument( + "--style_b", + help="Ending style.", + type=str, + ) + parser.add_argument( + "--style_b_name", + help="Ending style name.", + type=str, + default="b", + ) + parser.add_argument( + "--gpu", + help="Run System on GPU.", + action="store_true", + ) + parser.add_argument( + "--save", + help="Save audio examples.", + action="store_true", + ) + parser.add_argument( + "--output_dir", + type=str, + help="Path to save audio outputs.", + default="style_interpolation", + ) + + args = parser.parse_args() + torch.manual_seed(42) + + device = "cuda" if args.gpu else "cpu" + + if not os.path.isdir(args.output_dir): + os.makedirs(args.output_dir) + + # --------------- setup pre-trained model --------------- + use_dsp = DSPMode.NONE + system = System.load_from_checkpoint( + args.ckpt_path, + use_dsp=use_dsp, + batch_size=1, + spsa_parallel=False, + ) + system.eval() + if args.gpu: + system.to("cuda") + + sample_rate = system.hparams.sample_rate + + # ----------- load and pre-process audio ------------- + # normalize input + input_name = os.path.basename(args.input_audio).replace(".wav", "") + x, x_sr = torchaudio.load(args.input_audio) + input_length_samp = int(x_sr * args.input_length) + # x = x[:, input_length_samp : input_length_samp * 2] + # if x.shape[-1] > 131072: + # x = get_random_patch(x, x_sr, 131072) + + x_norm = x / x.abs().max() + x_norm *= 10 ** (-12.0 / 20.0) + + fig, axs = plt.subplots(figsize=(10, 4), nrows=1, ncols=2) + cmap = matplotlib.cm.get_cmap("viridis") + + colors = { + "podcast": (70 / 255, 181 / 255, 211 / 255), # podcast + "radio": (52 / 255, 57 / 255, 60 / 255), # radio + "telephone": (219 / 255, 73 / 255, 76 / 255), # telephone + "warm": (235 / 255, 164 / 255, 50 / 255), # warm + "bright": (134 / 255, 170 / 255, 109 / 255), # bright + } + + style_a_color = colors[args.style_a_name] + style_b_color = colors[args.style_b_name] + + # compute start and ending style embeddings + style_a_audio, sr = torchaudio.load(args.style_a) + style_b_audio, sr = torchaudio.load(args.style_b) + + style_a_audio = style_a_audio / style_a_audio.abs().max() + style_a_audio *= 10 ** (-12.0 / 20.0) + style_b_audio = style_b_audio / style_b_audio.abs().max() + style_b_audio *= 10 ** (-12.0 / 20.0) + + with torch.no_grad(): + style_a_embed = system.encoder(style_a_audio.view(1, 1, -1)) + style_b_embed = system.encoder(style_b_audio.view(1, 1, -1)) + + # linear interpolation between style embeddings + outputs = [] + for w_idx, w in enumerate(np.linspace(0, 1, args.num_steps)): + + style_embed = (w * style_b_embed) + ((1 - w) * style_a_embed) + print(w_idx, style_embed) + + # run our model + with torch.no_grad(): + y_hat_system, p, e_system = system( + x_norm.view(1, 1, -1), + e_y=style_embed, + analysis_length=131072, + ) + outputs.append(y_hat_system.view(-1)) + + # -------- split params between EQ and Comp. -------- + p_peq = p[:, : system.processor.peq.num_control_params] + p_comp = p[:, system.processor.peq.num_control_params :] + + p_peq_denorm = system.processor.peq.denormalize_params(p_peq.view(-1)) + p_peq_denorm = [p.numpy() for p in p_peq_denorm] + + p_comp_denorm = system.processor.comp.denormalize_params(p_comp.view(-1)) + p_comp_denorm = [p.numpy() for p in p_comp_denorm] + + comp_params = {} + + # -------- Create Frequency response plot -------- + if w_idx == 0: + label = args.style_a_name + elif w_idx == (args.num_steps - 1): + label = args.style_b_name + else: + label = None + + # linear interpolkate RGB color + style_color_R = (w * style_b_color[0]) + ((1 - w) * style_a_color[0]) + style_color_G = (w * style_b_color[1]) + ((1 - w) * style_a_color[1]) + style_color_B = (w * style_b_color[2]) + ((1 - w) * style_a_color[2]) + style_color = (style_color_R, style_color_G, style_color_B) + + plot_peq_response( + p_peq_denorm, + sample_rate, + ax=axs[0], + label=label, + color=style_color, + center_line=True if w_idx == 0 else False, + ) + + plot_comp_response( + p_comp_denorm, + sample_rate, + ax=axs[1], + label=label, + color=style_color, + center_line=True if w_idx == 0 else False, + param_text=True if w_idx == 0 or (w_idx + 1) == args.num_steps else False, + ) + + if False: + if not os.path.isdir(args.output_dir): + os.makedirs(args.output_dir) + + input_filepath = os.path.join(args.output_dir, f"{n}_{input_name}.wav") + style_filepath = os.path.join(args.output_dir, f"{n}_{style_name}.wav") + system_filepath = os.path.join( + args.output_dir, + f"{n}_{input_name}_{style_name}_system.wav", + ) + baseline_filepath = os.path.join( + args.output_dir, + f"{n}_{input_name}_{style_name}_baseline.wav", + ) + + torchaudio.save(input_filepath, x_norm, y_sr) + torchaudio.save(style_filepath, y_norm, y_sr) + torchaudio.save(system_filepath, y_hat_system.view(1, -1), y_sr) + # torchaudio.save(baseline_filepath, y_hat_baseline.view(1, -1), y_sr) + + # --------- Morphing audio style --------- + frame_size = x.shape[-1] // args.num_steps + output = torch.zeros(x.shape[-1]) + + for n in range(args.num_steps): + start_idx = n * frame_size + stop_idx = start_idx + frame_size + output[start_idx:stop_idx] = outputs[n][start_idx:stop_idx] + + audio_filepath = os.path.join( + args.output_dir, + f"style_interpolation_{args.style_a_name}_to_{args.style_b_name}_interp.wav", + ) + torchaudio.save(audio_filepath, output.view(1, -1), sample_rate) + + # --------- Save input/output/style audio --------- + input_filepath = os.path.join( + args.output_dir, + f"style_interpolation_{args.style_a_name}_to_{args.style_b_name}_input.wav", + ) + style_a_filepath = os.path.join( + args.output_dir, + f"style_interpolation_{args.style_a_name}_to_{args.style_b_name}_style={args.style_a_name}.wav", + ) + style_b_filepath = os.path.join( + args.output_dir, + f"style_interpolation_{args.style_a_name}_to_{args.style_b_name}_style={args.style_b_name}.wav", + ) + torchaudio.save(input_filepath, x.view(1, -1), sample_rate) + torchaudio.save(style_a_filepath, style_a_audio.view(1, -1), sample_rate) + torchaudio.save(style_b_filepath, style_b_audio.view(1, -1), sample_rate) + + # --------- formating for Parametric EQ ---------= + plot_filepath = os.path.join( + args.output_dir, + f"style_interpolation_{args.style_a_name}_to_{args.style_b_name}", + ) + axs[0].set_ylim([-24, 24]) + axs[0].set_xlim([10, 10000]) + axs[0].set_xscale("log") + axs[0].grid(c="lightgray", which="major") + axs[0].grid(c="lightgray", which="minor") + axs[0].set_ylabel("Magnitude (dB)") + axs[0].set_xlabel("Frequency (Hz)") + axs[0].spines["right"].set_visible(False) + axs[0].spines["left"].set_visible(False) + axs[0].spines["top"].set_visible(False) + axs[0].spines["bottom"].set_visible(False) + axs[0].tick_params(axis="x", which="minor", colors="lightgray", labelcolor="k") + axs[0].tick_params(axis="x", which="major", colors="lightgray", labelcolor="k") + axs[0].tick_params(axis="y", which="major", colors="lightgray", labelcolor="k") + axs[0].legend( + ncol=2, + loc="lower center", + columnspacing=0.8, + framealpha=1.0, + bbox_to_anchor=(0.5, 1.05), + ) + # --------- formating for compressor curve --------- + axs[1].set_ylim([-80, 0]) + axs[1].set_xlim([-80, 0]) + axs[1].grid(c="lightgray", which="major") + axs[1].spines["right"].set_visible(False) + axs[1].spines["left"].set_visible(False) + axs[1].spines["top"].set_visible(False) + axs[1].spines["bottom"].set_visible(False) + axs[1].set_ylabel("Output (dB)") + axs[1].set_xlabel("Input (dB)") + axs[1].tick_params(axis="x", which="major", colors="lightgray", labelcolor="k") + axs[1].tick_params(axis="y", which="major", colors="lightgray", labelcolor="k") + axs[1].set(adjustable="box", aspect="equal") + + plt.tight_layout() + plt.savefig(plot_filepath + ".png", dpi=300) + plt.savefig(plot_filepath + ".svg") + plt.savefig(plot_filepath + ".pdf") diff --git a/scripts/style_transfer.py b/scripts/style_transfer.py new file mode 100644 index 0000000..154da00 --- /dev/null +++ b/scripts/style_transfer.py @@ -0,0 +1,471 @@ +import os +import sys +import glob +from types import resolve_bases +import torch +import auraloss +import argparse +import torchaudio +import numpy as np +import scipy.signal +import matplotlib +import matplotlib.pyplot as plt + +from deepafx_st.utils import DSPMode +from deepafx_st.utils import loudness_normalize +from deepafx_st.processors.dsp.peq import biqaud, parametric_eq +from deepafx_st.processors.dsp.compressor import compressor +from deepafx_st.system import System +from deepafx_st.models.baselines import BaselineEQAndComp +from deepafx_st.metrics import ( + LoudnessError, + RMSEnergyError, + SpectralCentroidError, + CrestFactorError, + PESQ, + MelSpectralDistance, +) + + +def plot_peq_response( + p_peq_denorm, + sr, + ax=None, + label=None, + color=None, + points=False, + center_line=False, +): + + ls_gain = p_peq_denorm[0] + ls_freq = p_peq_denorm[1] + ls_q = p_peq_denorm[2] + b0, a0 = biqaud(ls_gain, ls_freq, ls_q, sr, filter_type="low_shelf") + sos0 = np.concatenate((b0, a0)) + + f1_gain = p_peq_denorm[3] + f1_freq = p_peq_denorm[4] + f1_q = p_peq_denorm[5] + b1, a1 = biqaud(f1_gain, f1_freq, f1_q, sr, filter_type="peaking") + sos1 = np.concatenate((b1, a1)) + + f2_gain = p_peq_denorm[6] + f2_freq = p_peq_denorm[7] + f2_q = p_peq_denorm[8] + b2, a2 = biqaud(f2_gain, f2_freq, f2_q, sr, filter_type="peaking") + sos2 = np.concatenate((b2, a2)) + + f3_gain = p_peq_denorm[9] + f3_freq = p_peq_denorm[10] + f3_q = p_peq_denorm[11] + b3, a3 = biqaud(f3_gain, f3_freq, f3_q, sr, filter_type="peaking") + sos3 = np.concatenate((b3, a3)) + + f4_gain = p_peq_denorm[12] + f4_freq = p_peq_denorm[13] + f4_q = p_peq_denorm[14] + b4, a4 = biqaud(f4_gain, f4_freq, f4_q, sr, filter_type="peaking") + sos4 = np.concatenate((b4, a4)) + + hs_gain = p_peq_denorm[15] + hs_freq = p_peq_denorm[16] + hs_q = p_peq_denorm[17] + b5, a5 = biqaud(hs_gain, hs_freq, hs_q, sr, filter_type="high_shelf") + sos5 = np.concatenate((b5, a5)) + + sos = [sos0, sos1, sos2, sos3, sos4, sos5] + sos = np.array(sos) + # print(sos.shape) + # print(sos) + + if label == "telephone": + label = "Tele" + + # measure freq response + w, h = scipy.signal.sosfreqz(sos, fs=22050, worN=2048) + + if ax is None: + fig, axs = plt.subplots() + + if center_line: + ax.plot(w, np.zeros(w.shape), color="lightgray") + + (handle,) = ax.plot( + w, 20 * np.log10(np.abs(h)), label=label.capitalize(), color=color + ) + if points: + ax.scatter(ls_freq, ls_gain, color=color) + ax.scatter(f1_freq, f1_gain, color=color) + ax.scatter(f2_freq, f2_gain, color=color) + ax.scatter(f3_freq, f3_gain, color=color) + ax.scatter(f4_freq, f4_gain, color=color) + ax.scatter(hs_freq, hs_gain, color=color) + + return handle + + +def plot_comp_response( + p_comp_denorm, + sr, + ax=None, + label=None, + color=None, + center_line=False, + prev_height=None, + plot_idx=0, +): + + # get parameters + threshold = p_comp_denorm[0] + ratio = p_comp_denorm[1] + attack_ms = p_comp_denorm[2] * 1000 + release_ms = p_comp_denorm[3] * 1000 + knee_db = p_comp_denorm[4] + makeup_db = p_comp_denorm[5] + + # print(knee_db) + + x = np.linspace(-80, 0) # input level + y = np.zeros(x.shape) # output level + + idx = np.where((2 * (x - threshold)) < -knee_db) + y[idx] = x[idx] + + idx = np.where((2 * np.abs(x - threshold)) <= knee_db) + y[idx] = x[idx] + ( + (1 / ratio - 1) * (((x[idx] - threshold + (knee_db / 2))) ** 2) + ) / (2 * knee_db) + + idx = np.where((2 * (x - threshold)) > knee_db) + y[idx] = threshold + ((x[idx] - threshold) / (ratio)) + + text_height = threshold + ((-20 - threshold) / (ratio)) - 1 + + # text_height += style_idx * 2 + # if ((style_idx + 1) % 2) == 0: + text_width = -18 + # else: + # text_width = -14 + + # plot the first part of the line + ax.plot(x, y, color=color) + if center_line: + ax.plot(x, x, color="lightgray", linestyle="--") + # ax.text(-19, -22, f"Thres. Ratio") + ax.text(-19, -22, f"Ratio") + + # ax.text(-18, text_height, f"{threshold:0.1f} {ratio:0.1f}") + ax.text(text_width, text_height, f"{ratio:0.1f}") + + return text_height + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--ckpt_path", + type=str, + help="Path to pre-trained system checkpoint.", + ) + parser.add_argument( + "--input_filepaths", + type=str, + help="List of input audio filepaths.", + nargs="+", + ) + parser.add_argument( + "--style_filepaths", + help="List of style audio filepaths.", + type=str, + nargs="+", + ) + parser.add_argument( + "--gpu", + help="Run System on GPU.", + action="store_true", + ) + parser.add_argument( + "--modify_input", + help="Apply increasing strong effects to input.", + action="store_true", + ) + parser.add_argument( + "--save", + help="Save audio examples.", + action="store_true", + ) + parser.add_argument( + "--output_dir", + type=str, + help="Path to save audio outputs.", + default="style_transfer", + ) + parser.add_argument( + "--target_loudness", + type=float, + help="Target audio output loudness in dB LUFS", + default=-23.0, + ) + + args = parser.parse_args() + torch.manual_seed(42) + + device = "cuda" if args.gpu else "cpu" + + fontlist = matplotlib.font_manager.findSystemFonts(fontpaths=None, fontext="ttf") + fontlist = [f.name for f in matplotlib.font_manager.fontManager.ttflist] + # print(fontlist) + # set font + plt.rcParams["font.family"] = "Nimbus Roman" + + # --------------- setup pre-trained model --------------- + use_dsp = DSPMode.NONE + system = System.load_from_checkpoint( + args.ckpt_path, + use_dsp=use_dsp, + batch_size=1, + spsa_parallel=False, + ) + system.eval() + if args.gpu: + system.to("cuda") + + sample_rate = system.hparams.sample_rate + + # create the baseline model + baseline_model = BaselineEQAndComp(sample_rate=sample_rate) + + colors = { + "neutral": (70 / 255, 181 / 255, 211 / 255), # neutral + "broadcast": (52 / 255, 57 / 255, 60 / 255), # broadcast + "telephone": (219 / 255, 73 / 255, 76 / 255), # telephone + "warm": (235 / 255, 164 / 255, 50 / 255), # warm + "bright": (134 / 255, 170 / 255, 109 / 255), # bright + } + + for input_filepath in args.input_filepaths: + # normalize input + input_name = os.path.basename(input_filepath).replace(".wav", "") + x, x_sr = torchaudio.load(input_filepath) + + if x_sr != sample_rate: + x = torchaudio.transforms.Resample(x_sr, sample_rate)(x) + + x = x[:, : 262144 * 2] + x /= x.abs().max() + print(x.shape) + + fig, axs = plt.subplots(figsize=(5, 2), nrows=1, ncols=2) + + # fig, axs = plt.subplots(figsize=(4, 5), nrows=2, ncols=1) + cmap = matplotlib.cm.get_cmap("viridis") + handles = [] + prev_height = None + + if args.modify_input: + parameters = { + "high_shelf_gain_dB": [0.0, 1.0, 3.0, 6.0, 12.0], # , 18.0], + "threshold_dB": [-3.0, -12.0, -24.0, -40.0, -62.0], # , -70], + "ratio": [1.0, 2.0, 3.0, 3.0, 4.0], # , 8.0], + } + args.style_filepaths *= len(parameters["high_shelf_gain_dB"]) + + for style_idx, style_filepath in enumerate(args.style_filepaths): + y, y_sr = torchaudio.load(style_filepath) + style_name = os.path.basename(os.path.dirname(style_filepath)) + + # apply increasing effects + if args.modify_input: + hsg = parameters["high_shelf_gain_dB"][style_idx] + thr = parameters["threshold_dB"][style_idx] + rto = parameters["ratio"][style_idx] + x_proc = parametric_eq( + x.view(-1).numpy(), + float(sample_rate), + high_shelf_gain_dB=hsg, + high_shelf_cutoff_freq=4000.0, + ) + x_proc = compressor( + x_proc, + float(sample_rate), + threshold=thr, + ratio=rto, + attack_time=0.005, + release_time=0.050, + knee_dB=0.0, + ) + x_proc = torch.tensor(x_proc).view(1, -1) + else: + x_proc = x.clone() + + x_norm = x_proc / x_proc.abs().max() + x_norm *= 10 ** (-6.0 / 20.0) + + x_proc = x_proc[:, -131072:] + + # normalize reference + y_norm = y / y.abs().max() + y_norm *= 10 ** (-12.0 / 20.0) + + # run our model + with torch.no_grad(): + y_hat_system, p, e_system = system( + x_norm.view(1, 1, -1), + y=y_norm.view(1, 1, -1), + analysis_length=131072, + ) + + # -------- split params between EQ and Comp. -------- + p_peq = p[:, : system.processor.peq.num_control_params] + p_comp = p[:, system.processor.peq.num_control_params :] + + p_peq_denorm = system.processor.peq.denormalize_params(p_peq.view(-1)) + p_peq_denorm = [p.numpy() for p in p_peq_denorm] + + p_comp_denorm = system.processor.comp.denormalize_params(p_comp.view(-1)) + p_comp_denorm = [p.numpy() for p in p_comp_denorm] + + # comp_params = {} + + # -------- Create Frequency response plot -------- + if args.modify_input: + label = f"{style_idx}" + color = cmap(style_idx / len(parameters["high_shelf_gain_dB"])) + + else: + label = style_name + color = colors[style_name] + + handle = plot_peq_response( + p_peq_denorm, + sample_rate, + ax=axs[0], + label=label, + color=color, + center_line=True if style_idx == 0 else False, + ) + handles.append(handle) + + prev_height = plot_comp_response( + p_comp_denorm, + sample_rate, + ax=axs[1], + label=label, + color=color, + center_line=True if style_idx == 0 else False, + prev_height=prev_height, + plot_idx=style_idx, + ) + + if args.save: + if not os.path.isdir(args.output_dir): + os.makedirs(args.output_dir) + + input_filepath = os.path.join( + args.output_dir, + f"{style_idx}_{input_name}.wav", + ) + style_filepath = os.path.join( + args.output_dir, + f"{style_idx}_{style_name}.wav", + ) + system_filepath = os.path.join( + args.output_dir, + f"{style_idx}_{input_name}_to_{style_name}_system.wav", + ) + + torchaudio.save( + input_filepath, + loudness_normalize( + x_norm, + sample_rate, + args.target_loudness, + ), + sample_rate, + ) + torchaudio.save( + style_filepath, + loudness_normalize( + y_norm, + sample_rate, + args.target_loudness, + ), + sample_rate, + ) + torchaudio.save( + system_filepath, + loudness_normalize( + y_hat_system.view(1, -1), + sample_rate, + args.target_loudness, + ), + sample_rate, + ) + + plot_filepath = os.path.join( + args.output_dir, f"style_transfer_{input_name}" + ) + # --------- formating for Parametric EQ --------- + axs[0].set_ylim([-24, 24]) + axs[0].set_xlim([10, 10000]) + axs[0].set_xscale("log") + axs[0].grid(c="lightgray", which="major") + axs[0].grid(c="lightgray", which="minor") + axs[0].set_ylabel("Magnitude (dB)") + axs[0].set_xlabel("Frequency (Hz)") + axs[0].spines["right"].set_visible(False) + axs[0].spines["left"].set_visible(False) + axs[0].spines["top"].set_visible(False) + axs[0].spines["bottom"].set_visible(False) + axs[0].tick_params( + axis="x", + which="minor", + colors="lightgray", + labelcolor="k", + ) + axs[0].tick_params( + axis="x", + which="major", + colors="lightgray", + labelcolor="k", + ) + axs[0].tick_params( + axis="y", + which="major", + colors="lightgray", + labelcolor="k", + ) + if args.modify_input: + ncol = 5 + else: + ncol = 5 + + axs[0].legend( + handles=handles, + ncol=ncol, + loc="upper center", + columnspacing=0.8, + framealpha=0.0, + bbox_to_anchor=(1.05, 1.3), + # bbox_to_anchor=(0.5, -0.025), + ) + # axs[0].set(adjustable="box", aspect="auto") + # --------- formating for compressor curve --------- + axs[1].set_ylim([-80, -20]) + axs[1].set_xlim([-80, -20]) + axs[1].grid(c="lightgray", which="major") + axs[1].spines["right"].set_visible(False) + axs[1].spines["left"].set_visible(False) + axs[1].spines["top"].set_visible(False) + axs[1].spines["bottom"].set_visible(False) + axs[1].set_ylabel("Output (dB)") + axs[1].set_xlabel("Input (dB)") + axs[1].tick_params(axis="x", which="major", colors="lightgray", labelcolor="k") + axs[1].tick_params(axis="y", which="major", colors="lightgray", labelcolor="k") + axs[1].set(adjustable="box", aspect="equal") + + # fig.tight_layout() + fig.subplots_adjust(top=0.86, bottom=0.22, wspace=0.25, hspace=0.4, right=0.90) + plt.savefig(plot_filepath + ".png", dpi=300) + plt.savefig(plot_filepath + ".svg") + plt.savefig(plot_filepath + ".pdf") diff --git a/scripts/style_transfer_bulk.py b/scripts/style_transfer_bulk.py new file mode 100644 index 0000000..613d10d --- /dev/null +++ b/scripts/style_transfer_bulk.py @@ -0,0 +1,454 @@ +import os +import sys +import glob +import torch +import auraloss +import argparse +import torchaudio +import numpy as np +import scipy.signal +import matplotlib +import pyloudnorm as pyln +import matplotlib.pyplot as plt + +from deepafx_st.utils import DSPMode +from deepafx_st.utils import get_random_patch +from deepafx_st.processors.dsp.peq import biqaud +from deepafx_st.system import System + + +def plot_peq_response( + p_peq_denorm, + sr, + ax=None, + label=None, + color=None, + points=False, + center_line=False, +): + + ls_gain = p_peq_denorm[0] + ls_freq = p_peq_denorm[1] + ls_q = p_peq_denorm[2] + b0, a0 = biqaud(ls_gain, ls_freq, ls_q, sr, filter_type="low_shelf") + sos0 = np.concatenate((b0, a0)) + + f1_gain = p_peq_denorm[3] + f1_freq = p_peq_denorm[4] + f1_q = p_peq_denorm[5] + b1, a1 = biqaud(f1_gain, f1_freq, f1_q, sr, filter_type="peaking") + sos1 = np.concatenate((b1, a1)) + + f2_gain = p_peq_denorm[6] + f2_freq = p_peq_denorm[7] + f2_q = p_peq_denorm[8] + b2, a2 = biqaud(f2_gain, f2_freq, f2_q, sr, filter_type="peaking") + sos2 = np.concatenate((b2, a2)) + + f3_gain = p_peq_denorm[9] + f3_freq = p_peq_denorm[10] + f3_q = p_peq_denorm[11] + b3, a3 = biqaud(f3_gain, f3_freq, f3_q, sr, filter_type="peaking") + sos3 = np.concatenate((b3, a3)) + + f4_gain = p_peq_denorm[12] + f4_freq = p_peq_denorm[13] + f4_q = p_peq_denorm[14] + b4, a4 = biqaud(f4_gain, f4_freq, f4_q, sr, filter_type="peaking") + sos4 = np.concatenate((b4, a4)) + + hs_gain = p_peq_denorm[15] + hs_freq = p_peq_denorm[16] + hs_q = p_peq_denorm[17] + b5, a5 = biqaud(hs_gain, hs_freq, hs_q, sr, filter_type="high_shelf") + sos5 = np.concatenate((b5, a5)) + + sos = [sos0, sos1, sos2, sos3, sos4, sos5] + sos = np.array(sos) + # print(sos.shape) + # print(sos) + + # measure freq response + w, h = scipy.signal.sosfreqz(sos, fs=22050, worN=2048) + + if ax is None: + fig, axs = plt.subplots() + + if center_line: + ax.plot(w, np.zeros(w.shape), color="lightgray") + + ax.plot(w, 20 * np.log10(np.abs(h)), label=label, color=color) + if points: + ax.scatter(ls_freq, ls_gain, color=color) + ax.scatter(f1_freq, f1_gain, color=color) + ax.scatter(f2_freq, f2_gain, color=color) + ax.scatter(f3_freq, f3_gain, color=color) + ax.scatter(f4_freq, f4_gain, color=color) + ax.scatter(hs_freq, hs_gain, color=color) + + +def plot_comp_response( + p_comp_denorm, + sr, + ax=None, + label=None, + color=None, + center_line=False, + param_text=True, +): + + # get parameters + threshold = p_comp_denorm[0] + ratio = p_comp_denorm[1] + attack_ms = p_comp_denorm[2] * 1000 + release_ms = p_comp_denorm[3] * 1000 + knee_db = p_comp_denorm[4] + makeup_db = p_comp_denorm[5] + + # print(knee_db) + + x = np.linspace(-80, 0) # input level + y = np.zeros(x.shape) # output level + + idx = np.where((2 * (x - threshold)) < -knee_db) + y[idx] = x[idx] + + idx = np.where((2 * np.abs(x - threshold)) <= knee_db) + y[idx] = x[idx] + ( + (1 / ratio - 1) * (((x[idx] - threshold + (knee_db / 2))) ** 2) + ) / (2 * knee_db) + + idx = np.where((2 * (x - threshold)) > knee_db) + y[idx] = threshold + ((x[idx] - threshold) / (ratio)) + + text_height = threshold + ((0 - threshold) / (ratio)) + + # plot the first part of the line + ax.plot(x, y, label=label, color=color) + if center_line: + ax.plot(x, x, color="lightgray", linestyle="--") + + if param_text: + ax.text( + 0, + text_height, + f"{threshold:0.1f} dB {ratio:0.1f}:1", + fontsize="small", + ) + + +def loudness_normalize(x, target_loudness=-24.0): + x = x.view(1, -1) + stereo_audio = x.repeat(2, 1).permute(1, 0).numpy() + loudness = meter.integrated_loudness(stereo_audio) + norm_x = pyln.normalize.loudness( + stereo_audio, + loudness, + target_loudness, + ) + x = torch.tensor(norm_x).permute(1, 0) + x = x[0, :].view(1, -1) + + return x + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--ckpt_path", + type=str, + help="Path to pre-trained system checkpoint.", + ) + parser.add_argument( + "--input_filepaths", + help="List of audio filepaths for style transfer.", + type=str, + nargs="+", + ) + parser.add_argument( + "--gpu", + help="Run System on GPU.", + action="store_true", + ) + parser.add_argument( + "--save", + help="Save audio examples.", + action="store_true", + ) + parser.add_argument( + "--output_dir", + type=str, + help="Path to save audio outputs.", + default="style_transfer_bulk", + ) + parser.add_argument( + "--target_loudness", + type=float, + help="Target audio output loudness in dB LUFS", + default=-23.0, + ) + parser.add_argument( + "--num_interp_steps", + type=int, + help="Number of steps between each interpolated style.", + default=4, + ) + + args = parser.parse_args() + torch.manual_seed(42) + + device = "cuda" if args.gpu else "cpu" + + if not os.path.isdir(args.output_dir): + os.makedirs(args.output_dir) + + # --------------- setup pre-trained model --------------- + use_dsp = DSPMode.NONE + system = System.load_from_checkpoint( + args.ckpt_path, + use_dsp=use_dsp, + batch_size=1, + spsa_parallel=False, + ) + system.eval() + if args.gpu: + system.to("cuda") + + sample_rate = system.hparams.sample_rate + meter = pyln.Meter(sample_rate) # Loudness meter + + # ----------- Plotting setup ------------- + + colors = { + "neutral": (70 / 255, 181 / 255, 211 / 255), # neutral + "broadcast": (52 / 255, 57 / 255, 60 / 255), # broadcast + "telephone": (219 / 255, 73 / 255, 76 / 255), # telephone + "warm": (235 / 255, 164 / 255, 50 / 255), # warm + "bright": (134 / 255, 170 / 255, 109 / 255), # bright + } + + # ----------- Locate audio files ------------- + for input_filepath in args.input_filepaths: + outputs = {} # store the transformed style outputs + interp_outputs = [] # style interpolations + + # create one plot for each input style + fig, axs = plt.subplots(figsize=(10, 4), nrows=1, ncols=2) + + input_style_name = os.path.basename(os.path.dirname(input_filepath)) + input_style_color = colors[input_style_name] + x, x_sr = torchaudio.load(input_filepath) + x = x / x.abs().max() + + output_filepath = os.path.join( + args.output_dir, + f"style_transfer_{input_style_name}.wav", + ) + torchaudio.save( + output_filepath, + loudness_normalize(x, args.target_loudness), + x_sr, + ) + + x *= 10 ** (-12.0 / 20.0) + + # use all other styles are targets + style_filepaths = list(args.input_filepaths) + style_filepaths.remove(input_filepath) + + # ----------- interpolate between all target styles ----------- + target_style_names = [] + for sidx, style_filepath in enumerate(style_filepaths): + target_style_a_name = os.path.basename(os.path.dirname(style_filepath)) + target_style_names.append(target_style_a_name) + target_style_a_color = colors[target_style_a_name] + y_a, y_a_sr = torchaudio.load(style_filepath) + y_a = y_a / y_a.abs().max() + y_a *= 10 ** (-12.0 / 20.0) + + # get the next style in list + next_sidx = sidx + 1 + if next_sidx > len(style_filepaths) - 1: + next_sidx = 0 + style_filepath = style_filepaths[next_sidx] + target_style_b_name = os.path.basename(os.path.dirname(style_filepath)) + target_style_b_color = colors[target_style_b_name] + y_b, y_b_sr = torchaudio.load(style_filepath) + y_b = y_b / y_b.abs().max() + y_b *= 10 ** (-12.0 / 20.0) + + # compute style embeddings + with torch.no_grad(): + style_a_embed = system.encoder(y_a.view(1, 1, -1)) + style_b_embed = system.encoder(y_b.view(1, 1, -1)) + + # repeat the input audio for more length + x_long = x.repeat(1, 4) + + # linear interpolation between style embeddings + for w_idx, w in enumerate(np.linspace(0, 1, args.num_interp_steps)): + style_embed = (w * style_b_embed) + ((1 - w) * style_a_embed) + print(w_idx, style_embed) + + # run our model + with torch.no_grad(): + y_hat_system, p, e_system = system( + x_long.view(1, 1, -1), + e_y=style_embed, + analysis_length=131072, + ) + + interp_outputs.append( + loudness_normalize( + y_hat_system.view(1, -1), + args.target_loudness, + ), + ) + + # chop outputs into an interpolation + num_frames = args.num_interp_steps * len(style_filepaths) + frame_size = x_long.shape[-1] // num_frames + tmp_output = torch.zeros(x_long.shape) + + fade_size = 4096 + + for n in range(num_frames): + start_idx = (n * frame_size) - fade_size + stop_idx = (start_idx + frame_size) + fade_size + if start_idx < 0: + start_idx = 0 + if stop_idx > tmp_output.shape[-1]: + stop_idx = tmp_output.shape[-1] - 1 + frame_audio = interp_outputs[n][:, start_idx:stop_idx] + # apply linear fade in and out + ramp_up = np.linspace(0, 1, num=fade_size) + ramp_down = np.linspace(1, 0, num=fade_size) + frame_audio[:, :fade_size] *= ramp_up + frame_audio[:, -fade_size:] *= ramp_down + tmp_output[:, start_idx:stop_idx] = frame_audio + + filename = ( + f"style_transfer_{input_style_name}_to_" + + "_".join(target_style_names) + + ".wav" + ) + output_filepath = os.path.join(args.output_dir, filename) + + # normalize to target loudness + torchaudio.save( + output_filepath, + loudness_normalize(tmp_output, args.target_loudness), + x_sr, + ) + + # ----------- single transfer to each style (with plotting) ----------- + for sidx, style_filepath in enumerate(style_filepaths): + target_style_name = os.path.basename(os.path.dirname(style_filepath)) + target_style_color = colors[target_style_name] + y, y_sr = torchaudio.load(style_filepath) + y = y / y.abs().max() + y *= 10 ** (-12.0 / 20.0) + + # run our model + with torch.no_grad(): + y_hat_system, p, e_system = system( + x.view(1, 1, -1), + y=y.view(1, 1, -1), + ) + + # normalize and store + y_hat_system /= y_hat_system.abs().max() + transfer_name = f"{input_style_name}_to_{target_style_name}" + outputs[transfer_name] = y_hat_system.view(1, -1) + + # -------- split params between EQ and Comp. -------- + p_peq = p[:, : system.processor.peq.num_control_params] + p_comp = p[:, system.processor.peq.num_control_params :] + + p_peq_denorm = system.processor.peq.denormalize_params(p_peq.view(-1)) + p_peq_denorm = [p.numpy() for p in p_peq_denorm] + + p_comp_denorm = system.processor.comp.denormalize_params(p_comp.view(-1)) + p_comp_denorm = [p.numpy() for p in p_comp_denorm] + + comp_params = {} + + # -------- Create Frequency response plot -------- + plot_peq_response( + p_peq_denorm, + sample_rate, + ax=axs[0], + label=target_style_name, + color=colors[target_style_name], + center_line=True if sidx == 0 else False, + ) + + plot_comp_response( + p_comp_denorm, + sample_rate, + ax=axs[1], + label=target_style_name, + color=target_style_color, + center_line=True if sidx == 0 else False, + param_text=True, + ) + + if args.save: + for output_name, output_audio in outputs.items(): + output_filepath = os.path.join( + args.output_dir, + f"style_transfer_{output_name}.wav", + ) + + # normalize to target loudness + torchaudio.save( + output_filepath, + loudness_normalize(output_audio, args.target_loudness), + y_sr, + ) + + # --------- formating for Parametric EQ ---------= + plot_filepath = os.path.join( + args.output_dir, + f"style_transfer_{input_style_name}", + ) + plt.title(f"{input_style_name} as input") + axs[0].set_ylim([-24, 24]) + axs[0].set_xlim([10, 10000]) + axs[0].set_xscale("log") + axs[0].grid(c="lightgray", which="major") + axs[0].grid(c="lightgray", which="minor") + axs[0].set_ylabel("Magnitude (dB)") + axs[0].set_xlabel("Frequency (Hz)") + axs[0].spines["right"].set_visible(False) + axs[0].spines["left"].set_visible(False) + axs[0].spines["top"].set_visible(False) + axs[0].spines["bottom"].set_visible(False) + axs[0].tick_params(axis="x", which="minor", colors="lightgray", labelcolor="k") + axs[0].tick_params(axis="x", which="major", colors="lightgray", labelcolor="k") + axs[0].tick_params(axis="y", which="major", colors="lightgray", labelcolor="k") + axs[0].legend( + ncol=4, + loc="lower center", + columnspacing=0.8, + framealpha=1.0, + bbox_to_anchor=(0.5, 1.05), + ) + # --------- formating for compressor curve --------- + axs[1].set_ylim([-80, 0]) + axs[1].set_xlim([-80, 0]) + axs[1].grid(c="lightgray", which="major") + axs[1].spines["right"].set_visible(False) + axs[1].spines["left"].set_visible(False) + axs[1].spines["top"].set_visible(False) + axs[1].spines["bottom"].set_visible(False) + axs[1].set_ylabel("Output (dB)") + axs[1].set_xlabel("Input (dB)") + axs[1].tick_params(axis="x", which="major", colors="lightgray", labelcolor="k") + axs[1].tick_params(axis="y", which="major", colors="lightgray", labelcolor="k") + axs[1].set(adjustable="box", aspect="equal") + + plt.tight_layout() + plt.savefig(plot_filepath + ".png", dpi=300) + plt.savefig(plot_filepath + ".svg") + plt.savefig(plot_filepath + ".pdf") diff --git a/scripts/test_ckpt.py b/scripts/test_ckpt.py new file mode 100644 index 0000000..460f7c0 --- /dev/null +++ b/scripts/test_ckpt.py @@ -0,0 +1,77 @@ +import os +import sys +import glob +import torch +import pickle +import pytorch_lightning as pl +import deepafx_st + +# sys.modules["deepafx_st"] = deepafx_st # patch for name change + +if __name__ == "__main__": + + checkpoint_dir = "checkpoints" + + for experiment in ["probes", "style", "proxies"]: + + for v in [0, 1, 2]: + ckpt_paths = glob.glob( + os.path.join( + checkpoint_dir, + experiment, + "**", + "**", + "lightning_logs", + f"version_{v}", + "checkpoints", + "*.ckpt", + ) + ) + + for ckpt_path in ckpt_paths: + print(ckpt_path) + + processor_model_id = ckpt_path.split("/")[-5] + print(processor_model_id) + + if "m" in processor_model_id: + peq_ckpt = "checkpoints/proxies/jamendo/peq/lightning_logs/version_0/checkpoints/epoch=326-step=204374-val-jamendo-peq.ckpt" + comp_ckpt = "checkpoints/proxies/jamendo/comp/lightning_logs/version_0/checkpoints/epoch=274-step=171874-val-jamendo-comp.ckpt" + else: + peq_ckpt = "checkpoints/proxies/libritts/peq/lightning_logs/version_1/checkpoints/epoch=111-step=139999-val-libritts-peq.ckpt" + comp_ckpt = "checkpoints/proxies/libritts/comp/lightning_logs/version_1/checkpoints/epoch=255-step=319999-val-libritts-comp.ckpt" + + proxy_ckpts = [peq_ckpt, comp_ckpt] + + if experiment == "style": + model = deepafx_st.system.System.load_from_checkpoint( + ckpt_path, + proxy_ckpts=proxy_ckpts, + strict=False, + ) + elif experiment == "probes": + if "speech" in ckpt_path: + deepafx_st_autodiff_ckpt = "checkpoints/style/libritts/autodiff/lightning_logs/version_1/checkpoints/epoch=367-step=1226911-val-libritts-autodiff.ckpt" + deepafx_st_spsa_ckpt = "checkpoints/style/libritts/spsa/lightning_logs/version_2/checkpoints/epoch=367-step=1226911-val-libritts-spsa.ckpt" + deepafx_st_proxy0_ckpt = "checkpoints/style/libritts/proxy0/lightning_logs/version_0/checkpoints/epoch=327-step=1093551-val-libritts-proxy0.ckpt" + elif "music" in ckpt_path: + deepafx_st_autodiff_ckpt = "checkpoints/style/jamendo/autodiff/lightning_logs/version_0/checkpoints/epoch=362-step=1210241-val-jamendo-autodiff.ckpt" + deepafx_st_spsa_ckpt = "checkpoints/style/jamendo/spsa/lightning_logs/version_0/checkpoints/epoch=362-step=1210241-val-jamendo-spsa.ckpt" + deepafx_st_proxy0_ckpt = "checkpoints/style/jamendo/proxy0/lightning_logs/version_0/checkpoints/epoch=362-step=1210241-val-jamendo-proxy0.ckpt" + + model = ( + deepafx_st.probes.probe_system.ProbeSystem.load_from_checkpoint( + ckpt_path, + strict=False, + deepafx_st_autodiff_ckpt=deepafx_st_autodiff_ckpt, + deepafx_st_spsa_ckpt=deepafx_st_spsa_ckpt, + deepafx_st_proxy0_ckpt=deepafx_st_proxy0_ckpt, + ) + ) + elif experiment == "proxies": + model = deepafx_st.processors.proxy.proxy_system.ProxySystem.load_from_checkpoint( + ckpt_path, + strict=False, + ) + else: + raise RuntimeError(f"Invalid experiment: {experiment}") diff --git a/scripts/timing.py b/scripts/timing.py new file mode 100644 index 0000000..2e3fd84 --- /dev/null +++ b/scripts/timing.py @@ -0,0 +1,424 @@ +import torch +import auraloss +import torchaudio +import numpy as np +import scipy.signal +from tqdm import tqdm +from itertools import chain +from time import perf_counter + +from deepafx_st.models.encoder import SpectralEncoder +from deepafx_st.models.controller import StyleTransferController +from deepafx_st.processors.autodiff.channel import AutodiffChannel +from deepafx_st.processors.proxy.channel import ProxyChannel +from deepafx_st.processors.dsp.compressor import Compressor +from deepafx_st.processors.dsp.peq import ParametricEQ +from deepafx_st.processors.spsa.channel import SPSAChannel +from deepafx_st.utils import DSPMode, count_parameters +from deepafx_st.processors.dsp.compressor import compressor + + +def run_dsp(x, peq_p, comp_p, peq, comp): + + x = peq(x, peq_p) + x = comp(x, comp_p) + + return x + + +if __name__ == "__main__": + + sample_rate = 24000 + n_iters = 1000 + length_sec = 5 + bs = 4 + length_samp = sample_rate * length_sec + + # loss + mrstft_loss = auraloss.freq.MultiResolutionSTFTLoss( + fft_sizes=[32, 128, 512, 2048, 8192, 32768], + hop_sizes=[16, 64, 256, 1024, 4096, 16384], + win_lengths=[32, 128, 512, 2048, 8192, 32768], + w_sc=0.0, + w_phs=0.0, + w_lin_mag=1.0, + w_log_mag=1.0, + ) + + # dsp effects + peq_dsp = ParametricEQ(sample_rate) + comp_dsp = Compressor(sample_rate) + + # autodiff effects + channel_ad = AutodiffChannel(sample_rate) + + spsa = SPSAChannel(sample_rate, True, bs) + + # proxy channel tcn 1 + np_norm = ProxyChannel( + [], + freeze_proxies=True, + dsp_mode=DSPMode.NONE, + tcn_nblocks=4, + tcn_dilation_growth=8, + tcn_channel_width=64, + tcn_kernel_size=13, + num_tcns=1, + sample_rate=sample_rate, + ) + + # proxy channel tcn 1 + np_hh = ProxyChannel( + [], + freeze_proxies=True, + dsp_mode=DSPMode.INFER, + tcn_nblocks=4, + tcn_dilation_growth=8, + tcn_channel_width=64, + tcn_kernel_size=13, + num_tcns=1, + sample_rate=sample_rate, + ) + + # proxy channel tcn 1 + np_fh = ProxyChannel( + [], + freeze_proxies=True, + dsp_mode=DSPMode.TRAIN_INFER, + tcn_nblocks=4, + tcn_dilation_growth=8, + tcn_channel_width=64, + tcn_kernel_size=13, + num_tcns=1, + sample_rate=sample_rate, + ) + + # proxy channel tcn 1 + tcn1 = ProxyChannel( + [], + freeze_proxies=False, + dsp_mode=DSPMode.NONE, + tcn_nblocks=4, + tcn_dilation_growth=8, + tcn_channel_width=64, + tcn_kernel_size=13, + num_tcns=1, + sample_rate=sample_rate, + ) + + # proxy channel tcn 2 + tcn2 = ProxyChannel( + [], + freeze_proxies=False, + dsp_mode=DSPMode.NONE, + tcn_nblocks=4, + tcn_dilation_growth=8, + tcn_channel_width=64, + tcn_kernel_size=13, + num_tcns=2, + sample_rate=sample_rate, + ) + + # predictor models + encoder = SpectralEncoder( + channel_ad.num_control_params, + sample_rate, + encoder_model="efficient_net", + embed_dim=1024, + width_mult=1, + ) + controller = StyleTransferController( + channel_ad.num_control_params, + 1024, + # bottleneck_dim=-1, + ) + + print() + + # iterate + for model in [ + "rb_infer", + "dsp_infer", + "autodiff_cpu_infer", + "autodiff_gpu_infer", + "tcn1_cpu_infer", + "tcn2_cpu_infer", + "tcn1_gpu_infer", + "tcn2_gpu_infer", + "autodiff_gpu_grad", + "np_norm_gpu_grad", + "np_hh_gpu_grad", + "np_fh_gpu_grad", + "tcn1_gpu_grad", + "tcn2_gpu_grad", + "spsa_gpu_grad", + ]: + timings = [] + for n in tqdm(range(n_iters), ncols=80): + if "grad" in model: + eff_bs = bs + else: + eff_bs = 1 + if model == "rb_infer": + if n == 0: + p = torch.rand( + eff_bs, channel_ad.num_control_params, requires_grad=True + ) + x = torch.randn(eff_bs, 1, length_samp) + y = torch.randn(eff_bs, 1, length_samp) + + n_fft = 65536 + freqs = np.linspace(0, 1.0, num=(n_fft // 2) + 1) + response = np.random.rand(n_fft // 2 + 1) + response[-1] = 0.0 # zero gain at nyquist + b = scipy.signal.firwin2( + 63, + freqs * (sample_rate / 2), + response, + fs=sample_rate, + ) + + t1_start = perf_counter() + + x_filt = scipy.signal.lfilter(b, [1.0], x.numpy()) + x_filt = torch.tensor(x_filt.astype("float32")) + + with torch.inference_mode(): + x_comp_new = compressor( + x_filt.view(-1).numpy(), + sample_rate, + threshold=-12, + ratio=3, + attack_time=0.001, + release_time=0.05, + knee_dB=6.0, + makeup_gain_dB=0.0, + ) + + t1_stop = perf_counter() + + if model == "dsp_infer": + if n == 0: + params = 0 + x = np.random.rand(length_samp) + peq_p = np.random.rand(peq_dsp.num_control_params) + comp_p = np.random.rand(comp_dsp.num_control_params) + t1_start = perf_counter() + y = run_dsp(x, peq_p, comp_p, peq_dsp, comp_dsp) + t1_stop = perf_counter() + elif "autodiff" in model: + if n == 0: + params = 0 + p = torch.rand( + eff_bs, channel_ad.num_control_params, requires_grad=True + ) + x = torch.randn(eff_bs, 1, length_samp) + y = torch.randn(eff_bs, 1, length_samp) + optimizer = torch.optim.Adam( + chain( + encoder.parameters(), + controller.parameters(), + ), + lr=1e-3, + ) + + if "gpu" in model: + p = p.to("cuda") + x = x.to("cuda") + y = y.to("cuda") + if "grad" in model: + encoder.to("cuda") + controller.to("cuda") + + if "grad" in model: + t1_start = perf_counter() + e_x = encoder(x) + e_y = encoder(y) + p = controller(e_x, e_y) + y_hat = channel_ad(x, p) + loss = mrstft_loss(y_hat, x) + loss.backward() + optimizer.step() + t1_stop = perf_counter() + else: + with torch.inference_mode(): + t1_start = perf_counter() + y = channel_ad(x, p) + t1_stop = perf_counter() + + elif "tcn1" in model: + if n == 0: + params = count_parameters(tcn1) + p = torch.rand( + eff_bs, + channel_ad.num_control_params, + requires_grad=False, + ) + x = torch.randn(eff_bs, 1, length_samp) + y = torch.randn(eff_bs, 1, length_samp) + optimizer = torch.optim.Adam( + chain( + encoder.parameters(), + controller.parameters(), + tcn1.parameters(), + ), + lr=1e-3, + ) + + if "gpu" in model: + p = p.to("cuda") + x = x.to("cuda") + y = y.to("cuda") + tcn1.to("cuda") + + if "grad" in model: + encoder.to("cuda") + controller.to("cuda") + else: + tcn1.to("cpu") + + if "grad" in model: + t1_start = perf_counter() + e_x = encoder(x) + e_y = encoder(y) + p = controller(e_x, e_y) + y_hat = tcn1(x, p) + loss = mrstft_loss(y_hat, x) + loss.backward() + optimizer.step() + t1_stop = perf_counter() + else: + with torch.inference_mode(): + t1_start = perf_counter() + y = tcn1(x, p) + t1_stop = perf_counter() + + elif "tcn2" in model: + if n == 0: + params = count_parameters(tcn2) + p = torch.rand( + eff_bs, channel_ad.num_control_params, requires_grad=True + ) + x = torch.randn(eff_bs, 1, length_samp) + y = torch.randn(eff_bs, 1, length_samp) + optimizer = torch.optim.Adam( + chain( + encoder.parameters(), + controller.parameters(), + tcn2.parameters(), + ), + lr=1e-3, + ) + + if "gpu" in model: + p = p.to("cuda") + x = x.to("cuda") + y = y.to("cuda") + tcn2.to("cuda") + + if "grad" in model: + encoder.to("cuda") + controller.to("cuda") + + if "grad" in model: + t1_start = perf_counter() + e_x = encoder(x) + e_y = encoder(y) + p = controller(e_x, e_y) + y_hat = tcn2(x, p) + loss = mrstft_loss(y_hat, x) + loss.backward() + optimizer.step() + t1_stop = perf_counter() + else: + with torch.inference_mode(): + t1_start = perf_counter() + y = tcn2(x, p) + t1_stop = perf_counter() + elif "np" in model: + if n == 0: + p = torch.rand( + eff_bs, channel_ad.num_control_params, requires_grad=True + ) + x = torch.randn(eff_bs, 1, length_samp) + y = torch.randn(eff_bs, 1, length_samp) + optimizer = torch.optim.Adam( + chain( + encoder.parameters(), + controller.parameters(), + ), + lr=1e-3, + ) + + if "gpu" in model: + p = p.to("cuda") + x = x.to("cuda") + y = y.to("cuda") + if "grad" in model: + encoder.to("cuda") + controller.to("cuda") + np_norm.to("cuda") + np_fh.to("cuda") + np_hh.to("cuda") + + if "grad" in model: + t1_start = perf_counter() + e_x = encoder(x) + e_y = encoder(y) + p = controller(e_x, e_y) + + if "fh" in model: + y_hat = np_fh(x, p) + elif "hh" in model: + y_hat = np_hh(x, p) + else: + y_hat = np_norm(x, p) + + loss = mrstft_loss(y_hat, x) + loss.backward() + optimizer.step() + t1_stop = perf_counter() + + elif "spsa" in model: + if n == 0: + p = torch.rand( + eff_bs, channel_ad.num_control_params, requires_grad=True + ) + x = torch.randn(eff_bs, 1, length_samp) + y = torch.randn(eff_bs, 1, length_samp) + optimizer = torch.optim.Adam( + chain( + encoder.parameters(), + controller.parameters(), + ), + lr=1e-3, + ) + if "gpu" in model: + p = p.to("cuda") + x = x.to("cuda") + y = y.to("cuda") + if "grad" in model: + encoder.to("cuda") + controller.to("cuda") + spsa.to("cuda") + + if "grad" in model: + t1_start = perf_counter() + e_x = encoder(x) + e_y = encoder(y) + p = controller(e_x, e_y) + y_hat = spsa(x, p) + loss = mrstft_loss(y_hat, x) + loss.backward() + optimizer.step() + t1_stop = perf_counter() + + elapsed = t1_stop - t1_start + timings.append(elapsed) + + # remove the first time + timings = timings[10:] + + rtf = np.mean(timings) / length_sec + sec_per_step = np.mean(timings) + print(f"{model} : sec/step {sec_per_step:0.4f} {rtf:0.4f} RTF") diff --git a/scripts/train_probe.py b/scripts/train_probe.py new file mode 100755 index 0000000..b93b63b --- /dev/null +++ b/scripts/train_probe.py @@ -0,0 +1,35 @@ +import os +import torch +import pytorch_lightning as pl +from argparse import ArgumentParser + +from deepafx_st.probes.probe_system import ProbeSystem + +torch.backends.cudnn.benchmark = True +pl.seed_everything(42) + +# some arg parse for configuration +parser = ArgumentParser() + +# add all the available trainer and system options to argparse +parser = pl.Trainer.add_argparse_args(parser) +parser = ProbeSystem.add_model_specific_args(parser) + +# parse them args +args = parser.parse_args() + +# setup callbacks +callbacks = [ + pl.callbacks.ModelCheckpoint( + monitor="val_f1_epoch", + mode="max", + filename="{epoch}-{step}-val-" + f"{args.encoder_type}-{args.probe_type}", + ), +] + +# create PyTorch Lightning trainer +trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks) +system = ProbeSystem(**vars(args)) + +# train! +trainer.fit(system) diff --git a/scripts/train_proxy.py b/scripts/train_proxy.py new file mode 100755 index 0000000..6e1befe --- /dev/null +++ b/scripts/train_proxy.py @@ -0,0 +1,38 @@ +import os +import torch +import pytorch_lightning as pl +from argparse import ArgumentParser + +from deepafx_st.processors.proxy.proxy_system import ProxySystem +from deepafx_st.callbacks.audio import LogAudioCallback + +torch.backends.cudnn.benchmark = True +pl.seed_everything(42) + +# some arg parse for configuration +parser = ArgumentParser() + +# add all the available trainer and system options to argparse +parser = pl.Trainer.add_argparse_args(parser) +parser = ProxySystem.add_model_specific_args(parser) + +# parse them args +args = parser.parse_args() + +dataset_name = args.default_root_dir.split(os.sep)[-2] + +# setup callbacks +callbacks = [ + LogAudioCallback(), + pl.callbacks.ModelCheckpoint( + monitor="val_loss", + filename="{epoch}-{step}-val-" + f"{dataset_name}-{args.processor}", + ), +] + +# create PyTorch Lightning trainer +trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks) +system = ProxySystem(**vars(args)) + +# train! +trainer.fit(system) diff --git a/scripts/train_style.py b/scripts/train_style.py new file mode 100755 index 0000000..4193d03 --- /dev/null +++ b/scripts/train_style.py @@ -0,0 +1,71 @@ +import os +import torch +import pytorch_lightning as pl +from argparse import ArgumentParser +from pytorch_lightning.plugins import DDPPlugin + +from deepafx_st.system import System +from deepafx_st.utils import system_summary +from deepafx_st.callbacks.audio import LogAudioCallback +from deepafx_st.callbacks.params import LogParametersCallback +from deepafx_st.callbacks.ckpt import CopyPretrainedCheckpoints + +if __name__ == "__main__": + + torch.multiprocessing.set_start_method("spawn") + + torch.backends.cudnn.benchmark = True + pl.seed_everything(42) + + # some arg parse for configuration + parser = ArgumentParser() + + # add all the available trainer and system options to argparse + parser = pl.Trainer.add_argparse_args(parser) + parser = System.add_model_specific_args(parser) + + # parse them args + args = parser.parse_args() + + # Checkpoint on the first reconstruction loss + args.train_monitor = f"train_loss/{args.recon_losses[-1]}" + args.val_monitor = f"val_loss/{args.recon_losses[-1]}" + + dataset_name = args.default_root_dir.split(os.sep)[-2] + + # setup callbacks + callbacks = [ + LogAudioCallback(), + pl.callbacks.LearningRateMonitor(logging_interval="step"), + pl.callbacks.ModelCheckpoint( + monitor=args.train_monitor, + filename="{epoch}-{step}-train-" + f"{dataset_name}-{args.processor_model}", + ), + pl.callbacks.ModelCheckpoint( + monitor=args.val_monitor, + filename="{epoch}-{step}-val-" + f"{dataset_name}-{args.processor_model}", + ), + CopyPretrainedCheckpoints(), + ] + + if args.processor_model != "tcn": + callbacks.append(LogParametersCallback()) + + # create PyTorch Lightning trainer + trainer = pl.Trainer.from_argparse_args( + args, + callbacks=callbacks, + plugins=DDPPlugin(find_unused_parameters=False), + ) + + # create the System + system = System(**vars(args)) + + # print details about the model + system_summary(system) + + # train! + trainer.fit(system) + + # close threads + del system.processor diff --git a/setup.py b/setup.py new file mode 100755 index 0000000..644163a --- /dev/null +++ b/setup.py @@ -0,0 +1,45 @@ +from setuptools import setup +from importlib.machinery import SourceFileLoader + +with open("README.md") as file: + long_description = file.read() + +version = SourceFileLoader("deepafx_st.version", "deepafx_st/version.py").load_module() + +setup( + name="deepafx-st", + version=version.version, + description="DeepAFx-ST", + author="See paper", + author_email="See paper", + url="https://github.com/adobe-research/DeepAFx-ST", + packages=["deepafx_st"], + long_description=long_description, + long_description_content_type="text/markdown", + license="Copyright Adobe Inc.", + install_requires=[ + "torch==1.9.0", + "torchaudio==0.9.0", + "torchmetrics>=0.4.1", + "torchvision==0.10.0", + "audioread>=2.1.9", + "auraloss>=0.2.1", + "librosa>=0.8.1", + "matplotlib", + "numpy", + "pytorch-lightning>=1.4.0", + "SoundFile>=0.10.3.post1", + "sox>=1.4.1", + "tensorboard>=2.4.1", + "scikit-learn>=0.24.2", + "scipy", + "pyloudnorm>=0.1.0", + "julius>=0.2.6", + "torchopenl3", + "cdpam", + "wget", + "pesq", + "umap-learn", + "setuptools==58.2.0" + ], +)