-
Notifications
You must be signed in to change notification settings - Fork 68
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #236 from EricFillion/ef/seq-to-seq
2.2.4: Text-to-text and more
- Loading branch information
Showing
19 changed files
with
472 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
--- | ||
title: Text-to-Text | ||
nav_order: 12 | ||
layout: page | ||
permalink: /text-to-text/ | ||
has_children: true | ||
--- | ||
|
||
## Text-to-Text | ||
|
||
Initialize a HappyTextToText() object to perform text-to-text generation | ||
|
||
**Initialization Arguments:** | ||
1. model_type (string): Specify the model name in all caps, such as "T5" or "BART" | ||
2. model_name(string): below are URLs that contains potential models: | ||
[standard models](https://huggingface.co/models?pipeline_tag=text2text-generation) and [translation models](https://huggingface.co/models?pipeline_tag=translation) | ||
|
||
|
||
#### Example 7.0: | ||
```python | ||
from happytransformer import HappyTextToText | ||
# --------------------------------------# | ||
happy_tt = HappyTextToText("T5", "t5-small") # default | ||
|
||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
--- | ||
title: Usage | ||
parent: Text-to-Text | ||
nav_order: 1 | ||
layout: page | ||
permalink: /text-to-text/usage/ | ||
--- | ||
|
||
## Text-to-Text Basic Usage | ||
### generate_text() | ||
The method generate_text() contains 2 arguments: | ||
1. text (string): The text prompt for the model. | ||
2. args (TTSettings): See this [webpage](/text-to-text/settings/) for more information | ||
|
||
|
||
Returns: | ||
An object with a single field called "text" | ||
|
||
|
||
#### Example 7.1: | ||
```python | ||
|
||
from happytransformer import HappyTextToText, TTSettings | ||
#--------------------------------------# | ||
happy_tt = HappyTextToText() # default uses t5-small | ||
top_p_sampling_settings = TTSettings(do_sample=True, top_k=0, top_p=0.8, temperature=0.7, min_length=20, max_length=20, early_stopping=True) | ||
result = happy_tt.generate_text("translate English to French: nlp is a field of artificial intelligence", args=top_p_sampling_settings) | ||
print(result) # TextToTextResult(text="nlp est un domaine de l'intelligence artificielle...") | ||
print(result.text) # nlp est un domaine de l’intelligence artificielle. n | ||
|
||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
--- | ||
title: Settings | ||
parent: Text-to-Text | ||
nav_order: 2 | ||
layout: page | ||
permalink: /text-to-text/settings/ | ||
--- | ||
|
||
# Text-to-Text Settings | ||
|
||
By default a text generation algorithm called "greedy" is used. | ||
This algorithm simply picks the most likely next word. | ||
|
||
|
||
A class called TTSettings() is used to control which algorithm is used and its settings. | ||
It is passed to the "args" parameter for HappyTextToText.generate_text(). | ||
|
||
```python | ||
from happytransformer import TTSettings | ||
``` | ||
|
||
TTSettings() contains the fields shown in Table 1.0 | ||
|
||
#### Table 7.0: | ||
|
||
| Parameter |Default| Definition | | ||
|----------------------|-------|----------------------------------------------------------------------------| | ||
| min_length | 10 | Minimum number of generated tokens | | ||
| max_length | 50 | Maximum number of generated tokens | | ||
| do_sample | False | When True, picks words based on their conditional probability | | ||
| early_stopping | False | When True, generation finishes if the EOS token is reached | | ||
| num_beams | 1 | Number of steps for each search path | | ||
| temperature | 1.0 | How sensitive the algorithm is to selecting low probability options | | ||
| top_k | 50 | How many potential answers are considered when performing sampling | | ||
| top_p | 1.0 | Min number of tokens are selected where their probabilities add up to top_p| | ||
| no_repeat_ngram_size | 0 | The size of an n-gram that cannot occur more than once. (0=infinity) | | ||
|
||
|
||
#### Examples 7.2: | ||
|
||
```python | ||
from happytransformer import HappyTextToText, TTSettings | ||
|
||
#--------------------------------------------------- | ||
happy_tt = HappyTextToText("T5", "t5-small") | ||
|
||
greedy_settings = TTSettings(no_repeat_ngram_size=2, max_length=20) | ||
output_greedy = happy_tt.generate_text( | ||
"translate English to French: nlp is a field of artificial intelligence ", | ||
args=greedy_settings) | ||
|
||
beam_settings = TTSettings(num_beams=5, max_length=20) | ||
output_beam_search = happy_tt.generate_text( | ||
"translate English to French: nlp is a field of artificial intelligence ", | ||
args=beam_settings) | ||
|
||
generic_sampling_settings = TTSettings(do_sample=True, top_k=0, temperature=0.7, max_length=20) | ||
output_generic_sampling = happy_tt.generate_text( | ||
"translate English to French: nlp is a field of artificial intelligence ", | ||
args=generic_sampling_settings) | ||
|
||
top_k_sampling_settings = TTSettings(do_sample=True, top_k=50, temperature=0.7, max_length=20) | ||
output_top_k_sampling = happy_tt.generate_text( | ||
"translate English to French: nlp is a field of artificial intelligence ", | ||
args=top_k_sampling_settings) | ||
|
||
top_p_sampling_settings = TTSettings(do_sample=True, top_k=0, top_p=0.8, temperature=0.7, max_length=20) | ||
output_top_p_sampling = happy_tt.generate_text( | ||
"translate English to French: nlp is a field of artificial intelligence ", | ||
args=top_p_sampling_settings) | ||
|
||
print("Greedy:", output_greedy.text) # Greedy: nlp est un domaine de l'intelligence artificielle | ||
print("Beam:", output_beam_search.text) # Beam: nlp est un domaine de l'intelligence artificielle | ||
print("Generic Sampling:", output_generic_sampling.text) # Generic Sampling: nlp est un champ d'intelligence artificielle | ||
print("Top-k Sampling:", output_top_k_sampling.text) # Top-k Sampling: nlp est un domaine de l’intelligence artificielle | ||
print("Top-p Sampling:", output_top_p_sampling.text) # Top-p Sampling: nlp est un domaine de l'intelligence artificielle | ||
|
||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from happytransformer import HappyTextToText, TTSettings | ||
|
||
def example_7_0(): | ||
# --------------------------------------# | ||
happy_tt = HappyTextToText("T5", "t5-small") # default | ||
|
||
def example_7_1(): | ||
# --------------------------------------# | ||
happy_tt = HappyTextToText() # default uses t5-small | ||
top_p_sampling_settings = TTSettings(do_sample=True, top_k=0, top_p=0.8, temperature=0.7, min_length=20, max_length=20, early_stopping=True) | ||
result = happy_tt.generate_text("translate English to French: nlp is a field of artificial intelligence", args=top_p_sampling_settings) | ||
print(result) # nlp est un domaine de l’intelligence artificielle. n | ||
print(result.text) # nlp est un domaine de l’intelligence artificielle. n | ||
|
||
|
||
def example_7_2(): | ||
happy_tt = HappyTextToText("T5", "t5-small") | ||
|
||
greedy_settings = TTSettings(no_repeat_ngram_size=2, max_length=20) | ||
output_greedy = happy_tt.generate_text( | ||
"translate English to French: nlp is a field of artificial intelligence ", | ||
args=greedy_settings) | ||
|
||
beam_settings = TTSettings(num_beams=5, max_length=20) | ||
output_beam_search = happy_tt.generate_text( | ||
"translate English to French: nlp is a field of artificial intelligence ", | ||
args=beam_settings) | ||
|
||
generic_sampling_settings = TTSettings(do_sample=True, top_k=0, temperature=0.7, max_length=20) | ||
output_generic_sampling = happy_tt.generate_text( | ||
"translate English to French: nlp is a field of artificial intelligence ", | ||
args=generic_sampling_settings) | ||
|
||
top_k_sampling_settings = TTSettings(do_sample=True, top_k=50, temperature=0.7, max_length=20) | ||
output_top_k_sampling = happy_tt.generate_text( | ||
"translate English to French: nlp is a field of artificial intelligence ", | ||
args=top_k_sampling_settings) | ||
|
||
top_p_sampling_settings = TTSettings(do_sample=True, top_k=0, top_p=0.8, temperature=0.7, max_length=20) | ||
output_top_p_sampling = happy_tt.generate_text( | ||
"translate English to French: nlp is a field of artificial intelligence ", | ||
args=top_p_sampling_settings) | ||
|
||
print("Greedy:", output_greedy.text) # Greedy: nlp est un domaine de l'intelligence artificielle | ||
print("Beam:", output_beam_search.text) # Beam: nlp est un domaine de l'intelligence artificielle | ||
print("Generic Sampling:", output_generic_sampling.text) # Generic Sampling: nlp est un champ d'intelligence artificielle | ||
print("Top-k Sampling:", output_top_k_sampling.text) # Top-k Sampling: nlp est un domaine de l’intelligence artificielle | ||
print("Top-p Sampling:", output_top_p_sampling.text) # Top-p Sampling: nlp est un domaine de l'intelligence artificielle | ||
|
||
|
||
def main(): | ||
# example_7_0() | ||
example_7_1() | ||
# example_7_2() | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
""" | ||
Contains a class called HappyTextToText which performs text to text generation | ||
""" | ||
from dataclasses import dataclass | ||
|
||
from transformers import Text2TextGenerationPipeline, AutoModelForSeq2SeqLM | ||
|
||
from happytransformer.happy_transformer import HappyTransformer | ||
from happytransformer.tt.trainer import TTTrainer | ||
from happytransformer.cuda_detect import detect_cuda_device_number | ||
from happytransformer.adaptors import get_adaptor | ||
from happytransformer.tt.trainer import TTTrainArgs, TTEvalArgs, TTTestArgs | ||
|
||
|
||
@dataclass | ||
class TextToTextResult: | ||
text: str | ||
|
||
@dataclass | ||
class TTSettings: | ||
min_length: int = 10 | ||
max_length: int = 50 | ||
do_sample: bool = False | ||
early_stopping: bool = False | ||
num_beams: int = 1 | ||
temperature: float = 1 | ||
top_k: int = 50 | ||
no_repeat_ngram_size: int = 0 | ||
top_p: float = 1 | ||
|
||
|
||
class HappyTextToText(HappyTransformer): | ||
""" | ||
A user facing class for text to text generation | ||
""" | ||
def __init__(self, model_type: str = "T5", model_name: str = "t5-small", load_path: str = ""): | ||
|
||
self.adaptor = get_adaptor(model_type) | ||
|
||
if load_path != "": | ||
model = AutoModelForSeq2SeqLM.from_pretrained(load_path) | ||
else: | ||
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | ||
|
||
|
||
super().__init__(model_type, model_name, model) | ||
|
||
device_number = detect_cuda_device_number() | ||
|
||
self._pipeline = Text2TextGenerationPipeline(model=self.model, | ||
tokenizer=self.tokenizer, device=device_number) | ||
|
||
self._trainer = TTTrainer(self.model, model_type, self.tokenizer, self._device, self.logger) | ||
|
||
|
||
def __assert_default_text_is_val(self, text): | ||
""" | ||
Ensures the input's text input is valid. | ||
Raises a Value Error if the text input is not valid. | ||
:param text: The value the user inputs for the "text" parameter | ||
""" | ||
|
||
if not isinstance(text, str): | ||
raise ValueError("The text input must be a string") | ||
if not text: | ||
raise ValueError("The text input must have at least one character") | ||
|
||
|
||
def generate_text(self, text: str, | ||
args: TTSettings = TTSettings()) -> TextToTextResult: | ||
""" | ||
:param text: starting text that the model uses as a prompt to continue it. | ||
:param args: A GENSettings object | ||
:return: A TextToTextResult() object | ||
""" | ||
self.__assert_default_text_is_val(text) | ||
|
||
output = self._pipeline(text, min_length=args.min_length, | ||
max_length=args.max_length, | ||
do_sample=args.do_sample, | ||
early_stopping=args.early_stopping, | ||
num_beams=args.num_beams, | ||
temperature=args.temperature, | ||
top_k=args.top_k, | ||
no_repeat_ngram_size=args.no_repeat_ngram_size, | ||
top_p=args.top_p, | ||
) | ||
return TextToTextResult(text=output[0]['generated_text']) | ||
|
||
def train(self, input_filepath, args=TTTrainArgs): | ||
raise NotImplementedError("train() is currently not available") | ||
|
||
def eval(self, input_filepath, args=TTEvalArgs): | ||
raise NotImplementedError("eval() is currently not available") | ||
|
||
def test(self, input_filepath, args=TTTestArgs): | ||
raise NotImplementedError("test() is currently not available") |
Oops, something went wrong.