diff --git a/README.md b/README.md index 41d984a..49cb302 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,37 @@ -# BDG(BERT-based Distractor Generation) +# BDG(Distractor Generation) Code for "A BERT-based Distractor Generation Scheme with Multi-tasking and Negative Answer Training Strategies." +[Paper](https://www.aclweb.org/anthology/2020.findings-emnlp.393/) + +## V2 +Updated result using BART. BART model is uploaded in HuggingFace model hub. +| model | BLEU1 | BLEU2 | BLEU3 | BLEU4 | ROUGEL | +|---------------|-------|-------|-------|-------|--------| +| Bert DG | 35.30 | 20.65 | 13.66 | 9.53 | 31.11 | +| Bert DG pm | 39.81 | 24.81 | 17.66 | 13.56 | 34.01 | +| Bert DG an+pm | 39.52 | 24.29 | 17.28 | 13.28 | 33.40 | +| Bart DG | 40.76 | 26.40 | 19.14 | 14.65 | 35.53 | +| Bart DG pm | 41.85 | 27.45 | 20.47 | 16.33 | 37.15 | +| Bart DG an+pm | 40.26 | 25.86 | 18.85 | 14.65 | 35.64 | +* higher is better + +| model | Count BLEU1 > 0.95 | +|---------------|--------------------| +| Bert DG | 115 | +| Bert DG pm | 57 | +| Bert DG an+pm | 43 | +| Bart DG | 110 | +| Bart DG pm | 60 | +| Bart DG an+pm | 23 | +| Gold | 12 | +* lower is better ## Trained Model and Code Example +### BART +Distractor: https://huggingface.co/voidful/bart-distractor-generation +Distractor PM: https://huggingface.co/voidful/bart-distractor-generation-pm +Distractor AN+PM: https://huggingface.co/voidful/bart-distractor-generation-both + +### BERT Trained model available on release: https://github.com/voidful/BDG/releases/tag/v1.0 @@ -39,8 +69,18 @@ Download dataset [here](https://github.com/Yifan-Gao/Distractor-Generation-RACE) run `convert_data.py` to do preprocessing. run `dataset_stat.py` for dataset statistics. -## Train BERT-based Distractor Generator -run the following in main dir: +## Train Distractor Generator +### Bart +using tfkit==0.7.0 and transformers==4.4.2 +```bash +tfkit-train --savedir ./race_cqa_gen_d_bart/ --train ./race_train_updated_cqa_dsep_a_bart.csv --test ./race_test_updated_cqa_dsep_a_bart.csv --model seq2seq --config facebook/bart-base --batch 9 --epoch 10 --grad_accum 2 --no_eval; +tfkit-train --savedir ./race_cqa_gen_d_bart_pm/ --train ./race_train_updated_cqa_dsep_a_bart.csv --test ./race_test_updated_cqa_dsep_a_bart.csv --model seq2seq --config facebook/bart-base --batch 9 --epoch 10 --grad_accum 2 --no_eval --likelihood pos; +tfkit-train --savedir ./race_cqa_gen_d_bart_both/ --train ./race_train_updated_cqa_dsep_a_bart.csv --test ./race_test_updated_cqa_dsep_a_bart.csv --model seq2seq --config facebook/bart-base --batch 9 --epoch 10 --grad_accum 2 --no_eval --likelihood both; +``` + +### Bert +using environment from `requirement.txt` +run the following in main dir: ### Train BDG Model ```bash tfkit-train --maxlen 512 --savedir ./race_cqa_gen_d/ --train ./data_preprocessing/processed_data/race_train_updated_cqa_dsep_a.csv --test ./data_preprocessing/processed_data/race_test_updated_cqa_dsep_a.csv --model onebyone --tensorboard --config bert-base-cased --batch 30 --epoch 6; diff --git a/convert_tfkit_bart_to_hf_model.ipynb b/convert_tfkit_bart_to_hf_model.ipynb new file mode 100644 index 0000000..fecec62 --- /dev/null +++ b/convert_tfkit_bart_to_hf_model.ipynb @@ -0,0 +1,670 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# !pip install nlp2go transformers tfkit -U" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "source_path = './tfkit-model-path'\n", + "target_path = \"./bart-distractor-generation/\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "===model info===\n", + "model_config : facebook/bart-base\n", + "tags : ['seq2seq_0']\n", + "type : ['seq2seq']\n", + "maxlen : 1024\n", + "epoch : 8\n", + "==========\n", + "loading saved model\n", + "Using device: cuda\n", + "finish loading\n" + ] + }, + { + "data": { + "text/plain": [ + "Model(\n", + " (pretrained): BartModel(\n", + " (shared): Embedding(50265, 768, padding_idx=1)\n", + " (encoder): BartEncoder(\n", + " (embed_tokens): Embedding(50265, 768, padding_idx=1)\n", + " (embed_positions): BartLearnedPositionalEmbedding(1026, 768)\n", + " (layers): ModuleList(\n", + " (0): BartEncoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (1): BartEncoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (2): BartEncoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (3): BartEncoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (4): BartEncoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (5): BartEncoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (decoder): BartDecoder(\n", + " (embed_tokens): Embedding(50265, 768, padding_idx=1)\n", + " (embed_positions): BartLearnedPositionalEmbedding(1026, 768)\n", + " (layers): ModuleList(\n", + " (0): BartDecoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (1): BartDecoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (2): BartDecoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (3): BartDecoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (4): BartDecoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (5): BartDecoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (model): Linear(in_features=768, out_features=50265, bias=False)\n", + ")" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import nlp2go\n", + "tfkit_model = nlp2go.Model(source_path)\n", + "tfkit_model.model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "BartForConditionalGeneration(\n", + " (model): BartModel(\n", + " (shared): Embedding(50265, 768, padding_idx=1)\n", + " (encoder): BartEncoder(\n", + " (embed_tokens): Embedding(50265, 768, padding_idx=1)\n", + " (embed_positions): BartLearnedPositionalEmbedding(1026, 768)\n", + " (layers): ModuleList(\n", + " (0): BartEncoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (1): BartEncoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (2): BartEncoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (3): BartEncoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (4): BartEncoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (5): BartEncoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (decoder): BartDecoder(\n", + " (embed_tokens): Embedding(50265, 768, padding_idx=1)\n", + " (embed_positions): BartLearnedPositionalEmbedding(1026, 768)\n", + " (layers): ModuleList(\n", + " (0): BartDecoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (1): BartDecoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (2): BartDecoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (3): BartDecoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (4): BartDecoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (5): BartDecoderLayer(\n", + " (self_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): BartAttention(\n", + " (k_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (v_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (q_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=768, out_features=3072, bias=True)\n", + " (fc2): Linear(in_features=3072, out_features=768, bias=True)\n", + " (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (lm_head): Linear(in_features=768, out_features=50265, bias=False)\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from transformers import AutoTokenizer, BartForConditionalGeneration\n", + "from transformers import pipeline\n", + "hf_model = BartForConditionalGeneration.from_pretrained('facebook/bart-base').to('cuda')\n", + "hf_model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "odict_keys(['model', 'lm_head'])\n", + "odict_keys(['shared', 'encoder', 'decoder'])\n" + ] + } + ], + "source": [ + "print(hf_model._modules.keys())\n", + "print(hf_model.model._modules.keys())" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "hf_model.lm_head = tfkit_model.model.model\n", + "hf_model.model = tfkit_model.model.pretrained" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "text = \"shyness is the cause of much unhappiness for a great many people . shy people are anxious and self - conscious ; that is , they are over concerned with their own appearance and actions . worrisome thoughts are constantly occurring in their minds : what kind of impression am i making ? do they like me ? do i sound stupid ? am i wearing unattractive clothes ? it is obvious that such uncomfortable feelings must affect people unfavorably . adperson ' s self concept is reflected in the way he or she behaves and the way a person behaves affects other people ' s reactions . in general , the way people think about themselves has a deep effect on all areas of their lives . shy people , have low self - esteem , are likely to be passive and easily influenced by others . they need reassurance ( , ) that they are doing \\\" the right thing \\\" . shy people are very sensitive to criticism . it makes them feel inferior . they also find it difficult to be pleased by praises because they believe they are unworthy of praise . a shy person may respond to a praise with a statement like this one : \\\" you \\' re just saying that to make me feel good . i know it ' s not true . \\\" it is clear that , while self - awareness is a healthy quality , overdoing it is harmful . can shyness be completely got rid of , or at least reduced ? fortunately , people can overcome shyness with determination . it is important for people to accept their weaknesses as well as their strengths , for example , not fair for them to label themselves inferior because they have to be realistic . living on the impossible leads to absence of inferiority . each one of us has his or her own characteristics . we are interested in our own personal ways . the better we understand ourselves . the easier it becomes to live up to our chances for a rich and fulfilling life . according to the wirter , self - awareness is a good characteristic\"" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n" + ] + } + ], + "source": [ + "inputs = tfkit_model.model.tokenizer(text, max_length=1024, return_tensors='pt').to('cuda')" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "hf_model.config.tie_word_embeddings = False\n", + "hf_model.config.tie_encoder_decoder = False" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "input_ids = inputs['input_ids'].to('cuda')" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'a good quality'" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "generations = hf_model.generate(input_ids)\n", + "tfkit_model.model.tokenizer.decode(generations[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "('./bart-distractor-generation/tokenizer_config.json',\n", + " './bart-distractor-generation/special_tokens_map.json',\n", + " './bart-distractor-generation/vocab.json',\n", + " './bart-distractor-generation/merges.txt',\n", + " './bart-distractor-generation/added_tokens.json')" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hf_model.save_pretrained(target_path)\n", + "tfkit_model.model.tokenizer.save_pretrained(target_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'missing_keys': [], 'unexpected_keys': [], 'error_msgs': []}\n" + ] + }, + { + "data": { + "text/plain": [ + "'a good quality'" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hf_model,info = BartForConditionalGeneration.from_pretrained(target_path, output_loading_info=True)\n", + "print(info)\n", + "hf_model.to('cuda')\n", + "generations = hf_model.generate(inputs['input_ids'])\n", + "tfkit_model.model.tokenizer.decode(generations[0])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file