使用的数据集为THUCNews,为了加速实验,使用了别人提供的一个子集, 文本涉及10个类别:categories = ['体育', '财经', '房产', '家居', '教育', '科技', '时尚', '时政', '游戏', '娱乐']
cnews.train.txt: 训练集(500010)
cnews.val.txt: 验证集(50010)
cnews.test.txt: 测试集(1000*10)
将cnews.train.txt, cnews.val.txt, cnews.test.txt放到data目录下
下载链接: https://pan.baidu.com/s/1DOgxlY42roBpOKAMKPPKWA,密码: up9d
会在data目录下生成vocab.txt文件
python dataset.py --dataset_dir data
训练结束后模型会保存在models目录下
python main.py --attack_mode none --dataset_dir data
python main.py --attack_mode pgd --dataset_dir data
python main.py --attack_mode free --dataset_dir data
python main.py --attack_mode fgsm --dataset_dir data
python main.py --test --attack_mode none --dataset_dir data
python main.py --test --attack_mode pgd --dataset_dir data
python main.py --test --attack_mode free --dataset_dir data
python main.py --test --attack_mode fgsm --dataset_dir data
绘制不同算法的训练过程,在验证集上的准确率
python plot.py
precision | recall | f1 | |
---|---|---|---|
baseline | 0.9531 | 0.9520 | 0.9517 |
PGD | 0.9554 | 0.9551 | 0.9548 |
Free | 0.9504 | 0.9501 | 0.9496 |
FGSM | 0.9567 | 0.9563 | 0.9560 |
在单卡GPU-1070上训练20个epoch
时长(分) | |
---|---|
baseline | 5.53 |
PGD | 30.40 |
Free | 5.53 |
FGSM | 10.57 |
https://github.com/cjymz886/fast_adversarial_for_text_classification
https://github.com/locuslab/fast_adversarial