Skip to content

Commit

Permalink
Merge pull request #254 from EricFillion/ef/t-t-finetuning
Browse files Browse the repository at this point in the history
Text-to-Text Finetuning
  • Loading branch information
EricFillion authored Aug 15, 2021
2 parents d48bf44 + aeabea7 commit da9929b
Show file tree
Hide file tree
Showing 21 changed files with 517 additions and 76 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ examples/question_answering/eval.txt
examples/generation/train.txt
examples/generation/eval.txt

examples/text-to-text/train.csv
examples/text-to-text/eval.csv

main.py
tests/model/

Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
New Course: Create a text generation web app. Also learn how to fine-tune GPT-Neo [link](https://www.udemy.com/course/nlp-text-generation-python-web-app/?couponCode=LAUNCH)


Join our brand new Discord server: [![Support Server](https://img.shields.io/discord/839263772312862740.svg?label=Discord&logo=Discord&colorB=7289da&style=?style=flat-square&logo=appveyor)](https://discord.gg/psVwe3wfTb)
Join our Discord server: [![Support Server](https://img.shields.io/discord/839263772312862740.svg?label=Discord&logo=Discord&colorB=7289da&style=?style=flat-square&logo=appveyor)](https://discord.gg/psVwe3wfTb)



Expand All @@ -26,9 +26,9 @@ Happy Transformer is an package built on top of [Hugging Face's transformer libr
| Text Classification |||
| Word Prediction |||
| Question Answering |||
| Text-to-Text |||
| Next Sentence Prediction || |
| Token Classification || |
| Text-to-Text || |

## Quick Start
```sh
Expand Down
5 changes: 5 additions & 0 deletions data/tt/train-eval-grammar.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
input,target
grammar: This sentences has bad grammar's error and overall is quality low?,This sentence has bad grammar errors and overall is low quality.
grammar: This sentence: had bad grammar?,This sentence has bad grammar.
grammar: not correct sentence?,This sentence is not correct.
grammar: You're laptop is broken,Your laptop is broken
2 changes: 2 additions & 0 deletions data/tt/train-eval-translate.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
input,target
"translate English to Spanish: Hello, I like to eat apples.","Hola, me gusta comer manzanas."
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ Happy Transformer is a package built on top of [Hugging Face's transformer libra
| Text Classification |||
| Question Answering |||
| Word Prediction |||
| Text-to-Text |||
| Token Classification || |
| Next Sentence Prediction || |
| Text-to-Text || |

<span class="fs-8">
[GitHub](https://github.com/EricFillion/happy-transformer){: .btn .mr-30 }
Expand Down
2 changes: 1 addition & 1 deletion docs/pages/1-text-generation/4-finetuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ permalink: /text-generation/finetuning/

## Text Generation Finetuning

TextGeneration contains two methods for training
HappyTextGeneration contains two methods for training
- train(): fine-tune the model to understand a body of text better
- eval(): determine how well the model performs

Expand Down
8 changes: 8 additions & 0 deletions docs/pages/3-news.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ permalink: /news/
nav_order: 4

---
## News:
### August 14th, 2021
**Introducing Version 2.3.0!**

New Features:
- Text-to-text fine-tuning is now available!


## News:
### May 4th, 2021
**Introducing Version 2.2.0!**
Expand Down
94 changes: 94 additions & 0 deletions docs/pages/7-text-to-text/4-finetuning.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
---
title: Finetuning
parent: Text-to-Text
nav_order: 3
layout: page
permalink: /text-to-text/finetuning/
---

## Text-to-text Finetuning

HappyTextToText contains two methods for training
- train(): fine-tune the model to convert a standalone text to another standalone piece of text
- eval(): determines how well the model performs

### train()

inputs:
1. input_filepath (string): a path file to a csv file as shown in table 7.1
2. args (TTTrainArgs): a dataclass with the same fields types as shown in Table 7.2.


#### Table 7.1
Contains two columns with the following header values: input and target

| input |target |
|-------------------------------|---------------------|
| grammar: I has poor grammars | I have poor grammar |
| grammar: I wants too plays | I want to play |


#### Table 7.2

| Parameter |Default|
|-------------------------------|-------|
| learning_rate | 5e-5 |
| num_train_epochs | 3 |
| batch_size | 1 |
| weight_decay | 0 |
| adam_beta1 | 0.9 |
| adam_beta2 | 0.999 |
| adam_epsilon | 1e-8 |
| max_grad_norm | 1.0 |
| preprocessing_processes | 1 |
| max_input_length | 1024 |
| max_output_length | 1024 |


Information about the learning parameters can be found [here](/learning-parameters/)


preprocessing_processes: Number of processes used for preprocessing. We recommend 1-4.
max_input_length: The maximum number of tokens for the input. The rest get truncated.
max_output_length: Ditto, except for the output.


#### Example 7.3:
```python
from happytransformer import HappyTextToText, TTTrainArgs
# --------------------------------------#
happy_tt = HappyTextToText()
args = TTTrainArgs(num_train_epochs=1)
happy_tt.train("../../data/tt/train-eval-grammar.csv", args=args)
```

### eval()
Input:
1. input_filepath (string): a path file to a csv file with the same format as described for the training data in table 7.1
2. args (TTEvalArgs): a dataclass with the same fields shown in Table 7.3

#### Table 7.3

| Parameter |Default|
|-------------------------------|-------|
| preprocessing_processes | 1 |
| max_input_length | 1024 |
| max_output_length | 1024 |

See Table 7.1 for more information


Output: An object with a single field called "loss"

#### Example 1.4
```python
from happytransformer import HappyTextToText, TTEvalArgs
# --------------------------------------#
happy_tt = HappyTextToText()
args = TTEvalArgs(preprocessing_processes=1)
result = happy_tt.eval("../../data/tt/train-eval-grammar.csv", args=args)
print(type(result)) # <class 'happytransformer.happy_trainer.EvalResult'>
print(result) # EvalResult(loss=3.2277376651763916)
print(result.loss) # 3.2277376651763916

```
19 changes: 17 additions & 2 deletions examples/text-to-text/doc_examples.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from happytransformer import HappyTextToText, TTSettings
from happytransformer import HappyTextToText, TTSettings, TTTrainArgs, TTEvalArgs

def example_7_0():
# --------------------------------------#
Expand Down Expand Up @@ -47,11 +47,26 @@ def example_7_2():
print("Top-k Sampling:", output_top_k_sampling.text) # Top-k Sampling: nlp est un domaine de l’intelligence artificielle
print("Top-p Sampling:", output_top_p_sampling.text) # Top-p Sampling: nlp est un domaine de l'intelligence artificielle

def example_7_3():
happy_tt = HappyTextToText()
args = TTTrainArgs(num_train_epochs=1)
happy_tt.train("../../data/tt/train-eval-grammar.csv", args=args)

def example_7_4():
happy_tt = HappyTextToText()
args = TTEvalArgs(preprocessing_processes=1)
result = happy_tt.eval("../../data/tt/train-eval-grammar.csv", args=args)
print(type(result)) # <class 'happytransformer.happy_trainer.EvalResult'>
print(result) # EvalResult(loss=3.2277376651763916)
print(result.loss) # 3.2277376651763916


def main():
# example_7_0()
example_7_1()
#example_7_1()
# example_7_2()
# example_7_3()
example_7_4()

if __name__ == "__main__":
main()
48 changes: 48 additions & 0 deletions examples/text-to-text/training_grammar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import csv
from happytransformer.happy_text_to_text import HappyTextToText, TTTrainArgs
from datasets import load_dataset

def main():
happy_tt = HappyTextToText("T5", "t5-base")
input_text = "grammar: This sentences had bad grammars and spelling. "
before_text = happy_tt.generate_text(input_text).text

# There's no training split. Just eval and test. So, we'll use eval for train and test for eval.
# 755 cases, but each case has 4 corrections so there are really 3020
train_dataset = load_dataset("jfleg", split='validation[:]')

# 748 cases, but again, each case has 4 correction so there are really
eval_dataset = load_dataset("jfleg", split='test[:]')

generate_csv("train.csv", train_dataset)
generate_csv("eval.csv", eval_dataset)

before_loss = happy_tt.eval("eval.csv").loss

happy_tt.train("train.csv")

after_text = happy_tt.generate_text(input_text).text
after_loss = happy_tt.eval("eval.csv").loss

print("before loss:", before_loss)
print("after loss:", after_loss)
print("------------------------------------")

print("input text:", input_text)
print("before text:", before_text)
print("after text:", after_text)


def generate_csv(csv_path, dataset):
with open(csv_path, 'w', newline='') as csvfile:
writter = csv.writer(csvfile)
writter.writerow(["input", "target"])
for case in dataset:
input_text = "grammar: " + case["sentence"]
for correction in case["corrections"]:
# a few of the case are have None values. We'll skip them
if input_text and correction:
writter.writerow([input_text, correction])

if __name__ == "__main__":
main()
43 changes: 43 additions & 0 deletions examples/text-to-text/training_summarization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import csv
from happytransformer.happy_text_to_text import HappyTextToText, TTTrainArgs
from datasets import load_dataset

def main():
happy_tt = HappyTextToText("T5", "t5-base")
# source ; https://en.wikipedia.org/wiki/Transformer_(machine_learning_model)
text = "A transformer is a deep learning model that adopts the mechanism of attention, differentially weighing the significance of each part of the input data. It is used primarily in the field of natural language processing (NLP)"
before_text = happy_tt.generate_text("summarize: "+ text)

train_dataset = load_dataset("xsum", split='train[0:1999]')
eval_dataset = load_dataset("xsum", split='validation[0:499]')

generate_csv("train.csv", train_dataset)
generate_csv("eval.csv", eval_dataset)

before_result = happy_tt.eval("eval.csv")

args = TTTrainArgs(max_input_length=1024, max_output_length=128)

happy_tt.train("train.csv", args=args)
after_text = happy_tt.generate_text("summarize: " + text)
after_result = happy_tt.eval("eval.csv")

print("before result:", before_result)
print("after result:", after_result)

print("before text:", before_text)
print("after text:", after_text)

def generate_csv(csv_path, dataset):
with open(csv_path, 'w', newline='') as csvfile:
writter = csv.writer(csvfile)
writter.writerow(["input", "target"])
for case in dataset:
long_text = "summarize" + case["document"]
short_text = case["summary"]
writter.writerow([long_text, short_text])



if __name__ == "__main__":
main()
45 changes: 45 additions & 0 deletions examples/text-to-text/training_translation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import csv
from happytransformer.happy_text_to_text import HappyTextToText, TTTrainArgs
from datasets import load_dataset

def main():
happy_tt = HappyTextToText("MT5", "google/mt5-base")
text = "Hello, I like to eat apples."
# Google translate translation: سلام من عاشق خوردن سیب هستم.
before_text = happy_tt.generate_text("translate English to Persian: " + text)

train_dataset = load_dataset("persiannlp/parsinlu_translation_en_fa", split='train[0:3999]')
eval_dataset = load_dataset("persiannlp/parsinlu_translation_en_fa", split='validation[0:399]')

generate_csv("train.csv", train_dataset)
generate_csv("eval.csv", eval_dataset)

eval_args = TTTrainArgs(max_input_length=1024, max_output_length=1024)
before_loss = happy_tt.eval("eval.csv", args=eval_args)

train_args = TTTrainArgs(num_train_epochs=1, max_input_length=1024, max_output_length=1024)
happy_tt.train("train.csv", args=train_args)

after_text = happy_tt.generate_text("translate English to Persian: " + text)

after_loss = happy_tt.eval("eval.csv", args=eval_args)

print("before loss:", before_loss)
print("after loss:", after_loss)

print("before text:", before_text)
print("after text:", after_text)


def generate_csv(csv_path, dataset):
with open(csv_path, 'w', newline='') as csvfile:
writter = csv.writer(csvfile)
writter.writerow(["input", "target"])
for case in dataset:
english_text = "translate English to Persian: " + case["source"]
persian_text = case["targets"][0]
writter.writerow([english_text, persian_text])


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion happytransformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from happytransformer.tc.default_args import ARGS_TC_TRAIN, ARGS_TC_EVAL, ARGS_TC_TEST
from happytransformer.wp.default_args import ARGS_WP_TRAIN, ARGS_WP_EVAl


from happytransformer.gen.trainer import GENTrainArgs, GENEvalArgs
from happytransformer.qa.trainer import QATestArgs, QAEvalArgs, QATrainArgs
from happytransformer.tc.trainer import TCTrainArgs, TCEvalArgs, TCTestArgs
from happytransformer.wp.trainer import WPTrainArgs, WPEvalArgs
from happytransformer.tt.trainer import TTTrainArgs, TTEvalArgs


from happytransformer.happy_generation import (
HappyGeneration, GENSettings
Expand Down
1 change: 0 additions & 1 deletion happytransformer/gen/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def eval(self, input_filepath, dataclass_args: GENEvalArgs):
self.logger.info("Evaluating...")

result = self._run_eval(tokenized_dataset['eval'], default_data_collator, dataclass_args)

return EvalResult(loss=result["eval_loss"])

def test(self, input_filepath, solve, args):
Expand Down
Loading

0 comments on commit da9929b

Please sign in to comment.