Skip to content

Commit

Permalink
Merge pull request #236 from EricFillion/ef/seq-to-seq
Browse files Browse the repository at this point in the history
2.2.4: Text-to-text and more
  • Loading branch information
EricFillion authored Jun 17, 2021
2 parents 89202b2 + 2a0abe7 commit abbcbfc
Show file tree
Hide file tree
Showing 19 changed files with 472 additions and 8 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Happy Transformer is an package built on top of [Hugging Face's transformer libr
| Question Answering |||
| Next Sentence Prediction || |
| Token Classification || |
| Text-to-Text || |

## Quick Start
```sh
Expand All @@ -49,6 +50,7 @@ from happytransformer import HappyWordPrediction
- [Ted Brownlow](https://github.com/ted537) Maintainer

## Tutorials
[Text classification (training)](https://www.vennify.ai/train-text-classification-transformers/)

[Text classification (hate speech detection)](https://youtu.be/jti2sPQYzeQ)

Expand Down
2 changes: 2 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Happy Transformer is a package built on top of [Hugging Face's transformer libra
| Word Prediction |||
| Token Classification || |
| Next Sentence Prediction || |
| Text-to-Text || |

<span class="fs-8">
[GitHub](https://github.com/EricFillion/happy-transformer){: .btn .mr-30 }
Expand All @@ -33,6 +34,7 @@ Happy Transformer is a package built on top of [Hugging Face's transformer libra
[Create a text generation web app. Also learn how to fine-tune GPT-Neo](https://www.udemy.com/course/nlp-text-generation-python-web-app/?couponCode=LAUNCH)

## Free Tutorials
[Text classification (training)](https://www.vennify.ai/train-text-classification-transformers/)

[Text classification (hate speech detection)](https://youtu.be/jti2sPQYzeQ)

Expand Down
6 changes: 3 additions & 3 deletions docs/pages/1-text-generation/2-basic-usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ permalink: /text-generation/usage/

## Text Generation Basic Usage
### generate_text()
The method predict_masks() contains 4 arguments:
The method generate_text() contains 2 arguments:
1. text (string): The text prompt for the model -- it will try to continue the text
2. settings (GENSettings): See this [webpage](/text-generation/settings/) for more information
2. args (GENSettings): See this [webpage](/text-generation/settings/) for more information


Returns:
Expand All @@ -22,7 +22,7 @@ An object with a single field called "text"

from happytransformer import HappyGeneration, GENSettings
#--------------------------------------#
happy_gen = HappyGeneration() # default uses distilbert-base-uncased
happy_gen = HappyGeneration() # default uses gpt2
args = GENSettings(max_length=15)
result = happy_gen.generate_text("artificial intelligence is ", args=args)
print(result) # GenerationResult(text='\xa0a new field of research that has been gaining momentum in recent years.')
Expand Down
2 changes: 2 additions & 0 deletions docs/pages/2-text-classification/1-instantiation.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ number of labels, so if you use these models you can set num_labels freely

## Tutorials

[Text classification (training)](https://www.vennify.ai/train-text-classification-transformers/)

[Text classification (hate speech detection)](https://youtu.be/jti2sPQYzeQ)

[Text classification (sentiment analysis)](https://youtu.be/Ew72EAgM7FM)
2 changes: 1 addition & 1 deletion docs/pages/6-next-sentence/1-instantiation.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Initialize a HappyNextSentence() object to next sentence prediction
"bert-base-uncased" and "bert-large-uncased" that have not been finetuned


#### Example 4.0:
#### Example 6.0:
```python
from happytransformer import HappyNextSentence
# --------------------------------------#
Expand Down
25 changes: 25 additions & 0 deletions docs/pages/7-text-to-text/1-instantiation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
---
title: Text-to-Text
nav_order: 12
layout: page
permalink: /text-to-text/
has_children: true
---

## Text-to-Text

Initialize a HappyTextToText() object to perform text-to-text generation

**Initialization Arguments:**
1. model_type (string): Specify the model name in all caps, such as "T5" or "BART"
2. model_name(string): below are URLs that contains potential models:
[standard models](https://huggingface.co/models?pipeline_tag=text2text-generation) and [translation models](https://huggingface.co/models?pipeline_tag=translation)


#### Example 7.0:
```python
from happytransformer import HappyTextToText
# --------------------------------------#
happy_tt = HappyTextToText("T5", "t5-small") # default

```
32 changes: 32 additions & 0 deletions docs/pages/7-text-to-text/2-basic-usage.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
---
title: Usage
parent: Text-to-Text
nav_order: 1
layout: page
permalink: /text-to-text/usage/
---

## Text-to-Text Basic Usage
### generate_text()
The method generate_text() contains 2 arguments:
1. text (string): The text prompt for the model.
2. args (TTSettings): See this [webpage](/text-to-text/settings/) for more information


Returns:
An object with a single field called "text"


#### Example 7.1:
```python

from happytransformer import HappyTextToText, TTSettings
#--------------------------------------#
happy_tt = HappyTextToText() # default uses t5-small
top_p_sampling_settings = TTSettings(do_sample=True, top_k=0, top_p=0.8, temperature=0.7, min_length=20, max_length=20, early_stopping=True)
result = happy_tt.generate_text("translate English to French: nlp is a field of artificial intelligence", args=top_p_sampling_settings)
print(result) # TextToTextResult(text="nlp est un domaine de l'intelligence artificielle...")
print(result.text) # nlp est un domaine de l’intelligence artificielle. n

```

79 changes: 79 additions & 0 deletions docs/pages/7-text-to-text/3-settings.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
---
title: Settings
parent: Text-to-Text
nav_order: 2
layout: page
permalink: /text-to-text/settings/
---

# Text-to-Text Settings

By default a text generation algorithm called "greedy" is used.
This algorithm simply picks the most likely next word.


A class called TTSettings() is used to control which algorithm is used and its settings.
It is passed to the "args" parameter for HappyTextToText.generate_text().

```python
from happytransformer import TTSettings
```

TTSettings() contains the fields shown in Table 1.0

#### Table 7.0:

| Parameter |Default| Definition |
|----------------------|-------|----------------------------------------------------------------------------|
| min_length | 10 | Minimum number of generated tokens |
| max_length | 50 | Maximum number of generated tokens |
| do_sample | False | When True, picks words based on their conditional probability |
| early_stopping | False | When True, generation finishes if the EOS token is reached |
| num_beams | 1 | Number of steps for each search path |
| temperature | 1.0 | How sensitive the algorithm is to selecting low probability options |
| top_k | 50 | How many potential answers are considered when performing sampling |
| top_p | 1.0 | Min number of tokens are selected where their probabilities add up to top_p|
| no_repeat_ngram_size | 0 | The size of an n-gram that cannot occur more than once. (0=infinity) |


#### Examples 7.2:

```python
from happytransformer import HappyTextToText, TTSettings

#---------------------------------------------------
happy_tt = HappyTextToText("T5", "t5-small")

greedy_settings = TTSettings(no_repeat_ngram_size=2, max_length=20)
output_greedy = happy_tt.generate_text(
"translate English to French: nlp is a field of artificial intelligence ",
args=greedy_settings)

beam_settings = TTSettings(num_beams=5, max_length=20)
output_beam_search = happy_tt.generate_text(
"translate English to French: nlp is a field of artificial intelligence ",
args=beam_settings)

generic_sampling_settings = TTSettings(do_sample=True, top_k=0, temperature=0.7, max_length=20)
output_generic_sampling = happy_tt.generate_text(
"translate English to French: nlp is a field of artificial intelligence ",
args=generic_sampling_settings)

top_k_sampling_settings = TTSettings(do_sample=True, top_k=50, temperature=0.7, max_length=20)
output_top_k_sampling = happy_tt.generate_text(
"translate English to French: nlp is a field of artificial intelligence ",
args=top_k_sampling_settings)

top_p_sampling_settings = TTSettings(do_sample=True, top_k=0, top_p=0.8, temperature=0.7, max_length=20)
output_top_p_sampling = happy_tt.generate_text(
"translate English to French: nlp is a field of artificial intelligence ",
args=top_p_sampling_settings)

print("Greedy:", output_greedy.text) # Greedy: nlp est un domaine de l'intelligence artificielle
print("Beam:", output_beam_search.text) # Beam: nlp est un domaine de l'intelligence artificielle
print("Generic Sampling:", output_generic_sampling.text) # Generic Sampling: nlp est un champ d'intelligence artificielle
print("Top-k Sampling:", output_top_k_sampling.text) # Top-k Sampling: nlp est un domaine de l’intelligence artificielle
print("Top-p Sampling:", output_top_p_sampling.text) # Top-p Sampling: nlp est un domaine de l'intelligence artificielle

```

57 changes: 57 additions & 0 deletions examples/text-to-text/doc_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from happytransformer import HappyTextToText, TTSettings

def example_7_0():
# --------------------------------------#
happy_tt = HappyTextToText("T5", "t5-small") # default

def example_7_1():
# --------------------------------------#
happy_tt = HappyTextToText() # default uses t5-small
top_p_sampling_settings = TTSettings(do_sample=True, top_k=0, top_p=0.8, temperature=0.7, min_length=20, max_length=20, early_stopping=True)
result = happy_tt.generate_text("translate English to French: nlp is a field of artificial intelligence", args=top_p_sampling_settings)
print(result) # nlp est un domaine de l’intelligence artificielle. n
print(result.text) # nlp est un domaine de l’intelligence artificielle. n


def example_7_2():
happy_tt = HappyTextToText("T5", "t5-small")

greedy_settings = TTSettings(no_repeat_ngram_size=2, max_length=20)
output_greedy = happy_tt.generate_text(
"translate English to French: nlp is a field of artificial intelligence ",
args=greedy_settings)

beam_settings = TTSettings(num_beams=5, max_length=20)
output_beam_search = happy_tt.generate_text(
"translate English to French: nlp is a field of artificial intelligence ",
args=beam_settings)

generic_sampling_settings = TTSettings(do_sample=True, top_k=0, temperature=0.7, max_length=20)
output_generic_sampling = happy_tt.generate_text(
"translate English to French: nlp is a field of artificial intelligence ",
args=generic_sampling_settings)

top_k_sampling_settings = TTSettings(do_sample=True, top_k=50, temperature=0.7, max_length=20)
output_top_k_sampling = happy_tt.generate_text(
"translate English to French: nlp is a field of artificial intelligence ",
args=top_k_sampling_settings)

top_p_sampling_settings = TTSettings(do_sample=True, top_k=0, top_p=0.8, temperature=0.7, max_length=20)
output_top_p_sampling = happy_tt.generate_text(
"translate English to French: nlp is a field of artificial intelligence ",
args=top_p_sampling_settings)

print("Greedy:", output_greedy.text) # Greedy: nlp est un domaine de l'intelligence artificielle
print("Beam:", output_beam_search.text) # Beam: nlp est un domaine de l'intelligence artificielle
print("Generic Sampling:", output_generic_sampling.text) # Generic Sampling: nlp est un champ d'intelligence artificielle
print("Top-k Sampling:", output_top_k_sampling.text) # Top-k Sampling: nlp est un domaine de l’intelligence artificielle
print("Top-p Sampling:", output_top_p_sampling.text) # Top-p Sampling: nlp est un domaine de l'intelligence artificielle


def main():
# example_7_0()
example_7_1()
# example_7_2()

if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions happytransformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from happytransformer.happy_next_sentence import HappyNextSentence
from happytransformer.happy_token_classification import HappyTokenClassification
from happytransformer.happy_generation import HappyGeneration, GENSettings
from happytransformer.happy_text_to_text import HappyTextToText, TTSettings

from happytransformer.gen.default_args import ARGS_GEN_TRAIN, ARGS_GEN_EVAl
from happytransformer.qa.default_args import ARGS_QA_TRAIN, ARGS_QA_EVAl, ARGS_QA_TEST
Expand Down
10 changes: 8 additions & 2 deletions happytransformer/adaptors/berts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,15 @@ def preprocess_mask_text(text: str) -> str:

@staticmethod
def postprocess_mask_prediction_token(text):
return text[1:] if text[0] == "Ġ" else text
if text:
return text[1:] if text[0] == "Ġ" else text
else:
return ""

class AlbertAdaptor(Adaptor):
@staticmethod
def postprocess_mask_prediction_token(text):
return text[1:] if text[0] == "▁" else text
if text:
return text[1:] if text[0] == "▁" else text
else:
return ""
97 changes: 97 additions & 0 deletions happytransformer/happy_text_to_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
Contains a class called HappyTextToText which performs text to text generation
"""
from dataclasses import dataclass

from transformers import Text2TextGenerationPipeline, AutoModelForSeq2SeqLM

from happytransformer.happy_transformer import HappyTransformer
from happytransformer.tt.trainer import TTTrainer
from happytransformer.cuda_detect import detect_cuda_device_number
from happytransformer.adaptors import get_adaptor
from happytransformer.tt.trainer import TTTrainArgs, TTEvalArgs, TTTestArgs


@dataclass
class TextToTextResult:
text: str

@dataclass
class TTSettings:
min_length: int = 10
max_length: int = 50
do_sample: bool = False
early_stopping: bool = False
num_beams: int = 1
temperature: float = 1
top_k: int = 50
no_repeat_ngram_size: int = 0
top_p: float = 1


class HappyTextToText(HappyTransformer):
"""
A user facing class for text to text generation
"""
def __init__(self, model_type: str = "T5", model_name: str = "t5-small", load_path: str = ""):

self.adaptor = get_adaptor(model_type)

if load_path != "":
model = AutoModelForSeq2SeqLM.from_pretrained(load_path)
else:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)


super().__init__(model_type, model_name, model)

device_number = detect_cuda_device_number()

self._pipeline = Text2TextGenerationPipeline(model=self.model,
tokenizer=self.tokenizer, device=device_number)

self._trainer = TTTrainer(self.model, model_type, self.tokenizer, self._device, self.logger)


def __assert_default_text_is_val(self, text):
"""
Ensures the input's text input is valid.
Raises a Value Error if the text input is not valid.
:param text: The value the user inputs for the "text" parameter
"""

if not isinstance(text, str):
raise ValueError("The text input must be a string")
if not text:
raise ValueError("The text input must have at least one character")


def generate_text(self, text: str,
args: TTSettings = TTSettings()) -> TextToTextResult:
"""
:param text: starting text that the model uses as a prompt to continue it.
:param args: A GENSettings object
:return: A TextToTextResult() object
"""
self.__assert_default_text_is_val(text)

output = self._pipeline(text, min_length=args.min_length,
max_length=args.max_length,
do_sample=args.do_sample,
early_stopping=args.early_stopping,
num_beams=args.num_beams,
temperature=args.temperature,
top_k=args.top_k,
no_repeat_ngram_size=args.no_repeat_ngram_size,
top_p=args.top_p,
)
return TextToTextResult(text=output[0]['generated_text'])

def train(self, input_filepath, args=TTTrainArgs):
raise NotImplementedError("train() is currently not available")

def eval(self, input_filepath, args=TTEvalArgs):
raise NotImplementedError("eval() is currently not available")

def test(self, input_filepath, args=TTTestArgs):
raise NotImplementedError("test() is currently not available")
Loading

0 comments on commit abbcbfc

Please sign in to comment.