This repo contains preliminary code of the EMNLP2020 paper (Findings) named "Group-wise Contrastive Learning for Neural Dialogue Generation".
This codebase is built upon the ParlAI project (Thanks for their pioneering contributions on developing such a great conversational platform!).
Check parlai/agents/contrastive_learning
for framework implementations.
Running scripts can be found in projects/contrastive_learning
.
- Python3
- Pytorch 1.2 or newer
Dependencies of the core modules are listed in requirement.txt.
git clone [email protected]:hengyicai/ContrastiveLearning4Dialogue.git ~/ContrastiveLearning4Dialogue
cd ~/ContrastiveLearning4Dialogue; python setup.py develop
echo "export PARLAI_HOME=~/ContrastiveLearning4Dialogue" >> ~/.bashrc; source ~/.bashrc
Download PersonaChat/OpenSubtitles/Douban and untar them to ${PARLAI_HOME}/data/
as:
data
├── DoubanConversaionCorpus
│ ├── douban.embed.vec
│ ├── test.txt
│ ├── train.txt
│ ├── train.txt.lengths
│ └── valid.txt
├── OpenSubExtend
│ ├── opensub_extend.embed.vec
│ ├── test.txt
│ ├── train.txt
│ ├── train.txt.lengths
│ └── valid.txt
└── PersonaChatExtend
├── personachat_extend.embed.vec
├── test.txt
├── train.txt
├── train.txt.lengths
└── valid.txt
cd ~/ContrastiveLearning4Dialogue
bash projects/contrastive_learning/shell/run.sh
The last line of projects/contrastive_learning/shell/run.sh
specifies preliminary arguments for the training:
# MODEL_NAME TO_MINIMIZE TASK PRETRAIN_STEPS SAMPLE_K CONTRAST_BY NAIVE_NEG_SAMPLING CL_THRESHOLD CL_ANNEAL ANNEAL_SPEED
export CUDA_VISIBLE_DEVICES=0; train_model cl_seq2seq to_minimize personachat_extend 5000 6 both False 0.5 True 1.0
See projects/adaptive_learning/shell/run.sh
for details.
Since the contrastive learning framework involves an auxiliary model during the training process, i.e., the reference model
# MODEL_NAME TO_MINIMIZE TASK PRETRAIN_STEPS SAMPLE_K CONTRAST_BY NAIVE_NEG_SAMPLING CL_THRESHOLD CL_ANNEAL ANNEAL_SPEED
export CUDA_VISIBLE_DEVICES=0; train_model seq2seq ppl personachat_extend 5000 6 both False 0.5 True 1.0
There are several arguments required to be declared explicitly in projects/contrastive_learning/shell/run.sh
.
Input the reference model path here:
declare -A ref_model_files=(
["none"]=None
["REF_MODEL_KEY"]="PATH/TO/THE/REFERENCE/MODEL"
)
and use it by setting the variable ref_model
:
ref_model=REF_MODEL_KEY
Apply the contrastive learning frmaework to seq2seq (or transformer by replacing cl_seq2seq
with cl_transformer
):
# MODEL_NAME TO_MINIMIZE TASK PRETRAIN_STEPS SAMPLE_K CONTRAST_BY NAIVE_NEG_SAMPLING CL_THRESHOLD CL_ANNEAL ANNEAL_SPEED
export CUDA_VISIBLE_DEVICES=0; train_model cl_seq2seq to_minimize personachat_extend 5000 6 both False 0.5 True 1.0
Start training by bash projects/contrastive_learning/shell/run.sh
Please reach me via my email (caihengyi at ict dot ac dot cn) if there is anything unclear.