Skip to content

YAIxPOZAlabs/Improving-TrXL-for-ComMU

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

86 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

header

Improving Transformer-XL for Music Generation

logo

This project was carried out by YAI 11th, in cooperation with POZAlabs.


Gmail NOTION REPORT



Improving Transformer-XL for Music Generation 🎼

YAI x POZAlabs μ‚°ν•™ν˜‘λ ₯ 1νŒ€
NLP model for music generation

Members πŸ‘‹

μ‘°μ •λΉˆΒ  :Β  YAI 9thΒ  /Β [email protected]
κΉ€λ―Όμ„œΒ  :Β  YAI 8thΒ  /Β  [email protected]
κΉ€μ‚°Β  :Β  YAI 9thΒ  /Β  [email protected]
κΉ€μ„±μ€€Β  :Β  YAI 10thΒ  /Β [email protected]
λ°•λ―Όμˆ˜Β  :Β  YAI 9thΒ  /Β [email protected]
λ°•μˆ˜λΉˆΒ  :Β  YAI 9thΒ  /Β [email protected]



Getting Started πŸ”₯

As there are different models and metrics, we recommand using seperate virtual envs for each. As each directory contains it's own "Getting Started", for clear instructions, please follow the links shown in each section.

Improving-TrXL-for-ComMU/
β”œβ”€ CAS/
β”œβ”€ Group_Encoding/
β”œβ”€ Soft_Labeling/
β”œβ”€ TransformerCVAE/

As for Baseline which is Transformer-XL trained on ComMU-Dataset, refer to the ComMU-code by POZAlabs

Building on Transformer-XL πŸ—οΈ

0. Baseline (Transformer-XL) - Link

Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context

Evaluation

Classifcation Accuracy Score

Lable(Meta) Real Model Fake Model error rate
BPM 0.6291 0.6159 0.0210
KEY 0.8781 0.8781 0
TIMESIGNATURE 0.9082 0.8925
PITCHRANGE 0.7483 0.7090 0.0525
NUMEIEROFMEASURE 1.0 1.0
INSTRUMENT 0.5858 0.5923
GENRE 0.8532 0.8427 0.0123
MINVELOCITY 0.4718 0.4482
MAXVELOCITY 0.4718 0.4495
TRACKROLE 0.6500 0.5753 0.1149
RHYTHM 0.9934 0.9934

Normalized Mean CASD : 0.0401

1. Group Encoding - Link

For a vanila transformer-XL model, it inputs tokens in a 1d sequence and adds Positional Encoding to give the model information about the position between tokens. In this setting, the model learns about the semantics of the data as well as the structure of the MIDI data. However, as there is an explicit pattern when encoding MIDI data in to sequence of tokens, we propose a Group Encoding method that injects an inductive bias about the explicit structure of the token sequence to the model. This not only keeps the model from inferencing strange tokens in strange positions, it also allows the model to generate 4 tokens in a single feed forward, boosting the training speed as well as the inference speed of the model.

GE

GE_1

Evaluation

Controllability and Diversity

CP CV(Midi) CV(Note) CH Diversity
Transformer XL w/o GE 0.8585 0.8060 0.9847 0.9891 0.4100
Transformer XL w GE 0.8493 0.7391 0.9821 0.9839 0.4113

Classification Accuracy Score

Lable(Meta) Real Model Fake Model error rate
BPM 0.6291 0.5910 0.0606
KEY 0.8781 0.8532 0.0284
TIMESIGNATURE 0.9082 0.8951
PITCHRANGE 0.7483 0.7195 0.0385
NUMEIEROFMEASURE 1.0 1.0
INSTRUMENT 0.5858 0.5884
GENRE 0.8532 0.8532 0
MINVELOCITY 0.4718 0.4364
MAXVELOCITY 0.4718 0.4560
TRACKROLE 0.6500 0.5360 0.1754
RHYTHM 0.9934 0.9934

Normalized Mean CASD : 0.0605

Inference Speed

Inference time for Valset Inference speed per sample relative speed up
Transformer XL w/o GE 1189.4s 1.558s per sample X1
Transformer XL w GE 692.2s 0.907s per sample X1.718

Sampled Audio

5 note sequences with shared and different meta data were sampled by the following conditions and mixed together.

  • Shared meta data acrross 5 samples

    • audio_key : aminor
    • chord_progressions : [['Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'Dm', 'Dm', 'Dm', 'Dm', 'Dm', 'Dm', 'Dm', 'Dm', 'Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'Am', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'G', 'D', 'D', 'D', 'D', 'D', 'D', 'D', 'D']]
    • time_signature : 4/4
    • genre : cinematic
    • bpm : 120
    • rhythm : standard
  • Different meta data for each instrument

    • riff string_violin mid_high standard
    • main_melody string_ensemble mid_high
    • sub_melody string_cello very_low
    • pad acoustic_piano mid_low
    • sub_melody brass_ensemble mid
sample.mov

2. Soft Labeling - Link

To prevent overfitting of the model, techniques such as soft labeling are often used. We apply soft labeling on velocity, duration, and position information, so it can be flexibly predicted. For example, if the target of the token value is 300, the logit is reconstructed by referring to the logit value of the 298/299/301/302 token. As a result of using soft labeling, we confirm that the token appears more flexible than baseline.

softlabeling

Evaluation

Test set NLL

n-2 n-1 n n+1 n+2 test NLL
0 0 1 0 0 0.96
0.1 0.1 0.6 0.1 0.1 1.01
0 0.15 0.7 0.15 0 1.05
0.1 0.2 0.4 0.2 0.1 1.26

Classification Accuracy Score

Lable(Meta) Real Model Fake Model error rate
BPM 0.6291 0.6133 0.0251
KEY 0.8781 0.8741 0.0046
TIMESIGNATURE 0.9082 0.8990
PITCHRANGE 0.7483 0.7195 0.0385
NUMEIEROFMEASURE 1.0 1.0
INSTRUMENT 0.5858 0.5740
GENRE 0.8532 0.8440 0.0108
MINVELOCITY 0.4718 0.4429
MAXVELOCITY 0.4718 0.4429
TRACKROLE 0.6500 0.5661 0.1291
RHYTHM 0.9934 0.9934

Normalized Mean CASD: 0.0416

Controllability and Diversity

CP CV(Midi) CV(Note) CH Diversity
Transformer XL w/o SL 0.8585 0.8060 0.9847 0.9891 0.4100
Transformer XL w SL 0.8807 0.8007 0.9861 0.9891 0.4134

3. Gated Transformer-XL - Link


Dataset 🎢

ComMU (POZAlabs)

ComMU-code has clear instructions on how to download and postprocess ComMU-dataset, but we also provide a postprocessed dataset for simplicity. To download preprocessed and postprocessed data, run

cd ./dataset && ./download.sh && cd ..

Metrics πŸ“‹

To evaluate generation models we have to generate data with trained models and depending on what metrics we want to use, the generation proccess differ. Please refer to the explanations below to generate certain samples needed for evaluation.

Generating Samples for Evaluation

for CAS we generate samples based on traing meta data and for Diversity & Controllability we generate samples based on validation meta data.

for transformer-XL with GE, use

Improving-TrXL-for-ComMU/
β”œβ”€ Group_Encoding/
    β”œβ”€ generate_GE.py

for transformer-XL baseline and SL, use

Improving-TrXL-for-ComMU/
β”œβ”€ generate_SL.py

note that generate_SL.py should be placed inside ComMU-code as for SL does not change the model structure or inference mechanism.

Classification Accuracy Score - Link

Evaluating Generative Models is an open problem and for Music generation has not been well defined. Inspired by 'Classification Accuracy Score for Conditional Generative Models' we use CAS as an evaluation metric for out music generation models. THe procedure of our CAS is the following

  1. Train a Music Generation Model with ComMU train set which we call 'Real Dataset'
  2. Generate samples and form a dataset which we call 'Fake Dataset'
  3. Train a classification model with 'Real Dataset' which we call 'Real Model'
  4. Train a classification model with 'Fake Dataset' which we call 'Fake Model'
  5. For each lable (meta data) we compare the performance of 'Fake Model' and 'Real Model' on ComMU validation set

From the above procedure we can obtain CAS for a certain label (meta) we want to evaluate. If the difference between the accuracy of the 'Fake Model' and 'Real Model' is low, it means our generation model has captured the data distribution w.r.t the certain label well. For our experiments on vanila Transformer-XL, Transformer-XL with GE and Transformer-XL with SL, we calculate CAS on all 11 labels. However, some labels such as Number of Measure, Time Signature or Rhythm are usuited for evaluation. Therfore we select BPM, KEY, PITCH RANGE, GENRE and TRACK-ROLE and calculate the Normalized Mean Classification Accuracy Difference Score denoting it as CADS. We obtain CADS as the following.

GE

where N is the number of labels(meta) that we think are relevent, in this case 5, and $R_i$ and $F_i$ denotes Real model accuracy for label num i and fake model accuracy for label num i respectively.

The following figure is the overal pipeline of CAS

  • To compute Classicication Accuracy Score of Generated Music conditioned with certain meta data

to generate samples for SL and baseline, run

$ python generate_SL.py --checkpoint_dir {./model_checkpoint} --meta_data {./train_meta_data.csv} --eval_diversity {False} --out_dir {./train_out}

for GE, run

$ python generate_GE.py --checkpoint_dir {./checkpoint_best.pt} --meta_data {./train_meta_data.csv} --eval_diversity {False} --out_dir {./train_out}

to compute CAS for certain meta data as label, run

$ python evaluate_resnet.py --midi_dir {./data.npy} --meta_dir {./meta.npy} --meta_num {meta_num}

to compute CAS for all meta data as label, run

$ python evaluate_resnet_all.py --midi_dir {./data.npy} --meta_dir {./meta.npy}

Diversity & Controllability

  • To compute the Diversity of Generated Music conditioned with certain meta data

to generate samples for SL and baseline, run

$ python generate_SL.py --checkpoint_dir {./model_checkpoint} --meta_data {./val_meta_data.csv} --eval_diversity {True} --out_dir {./val_out}

for GE, run

$ python generate_GE.py --checkpoint_dir {./checkpoint_best.pt} --meta_data {./val_meta_data.csv} --eval_diversity {True} --out_dir {./val_out}

First, you should modifty eval_config.py after then,

to compute Diversity run,

$ python ./commu_eval/commu_eval_diversity.py

to compute Controllability,

$ python ./commu_eval/commu_eval_controllability.py

Skills

Frameworks

Citations

@misc{dai2019transformerxl,
      title={Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context}, 
      author={Zihang Dai and Zhilin Yang and Yiming Yang and Jaime Carbonell and Quoc V. Le and Ruslan Salakhutdinov},
      year={2019},
      eprint={1901.02860},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
@misc{https://doi.org/10.48550/arxiv.1905.10887,
  doi = {10.48550/ARXIV.1905.10887},
  url = {https://arxiv.org/abs/1905.10887},
  author = {Ravuri, Suman and Vinyals, Oriol},
  keywords = {Machine Learning (cs.LG), Machine Learning (stat.ML), FOS: Computer and information sciences, FOS: Computer and information sciences},
  title = {Classification Accuracy Score for Conditional Generative Models},
  publisher = {arXiv},
  year = {2019},
  copyright = {arXiv.org perpetual, non-exclusive license}
}
@inproceedings{hyun2022commu,
  title={Com{MU}: Dataset for Combinatorial Music Generation},
  author={Lee Hyun and Taehyun Kim and Hyolim Kang and Minjoo Ki and Hyeonchan Hwang and Kwanho Park and Sharang Han and Seon Joo Kim},
  booktitle={Thirty-sixth Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
  year={2022},
}

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •  

Languages