Skip to content

Commit 5a7bc3c

Browse files
author
Shivanand Roy
authored
Merge pull request #13 from Shivanandroy/byt5
Byt5 support, transformers upgrade, dropping onnx support
2 parents 800994f + a9fd063 commit 5a7bc3c

File tree

4 files changed

+37
-40
lines changed

4 files changed

+37
-40
lines changed

README.md

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
<img align="center" src="data/st5.png" alt="simpleT5">
22

33
<p align="center">
4-
<b>Quickly train T5 models in just 3 lines of code with ONNX inference
4+
<b>Quickly train T5/mT5/byT5 models in just 3 lines of code
55
</b>
66
</p>
7-
87
<p align="center">
98
<a href="https://badge.fury.io/py/simplet5"><img src="https://badge.fury.io/py/simplet5.svg" alt="PyPI version" height="18"></a>
109

1110
<a href="https://badge.fury.io/py/simplet5">
1211
<img alt="Stars" src="https://img.shields.io/github/stars/Shivanandroy/simpleT5?color=blue">
1312
</a>
1413
<a href="https://pepy.tech/project/simplet5">
15-
<img alt="Stats" src="https://static.pepy.tech/personalized-badge/simplet5?period=month&units=international_system&left_color=black&right_color=orange&left_text=downloads/month">
14+
<img alt="Stats" src="https://static.pepy.tech/personalized-badge/simplet5?period=total&units=international_system&left_color=black&right_color=brightgreen&left_text=Downloads">
1615
</a>
1716
<a href="https://opensource.org/licenses/MIT">
1817
<img alt="License" src="https://img.shields.io/badge/License-MIT-yellow.svg">
@@ -41,7 +40,7 @@ from simplet5 import SimpleT5
4140
# instantiate
4241
model = SimpleT5()
4342

44-
# load (supports t5, mt5 models)
43+
# load (supports t5, mt5, byT5 models)
4544
model.from_pretrained("t5","t5-base")
4645

4746
# train
@@ -63,10 +62,6 @@ model.load_model("t5","path/to/trained/model/directory", use_gpu=False)
6362
# predict
6463
model.predict("input text for prediction")
6564

66-
# need faster inference on CPU, get ONNX support
67-
model.convert_and_load_onnx_model("path/to/T5 model/directory")
68-
model.onnx_predict("input text for prediction")
69-
7065
```
7166
## Articles
7267
- [Geek Culture: simpleT5 — Train T5 Models in Just 3 Lines of Code](https://medium.com/geekculture/simplet5-train-t5-models-in-just-3-lines-of-code-by-shivanand-roy-2021-354df5ae46ba)

requirements.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
pandas
22
sentencepiece
33
torch>=1.7.0,!=1.8.0
4-
transformers==4.6.1
5-
pytorch-lightning==1.3.3
6-
fastt5==0.0.6
4+
transformers==4.10.0
5+
pytorch-lightning==1.4.5

setup.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
setuptools.setup(
1010
name="simplet5",
11-
version="0.1.2",
11+
version="0.1.3",
1212
license="apache-2.0",
1313
author="Shivanand Roy",
1414
author_email="[email protected]",
@@ -41,9 +41,10 @@
4141
install_requires=[
4242
"sentencepiece",
4343
"torch>=1.7.0,!=1.8.0", # excludes torch v1.8.0
44-
"transformers==4.6.1",
45-
"pytorch-lightning==1.3.3",
46-
"fastt5==0.0.6",
44+
"transformers==4.10.0",
45+
"pytorch-lightning==1.4.5",
46+
"tqdm"
47+
# "fastt5==0.0.7",
4748
],
4849
classifiers=[
4950
"Intended Audience :: Developers",

simplet5/simplet5.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
AdamW,
77
T5ForConditionalGeneration,
88
MT5ForConditionalGeneration,
9+
ByT5Tokenizer,
910
PreTrainedTokenizer,
1011
T5TokenizerFast as T5Tokenizer,
1112
MT5TokenizerFast as MT5Tokenizer,
1213
)
1314
from transformers import AutoTokenizer
14-
from fastT5 import export_and_get_onnx_model
15+
16+
# from fastT5 import export_and_get_onnx_model
1517
from torch.utils.data import Dataset, DataLoader
1618
from transformers import AutoModelWithLMHead, AutoTokenizer
1719
import pytorch_lightning as pl
@@ -246,7 +248,7 @@ def training_epoch_end(self, training_step_outputs):
246248
torch.mean(torch.stack([x["loss"] for x in training_step_outputs])).item(),
247249
4,
248250
)
249-
path = f"{self.outputdir}/SimpleT5-epoch-{self.current_epoch}-train-loss-{str(avg_traning_loss)}"
251+
path = f"{self.outputdir}/simplet5-epoch-{self.current_epoch}-train-loss-{str(avg_traning_loss)}"
250252
self.tokenizer.save_pretrained(path)
251253
self.model.save_pretrained(path)
252254

@@ -282,11 +284,11 @@ def from_pretrained(self, model_type="t5", model_name="t5-base") -> None:
282284
self.model = MT5ForConditionalGeneration.from_pretrained(
283285
f"{model_name}", return_dict=True
284286
)
285-
# elif model_type == "byt5":
286-
# self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_name}")
287-
# self.model = T5ForConditionalGeneration.from_pretrained(
288-
# f"{model_name}", return_dict=True
289-
# )
287+
elif model_type == "byt5":
288+
self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_name}")
289+
self.model = T5ForConditionalGeneration.from_pretrained(
290+
f"{model_name}", return_dict=True
291+
)
290292

291293
def train(
292294
self,
@@ -385,9 +387,9 @@ def load_model(
385387
elif model_type == "mt5":
386388
self.model = MT5ForConditionalGeneration.from_pretrained(f"{model_dir}")
387389
self.tokenizer = MT5Tokenizer.from_pretrained(f"{model_dir}")
388-
# elif model_type == "byt5":
389-
# self.model = T5ForConditionalGeneration.from_pretrained(f"{model_dir}")
390-
# self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_dir}")
390+
elif model_type == "byt5":
391+
self.model = T5ForConditionalGeneration.from_pretrained(f"{model_dir}")
392+
self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_dir}")
391393

392394
if use_gpu:
393395
if torch.cuda.is_available():
@@ -459,18 +461,18 @@ def predict(
459461
]
460462
return preds
461463

462-
def convert_and_load_onnx_model(self, model_dir: str):
463-
""" returns ONNX model """
464-
self.onnx_model = export_and_get_onnx_model(model_dir)
465-
self.onnx_tokenizer = AutoTokenizer.from_pretrained(model_dir)
466-
467-
def onnx_predict(self, source_text: str):
468-
""" generates prediction from ONNX model """
469-
token = self.onnx_tokenizer(source_text, return_tensors="pt")
470-
tokens = self.onnx_model.generate(
471-
input_ids=token["input_ids"],
472-
attention_mask=token["attention_mask"],
473-
num_beams=2,
474-
)
475-
output = self.onnx_tokenizer.decode(tokens.squeeze(), skip_special_tokens=True)
476-
return output
464+
# def convert_and_load_onnx_model(self, model_dir: str):
465+
# """ returns ONNX model """
466+
# self.onnx_model = export_and_get_onnx_model(model_dir)
467+
# self.onnx_tokenizer = AutoTokenizer.from_pretrained(model_dir)
468+
469+
# def onnx_predict(self, source_text: str):
470+
# """ generates prediction from ONNX model """
471+
# token = self.onnx_tokenizer(source_text, return_tensors="pt")
472+
# tokens = self.onnx_model.generate(
473+
# input_ids=token["input_ids"],
474+
# attention_mask=token["attention_mask"],
475+
# num_beams=2,
476+
# )
477+
# output = self.onnx_tokenizer.decode(tokens.squeeze(), skip_special_tokens=True)
478+
# return output

0 commit comments

Comments
 (0)