diff --git a/deeppavlov/configs/generative_qa/nq_fid.json b/deeppavlov/configs/generative_qa/nq_fid.json new file mode 100644 index 0000000000..1cbbaba7f1 --- /dev/null +++ b/deeppavlov/configs/generative_qa/nq_fid.json @@ -0,0 +1,97 @@ +{ + "dataset_reader": { + "class_name": "json_reader", + "data_path": "{DATASET_PATH}/natural_questions_dataset.json" + }, + "dataset_iterator": { + "class_name": "data_learning_iterator", + "seed": 42, + "shuffle": true + }, + "chainer": { + "in": ["question", "contexts", "titles"], + "in_y": ["target", "gold_answers"], + "pipe": [ + { + "class_name": "fid_input_preprocessor", + "vocab_file": "{TRANSFORMER}", + "max_seq_length": 200, + "in": ["question", "contexts"], + "out": ["input_ids", "attention_mask"] + }, + { + "class_name": "fid_target_preprocessor", + "vocab_file": "{TRANSFORMER}", + "answer_maxlength" : 50, + "in": ["target"], + "out": ["target_ids"] + }, + { + "class_name": "torch_generative_qa_fid", + "pretrained_transformer": "{TRANSFORMER}", + "save_path": "{MODEL_PATH}", + "load_path": "{MODEL_PATH}", + "optimizer": "AdamW", + "optimizer_parameters": { + "lr": 3e-04, + "weight_decay": 0.01, + "betas": [0.9, 0.999], + "eps": 1e-08 + }, + "learning_rate_drop_patience": 24, + "learning_rate_drop_div": 2, + "min_learning_rate": 1e-5, + "generate_max_length" : 50, + "in": ["input_ids", "attention_mask"], + "in_y": ["target_ids"], + "out": ["model_answer"] + } + ], + "out": ["model_answer"] + }, + "train": { + "show_examples": false, + "evaluation_targets": [ + "valid" + ], + "log_every_n_batches": 100, + "val_every_n_batches": 600, + "batch_size": 1, + "validation_patience": 100, + "metrics": [ + { + "name": "squad_v2_em", + "inputs": ["gold_answers", "model_answer"] + }, + { + "name": "squad_v2_f1", + "inputs": ["gold_answers", "model_answer"] + } + ], + "class_name": "torch_trainer" + }, + "metadata": { + "variables": { + "TRANSFORMER": "t5-base", + "ROOT_PATH": "~/.deeppavlov", + "DOWNLOADS_PATH": "{ROOT_PATH}/downloads", + "MODELS_PATH": "{ROOT_PATH}/models", + "MODEL_PATH": "{MODELS_PATH}/generative_qa/fusion_in_decoder/natural_questions", + "DATASET_PATH": "{DOWNLOADS_PATH}/natural_questions" + }, + "download": [ + { + "url": "http://files.deeppavlov.ai/deeppavlov_data/generative_qa/datasets/natural_questions/natural_questions_dataset.json", + "subdir": "{DATASET_PATH}" + }, + { + "url": "http://files.deeppavlov.ai/deeppavlov_data/generative_qa/models/fusion_in_decoder/natural_questions/config.json", + "subdir": "{MODEL_PATH}" + }, + { + "url": "http://files.deeppavlov.ai/deeppavlov_data/generative_qa/models/fusion_in_decoder/natural_questions/pytorch_model.bin", + "subdir": "{MODEL_PATH}" + } + ] + } +} \ No newline at end of file diff --git a/deeppavlov/configs/generative_qa/tqa_fid.json b/deeppavlov/configs/generative_qa/tqa_fid.json new file mode 100644 index 0000000000..c936307cc8 --- /dev/null +++ b/deeppavlov/configs/generative_qa/tqa_fid.json @@ -0,0 +1,97 @@ +{ + "dataset_reader": { + "class_name": "json_reader", + "data_path": "{DATASET_PATH}/trivia_qa_dataset.json" + }, + "dataset_iterator": { + "class_name": "data_learning_iterator", + "seed": 42, + "shuffle": true + }, + "chainer": { + "in": ["question", "contexts", "titles"], + "in_y": ["target", "gold_answers"], + "pipe": [ + { + "class_name": "fid_input_preprocessor", + "vocab_file": "{TRANSFORMER}", + "max_seq_length": 200, + "in": ["question", "contexts"], + "out": ["input_ids", "attention_mask"] + }, + { + "class_name": "fid_target_preprocessor", + "vocab_file": "{TRANSFORMER}", + "answer_maxlength" : 50, + "in": ["target"], + "out": ["target_ids"] + }, + { + "class_name": "torch_generative_qa_fid", + "pretrained_transformer": "{TRANSFORMER}", + "save_path": "{MODEL_PATH}", + "load_path": "{MODEL_PATH}", + "optimizer": "AdamW", + "optimizer_parameters": { + "lr": 3e-04, + "weight_decay": 0.01, + "betas": [0.9, 0.999], + "eps": 1e-08 + }, + "learning_rate_drop_patience": 24, + "learning_rate_drop_div": 2, + "min_learning_rate": 1e-5, + "generate_max_length" : 50, + "in": ["input_ids", "attention_mask"], + "in_y": ["target_ids"], + "out": ["model_answer"] + } + ], + "out": ["model_answer"] + }, + "train": { + "show_examples": false, + "evaluation_targets": [ + "valid" + ], + "log_every_n_batches": 100, + "val_every_n_batches": 600, + "batch_size": 1, + "validation_patience": 100, + "metrics": [ + { + "name": "squad_v2_em", + "inputs": ["gold_answers", "model_answer"] + }, + { + "name": "squad_v2_f1", + "inputs": ["gold_answers", "model_answer"] + } + ], + "class_name": "torch_trainer" + }, + "metadata": { + "variables": { + "TRANSFORMER": "t5-base", + "ROOT_PATH": "~/.deeppavlov", + "DOWNLOADS_PATH": "{ROOT_PATH}/downloads", + "MODELS_PATH": "{ROOT_PATH}/models", + "MODEL_PATH": "{MODELS_PATH}/generative_qa/fusion_in_decoder/trivia_qa", + "DATASET_PATH": "{DOWNLOADS_PATH}/trivia_qa" + }, + "download": [ + { + "url": "http://files.deeppavlov.ai/deeppavlov_data/generative_qa/datasets/trivia_qa/trivia_qa_dataset.json", + "subdir": "{DATASET_PATH}" + }, + { + "url": "http://files.deeppavlov.ai/deeppavlov_data/generative_qa/models/fusion_in_decoder/trivia_qa/config.json", + "subdir": "{MODEL_PATH}" + }, + { + "url": "http://files.deeppavlov.ai/deeppavlov_data/generative_qa/models/fusion_in_decoder/trivia_qa/pytorch_model.bin", + "subdir": "{MODEL_PATH}" + } + ] + } + } \ No newline at end of file diff --git a/deeppavlov/core/common/registry.json b/deeppavlov/core/common/registry.json index 42f0df484e..138d056c6f 100644 --- a/deeppavlov/core/common/registry.json +++ b/deeppavlov/core/common/registry.json @@ -16,11 +16,14 @@ "entity_linker": "deeppavlov.models.entity_extraction.entity_linking:EntityLinker", "faq_reader": "deeppavlov.dataset_readers.faq_reader:FaqDatasetReader", "fasttext": "deeppavlov.models.embedders.fasttext_embedder:FasttextEmbedder", + "fid_input_preprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:FiDInputPreprocessor", + "fid_target_preprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:FiDTargetPreprocessor", "fit_trainer": "deeppavlov.core.trainers.fit_trainer:FitTrainer", "hashing_tfidf_vectorizer": "deeppavlov.models.vectorizers.hashing_tfidf_vectorizer:HashingTfIdfVectorizer", "huggingface_dataset_iterator": "deeppavlov.dataset_iterators.huggingface_dataset_iterator:HuggingFaceDatasetIterator", "huggingface_dataset_reader": "deeppavlov.dataset_readers.huggingface_dataset_reader:HuggingFaceDatasetReader", "imdb_reader": "deeppavlov.dataset_readers.imdb_reader:ImdbReader", + "json_reader": "deeppavlov.dataset_readers.json_reader:JsonReader", "kenlm_elector": "deeppavlov.models.spelling_correction.electors.kenlm_elector:KenlmElector", "line_reader": "deeppavlov.dataset_readers.line_reader:LineReader", "logit_ranker": "deeppavlov.models.doc_retrieval.logit_ranker:LogitRanker", @@ -81,6 +84,7 @@ "top1_elector": "deeppavlov.models.spelling_correction.electors.top1_elector:TopOneElector", "torch_bert_ranker": "deeppavlov.models.torch_bert.torch_bert_ranker:TorchBertRankerModel", "torch_bert_ranker_preprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:TorchBertRankerPreprocessor", + "torch_generative_qa_fid": "deeppavlov.models.torch_bert.torch_generative_qa:TorchFiD", "torch_record_postprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:TorchRecordPostprocessor", "torch_squad_transformers_preprocessor": "deeppavlov.models.preprocessors.torch_transformers_preprocessor:TorchSquadTransformersPreprocessor", "torch_text_classification_model": "deeppavlov.models.classifiers.torch_classification_model:TorchTextClassificationModel", diff --git a/deeppavlov/core/common/requirements_registry.json b/deeppavlov/core/common/requirements_registry.json index d65eba771e..ac03a5e154 100644 --- a/deeppavlov/core/common/requirements_registry.json +++ b/deeppavlov/core/common/requirements_registry.json @@ -86,6 +86,10 @@ "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", "{DEEPPAVLOV_PATH}/requirements/transformers.txt" ], + "torch_generative_qa_fid": [ + "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", + "{DEEPPAVLOV_PATH}/requirements/transformers_3.0.2.txt" + ], "torch_record_postprocessor": [ "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", "{DEEPPAVLOV_PATH}/requirements/transformers.txt" @@ -113,6 +117,14 @@ "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", "{DEEPPAVLOV_PATH}/requirements/transformers.txt" ], + "fid_input_preprocessor": [ + "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", + "{DEEPPAVLOV_PATH}/requirements/transformers_3.0.2.txt" + ], + "fid_target_preprocessor": [ + "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", + "{DEEPPAVLOV_PATH}/requirements/transformers_3.0.2.txt" + ], "torch_transformers_multiplechoice": [ "{DEEPPAVLOV_PATH}/requirements/pytorch.txt", "{DEEPPAVLOV_PATH}/requirements/transformers.txt" diff --git a/deeppavlov/dataset_readers/json_reader.py b/deeppavlov/dataset_readers/json_reader.py new file mode 100644 index 0000000000..c38418e3f6 --- /dev/null +++ b/deeppavlov/dataset_readers/json_reader.py @@ -0,0 +1,30 @@ +# Copyright 2017 Neural Networks and Deep Learning lab, MIPT +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Dict, Optional + +from deeppavlov.core.common.registry import register +from deeppavlov.core.data.dataset_reader import DatasetReader + +@register('json_reader') +class JsonReader(DatasetReader): + + def read(self, data_path: str, valid_size: Optional[int] = None) -> Dict: + with open(data_path, 'r') as f: + dataset = json.load(f) + if valid_size is not None: + dataset["valid"] = dataset["valid"][:valid_size] + + return dataset diff --git a/deeppavlov/metrics/bleu.py b/deeppavlov/metrics/bleu.py index 75bfec2b79..4fcc8fec67 100644 --- a/deeppavlov/metrics/bleu.py +++ b/deeppavlov/metrics/bleu.py @@ -20,6 +20,8 @@ from deeppavlov.core.common.metrics_registry import register_metric from deeppavlov.metrics.google_bleu import compute_bleu +import numpy as np + SMOOTH = SmoothingFunction() diff --git a/deeppavlov/models/preprocessors/torch_transformers_preprocessor.py b/deeppavlov/models/preprocessors/torch_transformers_preprocessor.py index 8bc2daec34..51cdd81edf 100644 --- a/deeppavlov/models/preprocessors/torch_transformers_preprocessor.py +++ b/deeppavlov/models/preprocessors/torch_transformers_preprocessor.py @@ -421,6 +421,69 @@ def __call__(self, questions_batch: List[List[str]], rels_batch: List[List[str]] return input_features +@register('fid_input_preprocessor') +class FiDInputPreprocessor(Component): + def __init__(self, + vocab_file: str, + do_lower_case: bool = True, + max_seq_length: int = 512, + **kwargs) -> None: + self.max_seq_length = max_seq_length + + if Path(vocab_file).is_file(): + vocab_file = str(expand_path(vocab_file)) + self.tokenizer = AutoTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case) + else: + self.tokenizer = AutoTokenizer.from_pretrained(vocab_file, do_lower_case=do_lower_case) + + def __call__(self, questions_batch: List[str], contexts_batch: List[List[str]]): + prepare_data = lambda q, c,: f"question: {q} context: {c}" + passages_batch = [[prepare_data(question, context) for context in contexts] + for (question, contexts) in zip(questions_batch, contexts_batch)] + + passage_ids, passage_masks = [], [] + for text_passages in passages_batch: + passages_encoding = self.tokenizer( + text_passages, + max_length=self.max_seq_length if self.max_seq_length > 0 else None, + pad_to_max_length=True, + return_tensors='pt', + truncation=True if self.max_seq_length > 0 else False, + ) + passage_ids.append(passages_encoding['input_ids'][None]) + passage_masks.append(passages_encoding['attention_mask'][None]) + + passage_ids = torch.cat(passage_ids, dim=0) + passage_masks = torch.cat(passage_masks, dim=0) + + return passage_ids, passage_masks + +@register('fid_target_preprocessor') +class FiDTargetPreprocessor(Component): + def __init__(self, + vocab_file: str, + do_lower_case: bool = True, + answer_maxlength: int = 50, + **kwargs) -> None: + self.answer_maxlength = answer_maxlength + if Path(vocab_file).is_file(): + vocab_file = str(expand_path(vocab_file)) + self.tokenizer = AutoTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case) + else: + self.tokenizer = AutoTokenizer.from_pretrained(vocab_file, do_lower_case=do_lower_case) + + + def __call__(self, targets_batch: List[str]): + target_encoding = self.tokenizer( + targets_batch, + max_length=self.answer_maxlength if self.answer_maxlength > 0 else None, + pad_to_max_length=True, + return_tensors='pt', + truncation=True if self.answer_maxlength > 0 else False, + ) + target_ids = target_encoding["input_ids"] + return target_ids + @register('torch_transformers_ner_preprocessor') class TorchTransformersNerPreprocessor(Component): """ diff --git a/deeppavlov/models/torch_bert/fusion_in_decoder.py b/deeppavlov/models/torch_bert/fusion_in_decoder.py new file mode 100644 index 0000000000..7f17dcd6d3 --- /dev/null +++ b/deeppavlov/models/torch_bert/fusion_in_decoder.py @@ -0,0 +1,353 @@ +import types +import torch +import transformers +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss +import numpy as np + + +class FiDT5(transformers.T5ForConditionalGeneration): + def __init__(self, config): + super().__init__(config) + self.wrap_encoder() + + def forward_(self, **kwargs): + if 'input_ids' in kwargs: + kwargs['input_ids'] = kwargs['input_ids'].view(kwargs['input_ids'].size(0), -1) + if 'attention_mask' in kwargs: + kwargs['attention_mask'] = kwargs['attention_mask'].view(kwargs['attention_mask'].size(0), -1) + + return super(FiDT5, self).forward( + **kwargs + ) + + # We need to resize as B x (N * L) instead of (B * N) x L here + # because the T5 forward method uses the input tensors to infer + # dimensions used in the decoder. + # EncoderWrapper resizes the inputs as (B * N) x L. + def forward(self, input_ids=None, attention_mask=None, **kwargs): + if input_ids != None: + # inputs might have already be resized in the generate method + if input_ids.dim() == 3: + self.encoder.n_passages = input_ids.size(1) + input_ids = input_ids.view(input_ids.size(0), -1) + if attention_mask != None: + attention_mask = attention_mask.view(attention_mask.size(0), -1) + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + **kwargs + ) + + # We need to resize the inputs here, as the generate method expect 2D tensors + def generate(self, input_ids, attention_mask, max_length): + self.encoder.n_passages = input_ids.size(1) + return super().generate( + input_ids=input_ids.view(input_ids.size(0), -1), + attention_mask=attention_mask.view(attention_mask.size(0), -1), + max_length=max_length + ) + + def wrap_encoder(self, use_checkpoint=False): + """ + Wrap T5 encoder to obtain a Fusion-in-Decoder model. + """ + self.encoder = EncoderWrapper(self.encoder, use_checkpoint=use_checkpoint) + + def unwrap_encoder(self): + """ + Unwrap Fusion-in-Decoder encoder, useful to load T5 weights. + """ + self.encoder = self.encoder.encoder + block = [] + for mod in self.encoder.block: + block.append(mod.module) + block = nn.ModuleList(block) + self.encoder.block = block + + def load_t5(self, state_dict): + self.unwrap_encoder() + self.load_state_dict(state_dict) + self.wrap_encoder() + + def set_checkpoint(self, use_checkpoint): + """ + Enable or disable checkpointing in the encoder. + See https://pytorch.org/docs/stable/checkpoint.html + """ + for mod in self.encoder.encoder.block: + mod.use_checkpoint = use_checkpoint + + def reset_score_storage(self): + """ + Reset score storage, only used when cross-attention scores are saved + to train a retriever. + """ + for mod in self.decoder.block: + mod.layer[1].EncDecAttention.score_storage = None + + def get_crossattention_scores(self, context_mask): + """ + Cross-attention scores are aggregated to obtain a single scalar per + passage. This scalar can be seen as a similarity score between the + question and the input passage. It is obtained by averaging the + cross-attention scores obtained on the first decoded token over heads, + layers, and tokens of the input passage. + + More details in Distilling Knowledge from Reader to Retriever: + https://arxiv.org/abs/2012.04584. + """ + scores = [] + n_passages = context_mask.size(1) + for mod in self.decoder.block: + scores.append(mod.layer[1].EncDecAttention.score_storage) + scores = torch.cat(scores, dim=2) + bsz, n_heads, n_layers, _ = scores.size() + # batch_size, n_head, n_layers, n_passages, text_maxlength + scores = scores.view(bsz, n_heads, n_layers, n_passages, -1) + scores = scores.masked_fill(~context_mask[:, None, None], 0.) + scores = scores.sum(dim=[1, 2, 4]) + ntokens = context_mask.sum(dim=[2]) * n_layers * n_heads + scores = scores/ntokens + return scores + + def overwrite_forward_crossattention(self): + """ + Replace cross-attention forward function, only used to save + cross-attention scores. + """ + for mod in self.decoder.block: + attn = mod.layer[1].EncDecAttention + attn.forward = types.MethodType(cross_attention_forward, attn) + +class EncoderWrapper(torch.nn.Module): + """ + Encoder Wrapper for T5 Wrapper to obtain a Fusion-in-Decoder model. + """ + def __init__(self, encoder, use_checkpoint=False): + super().__init__() + + self.encoder = encoder + apply_checkpoint_wrapper(self.encoder, use_checkpoint) + + def forward(self, input_ids=None, attention_mask=None, **kwargs,): + # total_length = n_passages * passage_length + bsz, total_length = input_ids.shape + passage_length = total_length // self.n_passages + input_ids = input_ids.view(bsz*self.n_passages, passage_length) + attention_mask = attention_mask.view(bsz*self.n_passages, passage_length) + outputs = self.encoder(input_ids, attention_mask, **kwargs) + outputs = (outputs[0].view(bsz, self.n_passages*passage_length, -1), ) + outputs[1:] + return outputs + +class CheckpointWrapper(torch.nn.Module): + """ + Wrapper replacing None outputs by empty tensors, which allows the use of + checkpointing. + """ + def __init__(self, module, use_checkpoint=False): + super().__init__() + self.module = module + self.use_checkpoint = use_checkpoint + + def forward(self, hidden_states, attention_mask, position_bias, **kwargs): + if self.use_checkpoint and self.training: + kwargs = {k: v for k, v in kwargs.items() if v is not None} + def custom_forward(*inputs): + output = self.module(*inputs, **kwargs) + empty = torch.tensor( + [], + dtype=torch.float, + device=output[0].device, + requires_grad=True) + output = tuple(x if x is not None else empty for x in output) + return output + + output = torch.utils.checkpoint.checkpoint( + custom_forward, + hidden_states, + attention_mask, + position_bias + ) + output = tuple(x if x.size() != 0 else None for x in output) + else: + output = self.module(hidden_states, attention_mask, position_bias, **kwargs) + return output + +def apply_checkpoint_wrapper(t5stack, use_checkpoint): + """ + Wrap each block of the encoder to enable checkpointing. + """ + block = [] + for mod in t5stack.block: + wrapped_mod = CheckpointWrapper(mod, use_checkpoint) + block.append(wrapped_mod) + block = nn.ModuleList(block) + t5stack.block = block + +def cross_attention_forward( + self, + input, + mask=None, + kv=None, + position_bias=None, + past_key_value_state=None, + head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + This only works for computing cross attention over the input + """ + assert(kv != None) + assert(head_mask == None) + assert(position_bias != None or self.has_relative_attention_bias) + + bsz, qlen, dim = input.size() + n_heads, d_heads = self.n_heads, self.d_kv + klen = kv.size(1) + + q = self.q(input).view(bsz, -1, n_heads, d_heads).transpose(1, 2) + if past_key_value_state == None: + k = self.k(kv).view(bsz, -1, n_heads, d_heads).transpose(1, 2) + v = self.v(kv).view(bsz, -1, n_heads, d_heads).transpose(1, 2) + else: + k, v = past_key_value_state + + scores = torch.einsum("bnqd,bnkd->bnqk", q, k) + + if mask is not None: + scores += mask + + if position_bias is None: + position_bias = self.compute_bias(qlen, klen) + scores += position_bias + + if self.score_storage is None: + self.score_storage = scores + + attn = F.softmax(scores.float(), dim=-1).type_as(scores) + attn = F.dropout(attn, p=self.dropout, training=self.training) + + output = torch.matmul(attn, v) + output = output.transpose(1, 2).contiguous().view(bsz, -1, self.inner_dim) + output = self.o(output) + + if use_cache: + output = (output,) + ((k, v),) + else: + output = (output,) + (None,) + + if output_attentions: + output = output + (attn,) + + if self.has_relative_attention_bias: + output = output + (position_bias,) + + return output + +class RetrieverConfig(transformers.BertConfig): + + def __init__(self, + indexing_dimension=768, + apply_question_mask=False, + apply_passage_mask=False, + extract_cls=False, + passage_maxlength=200, + question_maxlength=40, + projection=True, + **kwargs): + super().__init__(**kwargs) + self.indexing_dimension = indexing_dimension + self.apply_question_mask = apply_question_mask + self.apply_passage_mask = apply_passage_mask + self.extract_cls=extract_cls + self.passage_maxlength = passage_maxlength + self.question_maxlength = question_maxlength + self.projection = projection + +class Retriever(transformers.PreTrainedModel): + + config_class = RetrieverConfig + base_model_prefix = "retriever" + + def __init__(self, config, initialize_wBERT=False): + super().__init__(config) + assert config.projection or config.indexing_dimension == 768, \ + 'If no projection then indexing dimension must be equal to 768' + self.config = config + if initialize_wBERT: + self.model = transformers.BertModel.from_pretrained('bert-base-uncased') + else: + self.model = transformers.BertModel(config) + if self.config.projection: + self.proj = nn.Linear( + self.model.config.hidden_size, + self.config.indexing_dimension + ) + self.norm = nn.LayerNorm(self.config.indexing_dimension) + self.loss_fct = torch.nn.KLDivLoss() + + def forward(self, + question_ids, + question_mask, + passage_ids, + passage_mask, + gold_score=None): + question_output = self.embed_text( + text_ids=question_ids, + text_mask=question_mask, + apply_mask=self.config.apply_question_mask, + extract_cls=self.config.extract_cls, + ) + bsz, n_passages, plen = passage_ids.size() + passage_ids = passage_ids.view(bsz * n_passages, plen) + passage_mask = passage_mask.view(bsz * n_passages, plen) + passage_output = self.embed_text( + text_ids=passage_ids, + text_mask=passage_mask, + apply_mask=self.config.apply_passage_mask, + extract_cls=self.config.extract_cls, + ) + + score = torch.einsum( + 'bd,bid->bi', + question_output, + passage_output.view(bsz, n_passages, -1) + ) + score = score / np.sqrt(question_output.size(-1)) + if gold_score is not None: + loss = self.kldivloss(score, gold_score) + else: + loss = None + + return question_output, passage_output, score, loss + + def embed_text(self, text_ids, text_mask, apply_mask=False, extract_cls=False): + text_output = self.model( + input_ids=text_ids, + attention_mask=text_mask if apply_mask else None + ) + if type(text_output) is not tuple: + text_output.to_tuple() + text_output = text_output[0] + if self.config.projection: + text_output = self.proj(text_output) + text_output = self.norm(text_output) + + if extract_cls: + text_output = text_output[:, 0] + else: + if apply_mask: + text_output = text_output.masked_fill(~text_mask[:, :, None], 0.) + text_output = torch.sum(text_output, dim=1) / torch.sum(text_mask, dim=1)[:, None] + else: + text_output = torch.mean(text_output, dim=1) + return text_output + + def kldivloss(self, score, gold_score): + gold_score = torch.softmax(gold_score, dim=-1) + score = torch.nn.functional.log_softmax(score, dim=-1) + return self.loss_fct(score, gold_score) diff --git a/deeppavlov/models/torch_bert/torch_generative_qa.py b/deeppavlov/models/torch_bert/torch_generative_qa.py new file mode 100644 index 0000000000..8fa4d593cf --- /dev/null +++ b/deeppavlov/models/torch_bert/torch_generative_qa.py @@ -0,0 +1,203 @@ +# Copyright 2017 Neural Networks and Deep Learning lab, MIPT +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from logging import getLogger +from pathlib import Path +from typing import List, Optional, Dict + +import torch +from overrides import overrides +from transformers import T5ForConditionalGeneration, T5Tokenizer + +from deeppavlov.core.common.registry import register +from deeppavlov.core.models.torch_model import TorchModel + + +from deeppavlov.models.torch_bert.fusion_in_decoder import FiDT5 + + +logger = getLogger(__name__) + + +@register('torch_generative_qa_fid') +class TorchFiD(TorchModel): + def __init__(self, + pretrained_transformer: str = "t5-base", + attention_probs_keep_prob: Optional[float] = None, + hidden_keep_prob: Optional[float] = None, + optimizer: str = "AdamW", + optimizer_parameters: Optional[dict] = None, + bert_config_file: Optional[str] = None, + learning_rate_drop_patience: int = 20, + learning_rate_drop_div: float = 2.0, + load_before_drop: bool = True, + clip_norm: Optional[float] = None, + min_learning_rate: float = 1e-06, + generate_max_length: int = 50, + **kwargs) -> None: + + if not optimizer_parameters: + optimizer_parameters = {"lr": 0.01, + "weight_decay": 0.01, + "betas": (0.9, 0.999), + "eps": 1e-6} + self.generate_max_length = generate_max_length + + self.attention_probs_keep_prob = attention_probs_keep_prob + self.hidden_keep_prob = hidden_keep_prob + self.clip_norm = clip_norm + + self.pretrained_transformer = pretrained_transformer + self.bert_config_file = bert_config_file + self.tokenizer = T5Tokenizer.from_pretrained(self.pretrained_transformer, return_dict=False) + + super().__init__(optimizer=optimizer, + optimizer_parameters=optimizer_parameters, + learning_rate_drop_patience=learning_rate_drop_patience, + learning_rate_drop_div=learning_rate_drop_div, + load_before_drop=load_before_drop, + min_learning_rate=min_learning_rate, + **kwargs) + + def train_on_batch(self, input_ids_batch, attention_mask_batch, target_ids_batch) -> Dict: + input_ids_batch = torch.LongTensor(input_ids_batch).to(self.device) + attention_mask_batch = torch.LongTensor(attention_mask_batch).to(self.device) + target_ids_batch = torch.LongTensor(target_ids_batch).to(self.device) + + input_ = { + 'input_ids': input_ids_batch, + 'attention_mask': attention_mask_batch, + 'labels': target_ids_batch + } + + self.optimizer.zero_grad() + + loss = self.model(**input_)[0] + if self.is_data_parallel: + loss = loss.mean() + loss.backward() + + # Clip the norm of the gradients to 1.0. + # This is to help prevent the "exploding gradients" problem. + if self.clip_norm: + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.clip_norm) + + self.optimizer.step() + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + return {'loss': loss.item()} + + @property + def is_data_parallel(self) -> bool: + return isinstance(self.model, torch.nn.DataParallel) + + def __call__(self, input_ids_batch, attention_mask_batch): + input_ids_batch = torch.LongTensor(input_ids_batch).to(self.device) + attention_mask_batch = torch.LongTensor(attention_mask_batch).to(self.device) + input_ = { + 'input_ids': input_ids_batch, + 'attention_mask': attention_mask_batch, + } + + model = self.model.module if hasattr(self.model, "module") else self.model + with torch.no_grad(): + answer_ids_batch = model.generate(max_length=self.generate_max_length, **input_) + + answers_batch = self.tokenizer.batch_decode(answer_ids_batch, skip_special_tokens=True) + return answers_batch + + @overrides + def save(self, fname: Optional[str] = None, *args, **kwargs): + if fname is None: + fname = self.save_path + os.makedirs(fname, exist_ok=True) + logger.info(f"Saving checkpoint to {fname}.") + + # Save model + model_dir_path = fname + model_to_save = self.model.module if hasattr(self.model, "module") else self.model + model_to_save.save_pretrained(model_dir_path) + + # Save optimizer and scheduler + optimizer_path = str(Path(fname, "optimizer.pth.tar").resolve()) + optimizer_state = { + "optimizer": self.optimizer.state_dict() + } + torch.save(optimizer_state, optimizer_path) + + + def init_optimizer_from_scratch(self) -> None: + self.optimizer = getattr(torch.optim, self.optimizer_name)( + self.model.parameters(), **self.optimizer_parameters) + + if self.lr_scheduler_name is not None: + self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_name)( + self.optimizer, **self.lr_scheduler_parameters) + + if self.opt.get("criterion", None): + self.criterion = getattr(torch.nn, self.opt.get("criterion", None))() + + def init_model_from_scratch(self) -> None: + logger.info(f"From pretrained {self.pretrained_transformer}.") + self.tokenizer = T5Tokenizer.from_pretrained(self.pretrained_transformer, return_dict=False) + t5 = T5ForConditionalGeneration.from_pretrained(self.pretrained_transformer) + + self.model = FiDT5(t5.config) + self.model.load_t5(t5.state_dict()) + self.model.to(self.device) + + + def load_model_from_checkpoint(self, model_dir_path: str): + logger.info(f"Loading model from {model_dir_path}.") + self.model = FiDT5.from_pretrained(model_dir_path) + self.model = self.model.to(self.device) + + def load_optimizer_from_checkpoint(self, optimizer_path: str): + logger.info(f"Loading optimizer from {optimizer_path}.") + self.init_optimizer_from_scratch() + optimizer_state = torch.load(optimizer_path, map_location=self.device) + self.optimizer.load_state_dict(optimizer_state["optimizer"]) + + @overrides + def load(self, fname: Optional[str] = None, *args, **kwargs) -> None: + if fname is not None: + self.load_path = fname + + if self.load_path is not None: + logger.info(f"Load path {self.load_path} is given.") + model_config_path = Path(self.load_path) / "config.json" + model_weights_path = Path(self.load_path) / "pytorch_model.bin" + optimizer_path = Path(self.load_path) / "optimizer.pth.tar" + + if model_config_path.exists() and model_weights_path.exists(): + self.load_model_from_checkpoint(self.load_path) + else: + self.init_model_from_scratch() + logger.info(f"Init model from scratch: {model_config_path} or {model_weights_path} does not exist.") + + if optimizer_path.exists(): + self.load_optimizer_from_checkpoint(str(optimizer_path.resolve())) + else: + self.init_optimizer_from_scratch() + logger.info(f"Init optimizer from scratch: {optimizer_path} does not exist.") + else: + self.init_model_from_scratch() + self.init_optimizer_from_scratch() + logger.info(f"Init model and optimizer from scratch: \"load_path\" and \"fname\" are not given.") + + if self.device.type == "cuda" and torch.cuda.device_count() > 1: + self.model = torch.nn.DataParallel(self.model) \ No newline at end of file diff --git a/deeppavlov/requirements/transformers_3.0.2.txt b/deeppavlov/requirements/transformers_3.0.2.txt new file mode 100644 index 0000000000..70ef216ce9 --- /dev/null +++ b/deeppavlov/requirements/transformers_3.0.2.txt @@ -0,0 +1 @@ +transformers==3.0.2 \ No newline at end of file diff --git a/docs/features/models/generative_qa.rst b/docs/features/models/generative_qa.rst new file mode 100644 index 0000000000..ebf62f6546 --- /dev/null +++ b/docs/features/models/generative_qa.rst @@ -0,0 +1,214 @@ +Generative Question Answering +============================= + +Task definitfion +---------------- +Generative Question Answering is the task of finding an answer on question in a given contexts (e.g, paragraphs from Wikipedia), +where the answer to each question is **not necessary** a segment of the context. + + +**Question**: + + Is it possible to have a rating above 4000 in chess? + +**Contexts**: + + > Right now that can't really happen. Now, the highest-rated chess player is Stockfish 12, with a rating of 3515. A rating difference of 400 points means you'll beat your opponent over 90% of the time. Here we're looking at an even bigger difference than that: about 500 points. + + > It's nearly impossible to measure the rating difference between two players so far apart in skill. For there to be a player with a rating of 4000, there would first have to be other players with ratings that are at least fairly close, like 3800. + +**Answer**: + + not really possible + +Datasets +-------- +We consider the following datasets: + +- `Natural Questions `__ +- `TriviaQA `__ + +Specifically, we validate our model on *Natural Questions* and *TriviaQA* from: https://github.com/facebookresearch/FiD. + + +Datasets format +~~~~~~~~~~~~~~~ +{ + "train": [ + [ + [ "question", [ "contexts" ], [ "titles" ] ], + + [ "target", [ "answers" ] ] + + ], + + ... + + ] + + "valid": [ ... ] + + "test": [ ... ] + +} + +Built-In Models +--------------- +DeepPavlov's model for generative question answering is based on Fusion-in-decoder(FiD) base. +The model generates answer based on the question and k-support contexts. + +Currently, we provide two built-in models for generative question answering in DeepPavlov library, finetuned on 2 datasets: + +- Natural Questions :config:`deeppavlov/configs/generative_qa/nq_fid.json` + +- TriviaQA :config:`deeppavlov/configs/generative_qa/tqa_fid.json` + +Architecture +~~~~~~~~~~~~ +FiD model uses several support passages to gather usefull information from multiple knowledge sources. Firstly, every +passage is concatinated with the question like this *"question: What is the capital of UK? passage: London is the capital of UK"* +and processed independently from other passages by the encoder of pretrained sequence-to-sequence network (e.g. T5). +Then the decoder performs attention over the concatenation of the resulting representations of all the retrieved passages + + +Metrics +~~~~~~~ +Natural Questions dataset +^^^^^^^^^^^^^^^^^^^^^^^^^ ++---------------------------------------------------------+---------------------------------+---------------------------------+ +| Dataset | Natural Questions (dev) | Natural Questions (test) | ++---------------------------------------------------------+----------------+----------------+----------------+----------------+ +| Model | EM | F-1 | EM | F-1 | ++=========================================================+================+================+================+================+ +| :config:`DeepPavlov FiD ` | 39.9 | 50.0 | 46.0 | 54.1 | ++---------------------------------------------------------+----------------+----------------+----------------+----------------+ +| `T5`_ | 42.0 | 50.6 | 42.2 | 49.7 | ++---------------------------------------------------------+----------------+----------------+----------------+----------------+ + + +TriviaQA dataset +^^^^^^^^^^^^^^^^ ++---------------------------------------------------------+---------------------------------+---------------------------------+ +| Dataset | TriviaQA (dev) | TriviaQA (test) | ++---------------------------------------------------------+----------------+----------------+----------------+----------------+ +| Model | EM | F-1 | EM | F-1 | ++=========================================================+================+================+================+================+ +| :config:`DeepPavlov FiD ` | 61.8 | 69.6 | 63.1 | 70.0 | ++---------------------------------------------------------+----------------+----------------+----------------+----------------+ +| :config:`DeepPavlov FiD ` | 51.1 | 61.3 | 52.2 | 61.9 | ++---------------------------------------------------------+----------------+----------------+----------------+----------------+ +| `T5`_ | 46.0 | 55.0 | 46.1 | 55.3 | ++---------------------------------------------------------+----------------+----------------+----------------+----------------+ +| `QANet`_ | 51.1 | 56.6 | -- | -- | ++---------------------------------------------------------+----------------+----------------+----------------+----------------+ +| `M-Reader`_ | -- | -- | 46.9 | 52.9 | ++---------------------------------------------------------+----------------+----------------+----------------+----------------+ +| `MEMEN`_ | 43.2 | 46.9 | -- | -- | ++---------------------------------------------------------+----------------+----------------+----------------+----------------+ +| `BiDAF`_ | 40.3 | 45.7 | -- | -- | ++---------------------------------------------------------+----------------+----------------+----------------+----------------+ + + +.. _`M-Reader`: https://arxiv.org/abs/1705.02798 +.. _`MEMEN`: https://arxiv.org/abs/1707.09098 +.. _`QANet`: https://arxiv.org/abs/1804.09541 +.. _`BiDAF`: https://arxiv.org/abs/1611.01603 +.. _`T5`: https://arxiv.org/abs/1910.10683 + + + +Prerequisites +------------- + +Before using the models make sure that all required packages are installed running the command: + + .. code:: bash + + python -m deeppavlov install nq_fid + python -m deeppavlov install tqa_fid + + +Pretrained models are available and can be downloaded (~0.9Gb): + + .. code:: bash + + python -m deeppavlov download nq_fid + python -m deeppavlov download tqa_fid + + +Model usage from Python +----------------------- + +Interact +~~~~~~~~ + .. code:: python + + from deeppavlov import build_model + + model = build_model('nq_fid', download=True) + + model([ + "What is the capital of UK?", + "Where did the name Atari itself come from?" + ], + [ + [ + "The name Britain is sometimes used to refer to the United Kingdom as a whole", + "London is the capital of Great Britain" + ], + [ + "Bushnell and Dabney were originally going to name their company Syzygy, a term for planetary alignment, but found that it had been registered already.", + "Instead, they chose a word from the Japanese game Go. The Japanese equivalent of chess, in Go Atari means something similar to \'check\'." + ] + ]) + >>> ['london', 'the japanese game go'] + + model([ + "How many points do you need to win in badminton?" + ], + [ + [ + "A rally is lost if the shuttle is hit into the net, or over the net but outside of the opponent's court.", + "A rally is also lost if the shuttle touches the player's clothing or body, or if it is hit before it crosses over the net", + 'The side winning a rally adds a point to its score', 'A match consists of the best of 3 games of 21 points (games cap at 30 points)', + "A rally is won when a shuttle is hit over the net and onto the floor of the opponent's court.", + 'At 29 all, the side scoring the 30th point, wins that game', + 'The side winning a game serves first in the next game', + 'At 20 all, the side which gains a 2 point lead first, wins that game.', + 'Each gamestarts at 0-0. If the match goes to the third game that third game will be played to 15' + ] + ]) + >>> ['21'] + +Train +~~~~~ + .. code:: python + + from deeppavlov import train_model + + model = train_model('nq_fid', download=True) + + +Model usage from CLI +-------------------- + +Train +~~~~~ + .. code:: bash + + python -m deeppavlov train nq_fid + +Evaluate +~~~~~~~~ + .. code:: bash + + python -m deeppavlov evaluate nq_fid + +Interact +~~~~~~~~ + +Interact mode provides command line interface to already trained model. + + .. code:: bash + + python -m deeppavlov interact nq_fid diff --git a/docs/index.rst b/docs/index.rst index 0fe7640253..0c556edcc3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -38,6 +38,7 @@ Welcome to DeepPavlov's documentation! Knowledge Base Question answering Relation Extraction SuperGLUE Submission + Generative Question Answering .. toctree::