|
6 | 6 | AdamW,
|
7 | 7 | T5ForConditionalGeneration,
|
8 | 8 | MT5ForConditionalGeneration,
|
| 9 | + ByT5Tokenizer, |
9 | 10 | PreTrainedTokenizer,
|
10 | 11 | T5TokenizerFast as T5Tokenizer,
|
11 | 12 | MT5TokenizerFast as MT5Tokenizer,
|
12 | 13 | )
|
13 | 14 | from transformers import AutoTokenizer
|
14 |
| -from fastT5 import export_and_get_onnx_model |
| 15 | + |
| 16 | +# from fastT5 import export_and_get_onnx_model |
15 | 17 | from torch.utils.data import Dataset, DataLoader
|
16 | 18 | from transformers import AutoModelWithLMHead, AutoTokenizer
|
17 | 19 | import pytorch_lightning as pl
|
@@ -246,7 +248,7 @@ def training_epoch_end(self, training_step_outputs):
|
246 | 248 | torch.mean(torch.stack([x["loss"] for x in training_step_outputs])).item(),
|
247 | 249 | 4,
|
248 | 250 | )
|
249 |
| - path = f"{self.outputdir}/SimpleT5-epoch-{self.current_epoch}-train-loss-{str(avg_traning_loss)}" |
| 251 | + path = f"{self.outputdir}/simplet5-epoch-{self.current_epoch}-train-loss-{str(avg_traning_loss)}" |
250 | 252 | self.tokenizer.save_pretrained(path)
|
251 | 253 | self.model.save_pretrained(path)
|
252 | 254 |
|
@@ -282,11 +284,11 @@ def from_pretrained(self, model_type="t5", model_name="t5-base") -> None:
|
282 | 284 | self.model = MT5ForConditionalGeneration.from_pretrained(
|
283 | 285 | f"{model_name}", return_dict=True
|
284 | 286 | )
|
285 |
| - # elif model_type == "byt5": |
286 |
| - # self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_name}") |
287 |
| - # self.model = T5ForConditionalGeneration.from_pretrained( |
288 |
| - # f"{model_name}", return_dict=True |
289 |
| - # ) |
| 287 | + elif model_type == "byt5": |
| 288 | + self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_name}") |
| 289 | + self.model = T5ForConditionalGeneration.from_pretrained( |
| 290 | + f"{model_name}", return_dict=True |
| 291 | + ) |
290 | 292 |
|
291 | 293 | def train(
|
292 | 294 | self,
|
@@ -385,9 +387,9 @@ def load_model(
|
385 | 387 | elif model_type == "mt5":
|
386 | 388 | self.model = MT5ForConditionalGeneration.from_pretrained(f"{model_dir}")
|
387 | 389 | self.tokenizer = MT5Tokenizer.from_pretrained(f"{model_dir}")
|
388 |
| - # elif model_type == "byt5": |
389 |
| - # self.model = T5ForConditionalGeneration.from_pretrained(f"{model_dir}") |
390 |
| - # self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_dir}") |
| 390 | + elif model_type == "byt5": |
| 391 | + self.model = T5ForConditionalGeneration.from_pretrained(f"{model_dir}") |
| 392 | + self.tokenizer = ByT5Tokenizer.from_pretrained(f"{model_dir}") |
391 | 393 |
|
392 | 394 | if use_gpu:
|
393 | 395 | if torch.cuda.is_available():
|
@@ -459,18 +461,18 @@ def predict(
|
459 | 461 | ]
|
460 | 462 | return preds
|
461 | 463 |
|
462 |
| - def convert_and_load_onnx_model(self, model_dir: str): |
463 |
| - """ returns ONNX model """ |
464 |
| - self.onnx_model = export_and_get_onnx_model(model_dir) |
465 |
| - self.onnx_tokenizer = AutoTokenizer.from_pretrained(model_dir) |
466 |
| - |
467 |
| - def onnx_predict(self, source_text: str): |
468 |
| - """ generates prediction from ONNX model """ |
469 |
| - token = self.onnx_tokenizer(source_text, return_tensors="pt") |
470 |
| - tokens = self.onnx_model.generate( |
471 |
| - input_ids=token["input_ids"], |
472 |
| - attention_mask=token["attention_mask"], |
473 |
| - num_beams=2, |
474 |
| - ) |
475 |
| - output = self.onnx_tokenizer.decode(tokens.squeeze(), skip_special_tokens=True) |
476 |
| - return output |
| 464 | + # def convert_and_load_onnx_model(self, model_dir: str): |
| 465 | + # """ returns ONNX model """ |
| 466 | + # self.onnx_model = export_and_get_onnx_model(model_dir) |
| 467 | + # self.onnx_tokenizer = AutoTokenizer.from_pretrained(model_dir) |
| 468 | + |
| 469 | + # def onnx_predict(self, source_text: str): |
| 470 | + # """ generates prediction from ONNX model """ |
| 471 | + # token = self.onnx_tokenizer(source_text, return_tensors="pt") |
| 472 | + # tokens = self.onnx_model.generate( |
| 473 | + # input_ids=token["input_ids"], |
| 474 | + # attention_mask=token["attention_mask"], |
| 475 | + # num_beams=2, |
| 476 | + # ) |
| 477 | + # output = self.onnx_tokenizer.decode(tokens.squeeze(), skip_special_tokens=True) |
| 478 | + # return output |
0 commit comments