Skip to content

Commit dc6e4da

Browse files
author
Shivanand Roy
committed
updated transformers, removed fastt5 dependency
1 parent 800994f commit dc6e4da

File tree

2 files changed

+31
-28
lines changed

2 files changed

+31
-28
lines changed

setup.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
setuptools.setup(
1010
name="simplet5",
11-
version="0.1.2",
11+
version="0.1.3",
1212
license="apache-2.0",
1313
author="Shivanand Roy",
1414
author_email="[email protected]",
@@ -41,9 +41,10 @@
4141
install_requires=[
4242
"sentencepiece",
4343
"torch>=1.7.0,!=1.8.0", # excludes torch v1.8.0
44-
"transformers==4.6.1",
45-
"pytorch-lightning==1.3.3",
46-
"fastt5==0.0.6",
44+
"transformers==4.10.0",
45+
"pytorch-lightning==1.4.5",
46+
"tqdm"
47+
# "fastt5==0.0.7",
4748
],
4849
classifiers=[
4950
"Intended Audience :: Developers",

simplet5/simplet5.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
AdamW,
77
T5ForConditionalGeneration,
88
MT5ForConditionalGeneration,
9+
ByT5Tokenizer,
910
PreTrainedTokenizer,
1011
T5TokenizerFast as T5Tokenizer,
1112
MT5TokenizerFast as MT5Tokenizer,
1213
)
1314
from transformers import AutoTokenizer
14-
from fastT5 import export_and_get_onnx_model
15+
16+
# from fastT5 import export_and_get_onnx_model
1517
from torch.utils.data import Dataset, DataLoader
1618
from transformers import AutoModelWithLMHead, AutoTokenizer
1719
import pytorch_lightning as pl
@@ -282,11 +284,11 @@ def from_pretrained(self, model_type="t5", model_name="t5-base") -> None:
282284
self.model = MT5ForConditionalGeneration.from_pretrained(
283285
f"{model_name}", return_dict=True
284286
)
285-
# elif model_type == "byt5":
286-
# self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_name}")
287-
# self.model = T5ForConditionalGeneration.from_pretrained(
288-
# f"{model_name}", return_dict=True
289-
# )
287+
elif model_type == "byt5":
288+
self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_name}")
289+
self.model = T5ForConditionalGeneration.from_pretrained(
290+
f"{model_name}", return_dict=True
291+
)
290292

291293
def train(
292294
self,
@@ -385,9 +387,9 @@ def load_model(
385387
elif model_type == "mt5":
386388
self.model = MT5ForConditionalGeneration.from_pretrained(f"{model_dir}")
387389
self.tokenizer = MT5Tokenizer.from_pretrained(f"{model_dir}")
388-
# elif model_type == "byt5":
389-
# self.model = T5ForConditionalGeneration.from_pretrained(f"{model_dir}")
390-
# self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_dir}")
390+
elif model_type == "byt5":
391+
self.model = T5ForConditionalGeneration.from_pretrained(f"{model_dir}")
392+
self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_dir}")
391393

392394
if use_gpu:
393395
if torch.cuda.is_available():
@@ -459,18 +461,18 @@ def predict(
459461
]
460462
return preds
461463

462-
def convert_and_load_onnx_model(self, model_dir: str):
463-
""" returns ONNX model """
464-
self.onnx_model = export_and_get_onnx_model(model_dir)
465-
self.onnx_tokenizer = AutoTokenizer.from_pretrained(model_dir)
466-
467-
def onnx_predict(self, source_text: str):
468-
""" generates prediction from ONNX model """
469-
token = self.onnx_tokenizer(source_text, return_tensors="pt")
470-
tokens = self.onnx_model.generate(
471-
input_ids=token["input_ids"],
472-
attention_mask=token["attention_mask"],
473-
num_beams=2,
474-
)
475-
output = self.onnx_tokenizer.decode(tokens.squeeze(), skip_special_tokens=True)
476-
return output
464+
# def convert_and_load_onnx_model(self, model_dir: str):
465+
# """ returns ONNX model """
466+
# self.onnx_model = export_and_get_onnx_model(model_dir)
467+
# self.onnx_tokenizer = AutoTokenizer.from_pretrained(model_dir)
468+
469+
# def onnx_predict(self, source_text: str):
470+
# """ generates prediction from ONNX model """
471+
# token = self.onnx_tokenizer(source_text, return_tensors="pt")
472+
# tokens = self.onnx_model.generate(
473+
# input_ids=token["input_ids"],
474+
# attention_mask=token["attention_mask"],
475+
# num_beams=2,
476+
# )
477+
# output = self.onnx_tokenizer.decode(tokens.squeeze(), skip_special_tokens=True)
478+
# return output

0 commit comments

Comments
 (0)