Skip to content

Commit

Permalink
feat: add scoring option to synthesize cli command
Browse files Browse the repository at this point in the history
  • Loading branch information
roedoejet committed Jan 20, 2025
1 parent 8d2ebc3 commit 417b616
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 27 deletions.
40 changes: 40 additions & 0 deletions fs2/cli/synthesize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import textwrap
from collections import Counter
from pathlib import Path
from typing import Any, Optional

Expand All @@ -11,6 +12,7 @@
)
from everyvoice.utils import spinner
from loguru import logger
from tqdm import tqdm

from ..type_definitions import SynthesizeOutputFormats
from ..utils import truncate_basename
Expand Down Expand Up @@ -221,6 +223,7 @@ def synthesize_helper(
vocoder_global_step: Optional[int] = None,
vocoder_model=None,
vocoder_config=None,
return_scores=False,
):
"""This is a helper to perform synthesis once the model has been loaded.
It allows us to use the same command for synthesis via the CLI and
Expand Down Expand Up @@ -258,6 +261,28 @@ def synthesize_helper(
text_representation=text_representation,
style_reference=style_reference,
)
if return_scores:
from nltk.util import ngrams

token_counter = Counter()
trigram_counter = Counter()
for line in tqdm(
data, desc="calculating filelist statistics for score calculation"
):
tokens = line[f"{text_representation.value[:-1]}_tokens"].split("/")
for t in tokens:
token_counter[t] += 1
tokens.insert(0, "<BOS>")
tokens.append("<EOS>")
for trigram in ngrams(tokens, 3):
trigram_counter[trigram] += 1
for line in tqdm(data, desc="scoring utterances"):
tokens = line[f"{text_representation.value[:-1]}_tokens"].split("/")
line["phone_coverage_score"] = sum((1 / token_counter[t]) for t in tokens)
line["trigram_coverage_score"] = sum(
(1 / trigram_counter[n]) for n in ngrams(tokens, 3)
)

from pytorch_lightning import Trainer

from ..prediction_writing_callback import get_synthesis_output_callbacks
Expand All @@ -272,6 +297,7 @@ def synthesize_helper(
vocoder_model=vocoder_model,
vocoder_config=vocoder_config,
vocoder_global_step=vocoder_global_step,
return_scores=return_scores,
)
trainer = Trainer(
logger=False, # We don't need to log things to tensorboard during inference
Expand All @@ -284,6 +310,10 @@ def synthesize_helper(
teacher_forcing = True
model.config.preprocessing.save_dir = teacher_forcing_directory
else:
if return_scores:
raise ValueError(
"In order to return the scores, we also need access to the directory containing your ground truth audio. Please pass this in using the --teacher-forcing-directory option. e.g. --teacher-forcing-directory ./preprocessed"
)
teacher_forcing = False
# overwrite batch_size and num_workers
model.config.training.batch_size = batch_size
Expand Down Expand Up @@ -401,6 +431,12 @@ def synthesize( # noqa: C901
'**readalong-html**' will generate a single file Offline HTML ReadAlong that can be further edited in the ReadAlong Studio Editor, and opened by itself. Also implies '--output-type wav'. Requires --vocoder-path.
""",
),
return_scores: bool = typer.Option(
False,
"--return-scores",
"-R",
help="ADVANCED. Setting this to True will change your batch size to 1 and output a PSV file with the losses for each synthesized audio along with a score of trigram density to measure the phonological importance of the utterance.",
),
teacher_forcing_directory: Path = typer.Option(
None,
"--teacher-forcing-directory",
Expand Down Expand Up @@ -466,6 +502,9 @@ def synthesize( # noqa: C901

from ..model import FastSpeech2

if return_scores:
batch_size = 1

output_dir.mkdir(exist_ok=True, parents=True)
# NOTE: We want to be able to put the vocoder on the proper accelerator for
# it to be compatible with the vocoder's input device.
Expand Down Expand Up @@ -530,4 +569,5 @@ def synthesize( # noqa: C901
vocoder_model=vocoder_model,
vocoder_config=vocoder_config,
vocoder_global_step=vocoder_global_step,
return_scores=return_scores,
)
12 changes: 11 additions & 1 deletion fs2/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ def __getitem__(self, index):
else:
energy = None
pitch = None
return {

loaded_data = {
"mel": mel,
"mel_style_reference": mel_style_reference,
"duration": duration,
Expand All @@ -203,6 +204,15 @@ def __getitem__(self, index):
"pitch": pitch,
}

# used when returning scores
if "phone_coverage_score" in item:
loaded_data['phone_coverage_score'] = item['phone_coverage_score']

if "trigram_coverage_score" in item:
loaded_data['trigram_coverage_score'] = item['trigram_coverage_score']

return loaded_data

def __len__(self):
return len(self.dataset)

Expand Down
58 changes: 32 additions & 26 deletions fs2/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,39 +37,45 @@ def forward(self, output, batch, current_epoch, frozen_components=None):

# Don't calculate grad on target
duration_target.requires_grad = False
energy_target.requires_grad = False
spec_target.requires_grad = False
pitch_target.requires_grad = False

losses = {}

# Calculate pitch loss
if self.config.model.variance_predictors.pitch.level == "phone":
pitch_mask = src_mask
else:
pitch_mask = tgt_mask

pitch_prediction = pitch_prediction * pitch_mask
pitch_target = pitch_target * pitch_mask
pitch_loss_fn = self.config.model.variance_predictors.pitch.loss
losses["pitch"] = (
self.loss_fns[pitch_loss_fn](pitch_prediction, pitch_target)
* self.config.training.pitch_loss_weight
)
if pitch_target is not None:

pitch_target.requires_grad = False

if self.config.model.variance_predictors.pitch.level == "phone":
pitch_mask = src_mask
else:
pitch_mask = tgt_mask

pitch_prediction = pitch_prediction * pitch_mask
pitch_target = pitch_target * pitch_mask
pitch_loss_fn = self.config.model.variance_predictors.pitch.loss
losses["pitch"] = (
self.loss_fns[pitch_loss_fn](pitch_prediction, pitch_target)
* self.config.training.pitch_loss_weight
)

# Calculate energy loss
if self.config.model.variance_predictors.energy.level == "phone":
energy_mask = src_mask
else:
energy_mask = tgt_mask

energy_prediction = energy_prediction * energy_mask
energy_target = energy_target * energy_mask
energy_loss_fn = self.config.model.variance_predictors.energy.loss
losses["energy"] = (
self.loss_fns[energy_loss_fn](energy_prediction, energy_target)
* self.config.training.energy_loss_weight
)
if energy_target is not None:

energy_target.requires_grad = False

if self.config.model.variance_predictors.energy.level == "phone":
energy_mask = src_mask
else:
energy_mask = tgt_mask

energy_prediction = energy_prediction * energy_mask
energy_target = energy_target * energy_mask
energy_loss_fn = self.config.model.variance_predictors.energy.loss
losses["energy"] = (
self.loss_fns[energy_loss_fn](energy_prediction, energy_target)
* self.config.training.energy_loss_weight
)

# Calculate duration loss
log_duration_target = torch.log(duration_target.float() + 1) * src_mask
Expand Down
85 changes: 85 additions & 0 deletions fs2/prediction_writing_callback.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from csv import DictWriter
from pathlib import Path
from typing import Any, Optional, Sequence

Expand Down Expand Up @@ -32,12 +33,20 @@ def get_synthesis_output_callbacks(
vocoder_model: Optional[HiFiGAN] = None,
vocoder_config: Optional[HiFiGANConfig] = None,
vocoder_global_step: Optional[int] = None,
return_scores=False,
) -> dict[SynthesizeOutputFormats, Callback]:
"""
Given a list of desired output file formats, return the proper callbacks
that will generate those files.
"""
callbacks: dict[SynthesizeOutputFormats, Callback] = {}
if return_scores:
callbacks['score'] = ScorerCallback(
config=config,
global_step=global_step,
output_dir=output_dir,
output_key=output_key,
)
if (
SynthesizeOutputFormats.wav in output_type
or SynthesizeOutputFormats.readalong_html in output_type
Expand Down Expand Up @@ -135,6 +144,82 @@ def get_filename(
return str(path)


class ScorerCallback(Callback):
"""
This callback runs inference on a provided text-to-spec model and saves the resulting losses to disk.
"""

def __init__(
self,
config: FastSpeech2Config,
global_step: int,
output_dir: Path,
output_key: str,
):
self.global_step = global_step
self.save_dir = output_dir
self.output_key = output_key
self.config = config
logger.info(f"Saving pytorch output to {self.save_dir}")
self.scores = []

def _get_filename(self) -> Path:
path = self.save_dir / f"scores-{self.global_step}.psv"
path.parent.mkdir(
parents=True, exist_ok=True
) # synthesizing spec allows nested outputs
return path

def sort_scores(self):
self.scores.sort(key=lambda x: (-x["total"], x["trigram_coverage_score"]))

def on_predict_epoch_end(
self,
_trainer,
model,
):
self.sort_scores()
with open(self._get_filename(), "w") as f:
fieldnames = [
"basename",
"speaker",
"language",
"total",
"trigram_coverage_score",
"duration",
"spec",
"postnet",
"attn_ctc",
"attn_bin",
"raw_text",
"phone_coverage_score",
]
writer = DictWriter(f, fieldnames=fieldnames, delimiter="|")
writer.writeheader()
for score in self.scores:
writer.writerow(score)

def on_predict_batch_end( # pyright: ignore [reportIncompatibleMethodOverride]
self,
_trainer,
model,
outputs: dict[str, torch.Tensor | None],
batch: dict[str, Any],
_batch_idx: int,
_dataloader_idx: int = 0,
):
with torch.no_grad():
losses = model.loss(outputs, batch, model.current_epoch)
score = {k: float(v) for k, v in losses.items()}
score["basename"] = batch["basename"][0]
score["speaker"] = batch["speaker"][0]
score["language"] = batch["language"][0]
score["raw_text"] = batch["raw_text"][0]
score["phone_coverage_score"] = batch["phone_coverage_score"][0]
score["trigram_coverage_score"] = batch["trigram_coverage_score"][0]
self.scores.append(score)


class PredictionWritingSpecCallback(PredictionWritingCallbackBase):
"""
This callback runs inference on a provided text-to-spec model and saves the resulting Mel spectrograms to disk as pytorch files. These can be used to fine-tune an EveryVoice spec-to-wav model.
Expand Down

0 comments on commit 417b616

Please sign in to comment.