Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1 develop transcription translation summarisation pipeline #7

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
809 changes: 809 additions & 0 deletions notebooks/tts_pipeline_nb.ipynb

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ dependencies = [
"transformers",
"huggingface",
"datasets",
"numpy"
"numpy",
"sentencepiece",
"librosa",
"soundfile"
]

[project.optional-dependencies]
Expand Down
46 changes: 46 additions & 0 deletions scripts/TTS_pipeline_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
An example use of the transcription, translation and summarisation pipeline.
"""

import numpy as np
from datasets import Audio, load_dataset

from arc_spice.pipelines.TTS_pipeline import TTSpipeline


def main(TTS_params):
"""main function"""
TTS = TTSpipeline(TTS_params)
TTS.print_pipeline()
ds = load_dataset(
"facebook/multilingual_librispeech", "french", split="test", streaming=True
)
ds = ds.cast_column("audio", Audio(sampling_rate=16_000))
input_speech = next(iter(ds))["audio"]
# arrays = []
# n = 5
# for idx, data in enumerate(iter(ds)):
# arrays.append(data["audio"]["array"])
# if idx == n:
# break
# arrays = np.concatenate(arrays)
TTS.run_pipeline(input_speech["array"])
TTS.print_results()


if __name__ == "__main__":
TTS_pars = {
"transcriber": {
"specific_task": "automatic-speech-recognition",
"model": "openai/whisper-small",
},
"translator": {
"specific_task": "translation_fr_to_en",
"model": "facebook/mbart-large-50-many-to-many-mmt",
},
"summariser": {
"specific_task": "summarization",
"model": "facebook/bart-large-cnn",
},
}
main(TTS_params=TTS_pars)
53 changes: 53 additions & 0 deletions src/arc_spice/pipelines/TTS_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
Class for the transcription, translation and summarisation pipeline.
"""

from transformers import pipeline


class TTSpipeline:
"""
Class for the transcription, translation, summarisation pipeline.

pars:
- {'top_level_task': {'specific_task': str, 'model_name': str}}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is "specific_task" here the task as its named in the huggingface API?

"""

def __init__(self, pars) -> None:
self.pars = pars
self.transcriber = pipeline(
pars["transcriber"]["specific_task"], pars["transcriber"]["model"]
)
self.translator = pipeline(
pars["translator"]["specific_task"], pars["translator"]["model"]
)
self.summariser = pipeline(
pars["summariser"]["specific_task"], pars["summariser"]["model"]
)
self.results = {}

def print_pipeline(self):
"""Print the models in the pipeline"""
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
print(f"Transcriber model: {self.pars['transcriber']['model']}")
print(f"Translator model: {self.pars['translator']['model']}")
print(f"Summariser model: {self.pars['summariser']['model']}")
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")

def run_pipeline(self, x):
"""Run the pipeline on an input x"""
transcription = self.transcriber(x)
self.results["transcription"] = transcription["text"]
translation = self.translator(transcription["text"])
self.results["translation"] = translation[0]["translation_text"]
summarisation = self.summariser(translation[0]["translation_text"])
self.results["summarisation"] = summarisation[0]["summary_text"]

def print_results(self):
"""Print the results for quick scanning"""
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
for key, val in self.results.items():
print("-------------")
print(f"{key} result is: \n {val}")
print("-------------")
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
Loading