This repository is the official implementation of our paper Compositional Generalization by Learning Analytical Expressions.
If you find our code useful for you, please consider citing our paper
@inproceedings{qian2020compositional,
title={Compositional Generalization by Learning Analytical Expressions},
author={Qian Liu and Shengnan An and Jian{-}Guang Lou and Bei Chen and Zeqi Lin and Yan Gao and Bin Zhou and Nanning Zheng and Dongmei Zhang},
booktitle={Advances in Neural Information Processing Systems 33: Annual Conference on Neural Information Processing Systems 2020, NeurIPS 2020, December 6-12, 2020, virtual},
year={2020}
}
- Install Requirements
- Train Model
- Evaluate Model
- Pre-trained Models
- Expected Results
- Frequent Asked Questions
Our code is officially supported by Python 3.7. The main dependencies are pytorch
and tensorboardX
.
You could install all requirements by the following command:
❱❱❱ pip install -r requirements.txt
To train our model on different tasks on SCAN and SCAN-ext datasets, you could use this command:
❱❱❱ python main.py --mode train --checkpoint <model_dir> --task <task_name>
📋 Note that <model_dir>
specifies the store folder of model checkpoints, and <task_name>
is the task name.
Available task names are [simple, addjump, around_right, length, mcd1, mcd2, mcd3, extend]
.
For example, you could train a model on addjump
task by the following command:
❱❱❱ python main.py --mode train --checkpoint addjump_model --task addjump
📋 Since reinforcement learning is known to be hard to train, there is a chance of the code to not converge in the training. You could choose another random seed and try again.
📋 Meanwhile, please note that the model training is sensitive to the value of the hyper-parameter coefficient of the simplicity-based reward (i.e. --simplicity-ratio
in args). When it is higher (i.e. 0.5 or 1.0), the model is harder to converge, which indicates that the training accuracy may not arrive at 100%. We're still investigating in the reason behind it. If you cannot obtain good results after trying several random seed, you could try to reproduce other results (not suitable for around_right
and mcd3
, as stated in the paper) using a 0
simplicity-ratio (default setting now). We will update the code when we find a better training strategy.
Therefore, please use the following command for around_right
and mcd3
task:
❱❱❱ python main.py --mode train --checkpoint addjump_model --task around_right --simplicity-ratio 0.5
The corresponding log and model weights will be stored in the path checkpoint/logs/addjump_model.log
and checkpoint/models/addjump_model/*.mdl
respectively
To evaluate our model on different tasks, run:
❱❱❱ python main.py --mode test --checkpoint <model_weight_file> --task <task_name>
📋 Note that <model_weight_file>
specifies a concrete model file with the suffix .mdl
, and <task_name>
is the task name.
For example, you could evaluate a trained model weight weight.mdl
on addjump
task by the following command:
❱❱❱ python main.py --mode test --checkpoint weight.mdl --task addjump
You can find pretrained model weights for the above tasks under the pretrained_weights
folder.
Our model is excepted to achieve 100% accuracies on all tasks if the training succeeds.