Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow downloading extra formats in the demo #617

Merged
merged 8 commits into from
Jan 14, 2025
31 changes: 25 additions & 6 deletions everyvoice/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import enum
import json
import platform
import subprocess
Expand Down Expand Up @@ -37,6 +36,9 @@
from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.cli.train import (
train as train_fs2,
)
from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.type_definitions import (
SynthesizeOutputFormats,
)
from everyvoice.model.vocoder.HiFiGAN_iSTFT_lightning.hfgl.cli import (
HFG_EXPORT_LONG_HELP,
HFG_EXPORT_SHORT_HELP,
Expand Down Expand Up @@ -558,7 +560,7 @@ def check_data(
)(inspect_checkpoint)


TestSuites = enum.Enum("TestSuites", {name: name for name in SUITE_NAMES}) # type: ignore
TestSuites = Enum("TestSuites", {name: name for name in SUITE_NAMES}) # type: ignore


@app.command(hidden=True)
Expand All @@ -571,6 +573,12 @@ def test(suite: TestSuites = typer.Argument("dev")):
SCHEMAS_TO_OUTPUT: dict[str, Any] = {} # dict[str, type[BaseModel]]


AllowedDemoOutputFormats = Enum( # type: ignore
"AllowedDemoOutputFormats",
[("all", "all")] + [(i.name, i.value) for i in SynthesizeOutputFormats],
)


@app.command()
def demo(
text_to_spec_model: Path = typer.Argument(
Expand Down Expand Up @@ -608,13 +616,19 @@ def demo(
["all"],
"--language",
"-l",
help="Specify languages to be included in the demo. Example: everyvoice demo <path_to_text_to_spec_model> <path_to_spec_to_wav_model> --language eng --language fin",
help="Specify languages to be included in the demo. Must be supported by your model. Example: everyvoice demo TEXT_TO_SPEC_MODEL SPEC_TO_WAV_MODEL --language eng --language fin",
),
speakers: List[str] = typer.Option(
["all"],
"--speaker",
"-s",
help="Specify speakers to be included in the demo. Example: everyvoice demo <path_to_text_to_spec_model> <path_to_spec_to_wav_model> --speaker speaker_1 --speaker Sue",
help="Specify speakers to be included in the demo. Must be supported by your model. Example: everyvoice demo TEXT_TO_SPEC_MODEL SPEC_TO_WAV_MODEL --speaker speaker_1 --speaker Sue",
),
outputs: list[AllowedDemoOutputFormats] = typer.Option(
["all"],
"--output-format",
"-O",
help="Specify output formats to be included in the demo. Example: everyvoice demo TEXT_TO_SPEC_MODEL SPEC_TO_WAV_MODEL --output-format wav --output-format readalong-html",
),
output_dir: Path = typer.Option(
"synthesis_output",
Expand All @@ -625,9 +639,13 @@ def demo(
help="The directory where your synthesized audio should be written",
shell_complete=complete_path,
),
accelerator: str = typer.Option("auto", "--accelerator", "-a"),
accelerator: str = typer.Option(
"auto",
"--accelerator",
"-a",
help="Specify the Pytorch Lightning accelerator to use",
),
):

if allowlist and denylist:
raise ValueError(
"You provided a value for both the allowlist and the denylist but you can only provide one."
Expand All @@ -652,6 +670,7 @@ def demo(
spec_to_wav_model_path=spec_to_wav_model,
languages=languages,
speakers=speakers,
outputs=outputs,
output_dir=output_dir,
accelerator=accelerator,
allowlist=allowlist_data,
Expand Down
104 changes: 71 additions & 33 deletions everyvoice/demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import gradio as gr
import torch
from gradio.processing_utils import convert_to_16_bit_wav
from loguru import logger

from everyvoice.config.type_definitions import TargetTrainingTextRepresentationLevel
Expand All @@ -17,12 +16,16 @@
from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.model import (
FastSpeech2,
)
from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.prediction_writing_callback import (
PredictionWritingWavCallback,
from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.type_definitions import (
SynthesizeOutputFormats,
)
from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.utils import (
truncate_basename,
)
from everyvoice.model.vocoder.HiFiGAN_iSTFT_lightning.hfgl.utils import (
load_hifigan_from_checkpoint,
)
from everyvoice.utils import slugify
from everyvoice.utils.heavy import get_device_from_accelerator

os.environ["no_proxy"] = "localhost,127.0.0.1,::1"
Expand All @@ -33,6 +36,7 @@
duration_control,
language,
speaker,
output_format,
text_to_spec_model,
vocoder_model,
vocoder_config,
Expand All @@ -47,6 +51,7 @@
"Text for synthesis was not provided. Please type the text you want to be synthesized into the textfield."
)
norm_text = normalize_text(text)
basename = truncate_basename(slugify(text))

Check warning on line 54 in everyvoice/demo/app.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/demo/app.py#L54

Added line #L54 was not covered by tests
if allowlist and norm_text not in allowlist:
raise gr.Error(
f"Oops, the word {text} is not allowed to be synthesized by this model. Please contact the model owner."
Expand All @@ -62,7 +67,9 @@
raise gr.Error("Language is not selected. Please select a language.")
if speaker is None:
raise gr.Error("Speaker is not selected. Please select a speaker.")
config, device, predictions = synthesize_helper(
if output_format is None:
raise gr.Error("Speaker is not selected. Please select an output format.")
config, device, predictions, callbacks = synthesize_helper(

Check warning on line 72 in everyvoice/demo/app.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/demo/app.py#L71-L72

Added lines #L71 - L72 were not covered by tests
model=text_to_spec_model,
vocoder_model=vocoder_model,
vocoder_config=vocoder_config,
Expand All @@ -71,9 +78,9 @@
accelerator=accelerator,
devices="1",
device=device,
global_step=1,
vocoder_global_step=1, # dummy value since the vocoder step is not used
output_type=[],
global_step=text_to_spec_model.config.training.max_steps,
vocoder_global_step=vocoder_model.config.training.max_steps,
output_type=(output_format, SynthesizeOutputFormats.wav),
text_representation=TargetTrainingTextRepresentationLevel.characters,
output_dir=output_dir,
speaker=speaker,
Expand All @@ -83,24 +90,16 @@
batch_size=1,
num_workers=1,
)
output_key = (
"postnet_output" if text_to_spec_model.config.model.use_postnet else "output"
)
wav_writer = PredictionWritingWavCallback(
output_dir=output_dir,
config=config,
output_key=output_key,
device=device,
global_step=1,
vocoder_global_step=1, # dummy value since the vocoder step is not used
vocoder_model=vocoder_model,
vocoder_config=vocoder_config,
)
# move to device because lightning accumulates predictions on cpu
predictions[0][output_key] = predictions[0][output_key].to(device)
wav, sr = wav_writer.synthesize_audio(predictions[0])

return sr, convert_to_16_bit_wav(wav.numpy())
wav_writer = callbacks[SynthesizeOutputFormats.wav]
wav_output = wav_writer.get_filename(basename, speaker, language)

Check warning on line 95 in everyvoice/demo/app.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/demo/app.py#L94-L95

Added lines #L94 - L95 were not covered by tests

file_output = None

Check warning on line 97 in everyvoice/demo/app.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/demo/app.py#L97

Added line #L97 was not covered by tests
if output_format != SynthesizeOutputFormats.wav:
file_writer = callbacks[output_format]
file_output = file_writer.get_filename(basename, speaker, language)

Check warning on line 100 in everyvoice/demo/app.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/demo/app.py#L99-L100

Added lines #L99 - L100 were not covered by tests

return wav_output, file_output

Check warning on line 102 in everyvoice/demo/app.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/demo/app.py#L102

Added line #L102 was not covered by tests


def require_ffmpeg():
Expand Down Expand Up @@ -158,15 +157,38 @@


def create_demo_app(
text_to_spec_model_path,
spec_to_wav_model_path,
languages,
speakers,
output_dir,
accelerator,
text_to_spec_model_path: os.PathLike,
spec_to_wav_model_path: os.PathLike,
languages: list[str],
speakers: list[str],
outputs: list, # list[str | AllowedDemoOutputFormats]
output_dir: os.PathLike,
accelerator: str,
allowlist: list[str] = [],
denylist: list[str] = [],
) -> gr.Blocks:
# Early argument validation where possible
possible_outputs = [x.value for x in SynthesizeOutputFormats]

# this used to be `if outputs == ["all"]:` but my Enum() constructor for
# AllowedDemoOutputFormats breaks that, unfortunately, and enum.StrEnum
# doesn't appear until Python 3.11 so I can't use it.
if len(outputs) == 1 and getattr(outputs[0], "value", outputs[0]) == "all":
output_list = possible_outputs

Check warning on line 177 in everyvoice/demo/app.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/demo/app.py#L177

Added line #L177 was not covered by tests
else:
if not outputs:
raise ValueError(
f"Empty outputs list. Please specify ['all'] or one or more of {possible_outputs}"
)
output_list = []
for output in outputs:
value = getattr(output, "value", output) # Enum->value as str / str->str
if value not in possible_outputs:
raise ValueError(
f"Unknown output format '{value}'. Valid outputs values are ['all'] or one or more of {possible_outputs}"
)
output_list.append(value)

require_ffmpeg()
device = get_device_from_accelerator(accelerator)
vocoder_ckpt = torch.load(spec_to_wav_model_path, map_location=device)
Expand Down Expand Up @@ -215,6 +237,7 @@
print(
f"Attention: The model have not been trained for speech synthesis with '{speaker}' speaker. The '{speaker}' speaker option will not be available for selection."
)

if lang_list == []:
raise ValueError(
f"Language option has been activated, but valid languages have not been provided. The model has been trained in {model_languages} languages. Please select either 'all' or at least some of them."
Expand All @@ -227,6 +250,8 @@
interactive_lang = len(lang_list) > 1
default_speak = speak_list[0]
interactive_speak = len(speak_list) > 1
default_output = output_list[0]
interactive_output = len(output_list) > 1

Check warning on line 254 in everyvoice/demo/app.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/demo/app.py#L253-L254

Added lines #L253 - L254 were not covered by tests
with gr.Blocks() as demo:
gr.Markdown(
"""
Expand Down Expand Up @@ -255,12 +280,25 @@
interactive=interactive_speak,
label="Speaker",
)
with gr.Row():
output_format = gr.Dropdown(

Check warning on line 284 in everyvoice/demo/app.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/demo/app.py#L283-L284

Added lines #L283 - L284 were not covered by tests
choices=output_list,
value=default_output,
interactive=interactive_output,
label="Output Format",
)
btn = gr.Button("Synthesize")
with gr.Column():
out_audio = gr.Audio(format="mp3")
out_audio = gr.Audio(format="wav")

Check warning on line 292 in everyvoice/demo/app.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/demo/app.py#L292

Added line #L292 was not covered by tests
if output_list == [SynthesizeOutputFormats.wav]:
# When the only output option is wav, don't show the File Output box
outputs = [out_audio]

Check warning on line 295 in everyvoice/demo/app.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/demo/app.py#L295

Added line #L295 was not covered by tests
else:
out_file = gr.File(label="File Output")
outputs = [out_audio, out_file]

Check warning on line 298 in everyvoice/demo/app.py

View check run for this annotation

Codecov / codecov/patch

everyvoice/demo/app.py#L297-L298

Added lines #L297 - L298 were not covered by tests
btn.click(
synthesize_audio_preset,
inputs=[inp_text, inp_slider, inp_lang, inp_speak],
outputs=[out_audio],
inputs=[inp_text, inp_slider, inp_lang, inp_speak, output_format],
outputs=outputs,
)
return demo
45 changes: 45 additions & 0 deletions everyvoice/tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import json
import os
import subprocess
Expand All @@ -21,6 +22,7 @@
from everyvoice.base_cli.helpers import save_configuration_to_log_dir
from everyvoice.cli import SCHEMAS_TO_OUTPUT, app
from everyvoice.config.shared_types import ContactInformation
from everyvoice.demo.app import create_demo_app
from everyvoice.model.feature_prediction.FastSpeech2_lightning.fs2.config import (
FastSpeech2Config,
)
Expand Down Expand Up @@ -56,6 +58,7 @@ def setUp(self) -> None:
"preprocess",
"inspect-checkpoint",
"evaluate",
"demo",
]

def test_version(self):
Expand Down Expand Up @@ -323,6 +326,48 @@ def test_expensive_imports_are_tucked_away(self):
self.assertNotIn(b"shared_types", result.stderr, msg.format("shared_types.py"))
self.assertNotIn(b"pydantic", result.stderr, msg.format("pydantic"))

def test_demo_with_bad_args(self):
result = self.runner.invoke(app, ["demo"])
self.assertNotEqual(result.exit_code, 0)
self.assertIn("Missing argument", result.output)

result = self.runner.invoke(
app, ["demo", os.devnull, os.devnull, "--output-format", "not-a-format"]
)
self.assertNotEqual(result.exit_code, 0)
self.assertIn("Invalid value", result.output)

def test_create_demo_app_with_errors(self):
# outputs is the first thing to get checked, because it's can be done as
# a quick check before loading any models.
with self.assertRaises(ValueError) as cm:
create_demo_app(
text_to_spec_model_path=None,
spec_to_wav_model_path=None,
languages=[],
speakers=[],
outputs=[],
output_dir=None,
accelerator=None,
)
self.assertIn("Empty outputs list", str(cm.exception))

class WrongEnum(str, enum.Enum):
foo = "foo"

for outputs in (["wav", WrongEnum.foo], ["textgrid", "foo"]):
with self.assertRaises(ValueError) as cm:
create_demo_app(
text_to_spec_model_path=None,
spec_to_wav_model_path=None,
languages=[],
speakers=[],
outputs=outputs,
output_dir=None,
accelerator=None,
)
self.assertIn("Unknown output format 'foo'", str(cm.exception))


class TestBaseCLIHelper(TestCase):
def test_save_configuration_to_log_dir(self):
Expand Down
Loading