Skip to content

Commit

Permalink
#26 New models pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
gagan3012 committed May 10, 2021
1 parent 1286c60 commit 46827d8
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions keytotext/newmodels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
)


class NMPipeline:
def __init__(
self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, use_cuda: bool
):
self.model = model
self.tokenizer = tokenizer
self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"
self.model.to(self.device)

assert self.model.__class__.__name__ in ["T5ForConditionalGeneration"]

if "T5ForConditionalGeneration" in self.model.__class__.__name__:
self.model_type = "t5"

self.default_generate_kwargs = {
"max_length": 1024,
"num_beams": 4,
"length_penalty": 1.5,
"no_repeat_ngram_size": 3,
"early_stopping": True,
}

def __call__(self, keywords, **kwargs):
inputs = self._prepare_inputs_for_k2t(keywords)
result = ""
if not kwargs:
kwargs = self.default_generate_kwargs

for txt in inputs:
input_ids = self._tokenize("{} </s>".format(txt), padding=False)
outputs = self.model.generate(input_ids.to(self.device), **kwargs)
result += self.tokenizer.decode(outputs[0])

result = re.sub("<pad>|</s>", "", result)
return result.strip()

def _prepare_inputs_for_k2t(self, keywords):
text = str(keywords)
text = text.replace(",", " ")
text = text.replace("'", "")
text = text.replace("[", "")
text = text.replace("]", "")
texts = text.split(".")
return texts

def _tokenize(
self,
inputs,
padding=True,
truncation=True,
add_special_tokens=True,
max_length=1024,
):
inputs = self.tokenizer.encode(
inputs,
max_length=max_length,
add_special_tokens=add_special_tokens,
truncation=truncation,
padding="max_length" if padding else False,
pad_to_max_length=padding,
return_tensors="pt",
)
return inputs

0 comments on commit 46827d8

Please sign in to comment.