-
Notifications
You must be signed in to change notification settings - Fork 68
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #254 from EricFillion/ef/t-t-finetuning
Text-to-Text Finetuning
- Loading branch information
Showing
21 changed files
with
517 additions
and
76 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
input,target | ||
grammar: This sentences has bad grammar's error and overall is quality low?,This sentence has bad grammar errors and overall is low quality. | ||
grammar: This sentence: had bad grammar?,This sentence has bad grammar. | ||
grammar: not correct sentence?,This sentence is not correct. | ||
grammar: You're laptop is broken,Your laptop is broken |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
input,target | ||
"translate English to Spanish: Hello, I like to eat apples.","Hola, me gusta comer manzanas." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
--- | ||
title: Finetuning | ||
parent: Text-to-Text | ||
nav_order: 3 | ||
layout: page | ||
permalink: /text-to-text/finetuning/ | ||
--- | ||
|
||
## Text-to-text Finetuning | ||
|
||
HappyTextToText contains two methods for training | ||
- train(): fine-tune the model to convert a standalone text to another standalone piece of text | ||
- eval(): determines how well the model performs | ||
|
||
### train() | ||
|
||
inputs: | ||
1. input_filepath (string): a path file to a csv file as shown in table 7.1 | ||
2. args (TTTrainArgs): a dataclass with the same fields types as shown in Table 7.2. | ||
|
||
|
||
#### Table 7.1 | ||
Contains two columns with the following header values: input and target | ||
|
||
| input |target | | ||
|-------------------------------|---------------------| | ||
| grammar: I has poor grammars | I have poor grammar | | ||
| grammar: I wants too plays | I want to play | | ||
|
||
|
||
#### Table 7.2 | ||
|
||
| Parameter |Default| | ||
|-------------------------------|-------| | ||
| learning_rate | 5e-5 | | ||
| num_train_epochs | 3 | | ||
| batch_size | 1 | | ||
| weight_decay | 0 | | ||
| adam_beta1 | 0.9 | | ||
| adam_beta2 | 0.999 | | ||
| adam_epsilon | 1e-8 | | ||
| max_grad_norm | 1.0 | | ||
| preprocessing_processes | 1 | | ||
| max_input_length | 1024 | | ||
| max_output_length | 1024 | | ||
|
||
|
||
Information about the learning parameters can be found [here](/learning-parameters/) | ||
|
||
|
||
preprocessing_processes: Number of processes used for preprocessing. We recommend 1-4. | ||
max_input_length: The maximum number of tokens for the input. The rest get truncated. | ||
max_output_length: Ditto, except for the output. | ||
|
||
|
||
#### Example 7.3: | ||
```python | ||
from happytransformer import HappyTextToText, TTTrainArgs | ||
# --------------------------------------# | ||
happy_tt = HappyTextToText() | ||
args = TTTrainArgs(num_train_epochs=1) | ||
happy_tt.train("../../data/tt/train-eval-grammar.csv", args=args) | ||
``` | ||
|
||
### eval() | ||
Input: | ||
1. input_filepath (string): a path file to a csv file with the same format as described for the training data in table 7.1 | ||
2. args (TTEvalArgs): a dataclass with the same fields shown in Table 7.3 | ||
|
||
#### Table 7.3 | ||
|
||
| Parameter |Default| | ||
|-------------------------------|-------| | ||
| preprocessing_processes | 1 | | ||
| max_input_length | 1024 | | ||
| max_output_length | 1024 | | ||
|
||
See Table 7.1 for more information | ||
|
||
|
||
Output: An object with a single field called "loss" | ||
|
||
#### Example 1.4 | ||
```python | ||
from happytransformer import HappyTextToText, TTEvalArgs | ||
# --------------------------------------# | ||
happy_tt = HappyTextToText() | ||
args = TTEvalArgs(preprocessing_processes=1) | ||
result = happy_tt.eval("../../data/tt/train-eval-grammar.csv", args=args) | ||
print(type(result)) # <class 'happytransformer.happy_trainer.EvalResult'> | ||
print(result) # EvalResult(loss=3.2277376651763916) | ||
print(result.loss) # 3.2277376651763916 | ||
|
||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import csv | ||
from happytransformer.happy_text_to_text import HappyTextToText, TTTrainArgs | ||
from datasets import load_dataset | ||
|
||
def main(): | ||
happy_tt = HappyTextToText("T5", "t5-base") | ||
input_text = "grammar: This sentences had bad grammars and spelling. " | ||
before_text = happy_tt.generate_text(input_text).text | ||
|
||
# There's no training split. Just eval and test. So, we'll use eval for train and test for eval. | ||
# 755 cases, but each case has 4 corrections so there are really 3020 | ||
train_dataset = load_dataset("jfleg", split='validation[:]') | ||
|
||
# 748 cases, but again, each case has 4 correction so there are really | ||
eval_dataset = load_dataset("jfleg", split='test[:]') | ||
|
||
generate_csv("train.csv", train_dataset) | ||
generate_csv("eval.csv", eval_dataset) | ||
|
||
before_loss = happy_tt.eval("eval.csv").loss | ||
|
||
happy_tt.train("train.csv") | ||
|
||
after_text = happy_tt.generate_text(input_text).text | ||
after_loss = happy_tt.eval("eval.csv").loss | ||
|
||
print("before loss:", before_loss) | ||
print("after loss:", after_loss) | ||
print("------------------------------------") | ||
|
||
print("input text:", input_text) | ||
print("before text:", before_text) | ||
print("after text:", after_text) | ||
|
||
|
||
def generate_csv(csv_path, dataset): | ||
with open(csv_path, 'w', newline='') as csvfile: | ||
writter = csv.writer(csvfile) | ||
writter.writerow(["input", "target"]) | ||
for case in dataset: | ||
input_text = "grammar: " + case["sentence"] | ||
for correction in case["corrections"]: | ||
# a few of the case are have None values. We'll skip them | ||
if input_text and correction: | ||
writter.writerow([input_text, correction]) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import csv | ||
from happytransformer.happy_text_to_text import HappyTextToText, TTTrainArgs | ||
from datasets import load_dataset | ||
|
||
def main(): | ||
happy_tt = HappyTextToText("T5", "t5-base") | ||
# source ; https://en.wikipedia.org/wiki/Transformer_(machine_learning_model) | ||
text = "A transformer is a deep learning model that adopts the mechanism of attention, differentially weighing the significance of each part of the input data. It is used primarily in the field of natural language processing (NLP)" | ||
before_text = happy_tt.generate_text("summarize: "+ text) | ||
|
||
train_dataset = load_dataset("xsum", split='train[0:1999]') | ||
eval_dataset = load_dataset("xsum", split='validation[0:499]') | ||
|
||
generate_csv("train.csv", train_dataset) | ||
generate_csv("eval.csv", eval_dataset) | ||
|
||
before_result = happy_tt.eval("eval.csv") | ||
|
||
args = TTTrainArgs(max_input_length=1024, max_output_length=128) | ||
|
||
happy_tt.train("train.csv", args=args) | ||
after_text = happy_tt.generate_text("summarize: " + text) | ||
after_result = happy_tt.eval("eval.csv") | ||
|
||
print("before result:", before_result) | ||
print("after result:", after_result) | ||
|
||
print("before text:", before_text) | ||
print("after text:", after_text) | ||
|
||
def generate_csv(csv_path, dataset): | ||
with open(csv_path, 'w', newline='') as csvfile: | ||
writter = csv.writer(csvfile) | ||
writter.writerow(["input", "target"]) | ||
for case in dataset: | ||
long_text = "summarize" + case["document"] | ||
short_text = case["summary"] | ||
writter.writerow([long_text, short_text]) | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import csv | ||
from happytransformer.happy_text_to_text import HappyTextToText, TTTrainArgs | ||
from datasets import load_dataset | ||
|
||
def main(): | ||
happy_tt = HappyTextToText("MT5", "google/mt5-base") | ||
text = "Hello, I like to eat apples." | ||
# Google translate translation: سلام من عاشق خوردن سیب هستم. | ||
before_text = happy_tt.generate_text("translate English to Persian: " + text) | ||
|
||
train_dataset = load_dataset("persiannlp/parsinlu_translation_en_fa", split='train[0:3999]') | ||
eval_dataset = load_dataset("persiannlp/parsinlu_translation_en_fa", split='validation[0:399]') | ||
|
||
generate_csv("train.csv", train_dataset) | ||
generate_csv("eval.csv", eval_dataset) | ||
|
||
eval_args = TTTrainArgs(max_input_length=1024, max_output_length=1024) | ||
before_loss = happy_tt.eval("eval.csv", args=eval_args) | ||
|
||
train_args = TTTrainArgs(num_train_epochs=1, max_input_length=1024, max_output_length=1024) | ||
happy_tt.train("train.csv", args=train_args) | ||
|
||
after_text = happy_tt.generate_text("translate English to Persian: " + text) | ||
|
||
after_loss = happy_tt.eval("eval.csv", args=eval_args) | ||
|
||
print("before loss:", before_loss) | ||
print("after loss:", after_loss) | ||
|
||
print("before text:", before_text) | ||
print("after text:", after_text) | ||
|
||
|
||
def generate_csv(csv_path, dataset): | ||
with open(csv_path, 'w', newline='') as csvfile: | ||
writter = csv.writer(csvfile) | ||
writter.writerow(["input", "target"]) | ||
for case in dataset: | ||
english_text = "translate English to Persian: " + case["source"] | ||
persian_text = case["targets"][0] | ||
writter.writerow([english_text, persian_text]) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.