Skip to content

Commit

Permalink
Merge pull request #515 from EveryVoiceTTS/dev.ap/text-fixes
Browse files Browse the repository at this point in the history
fix: change shape of filelist list data instead of re-reading it
  • Loading branch information
roedoejet authored Jul 30, 2024
2 parents 00b07a5 + 9ab79eb commit 8fc4099
Show file tree
Hide file tree
Showing 8 changed files with 198 additions and 39 deletions.
23 changes: 19 additions & 4 deletions everyvoice/config/text_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from everyvoice.config.shared_types import ConfigModel
from everyvoice.config.utils import PossiblySerializedCallable
from everyvoice.text.utils import normalize_text_helper
from everyvoice.utils import collapse_whitespace
from everyvoice.utils import collapse_whitespace, strip_text


class Punctuation(BaseModel):
Expand Down Expand Up @@ -62,6 +62,23 @@ def all_except_punctuation(self) -> set[str]:
"""Returns the set containing all characters."""
return set(w for _, v in self if not isinstance(v, Punctuation) for w in v)

@model_validator(mode="after")
def cannot_have_punctuation_in_symbol_set(self) -> "Symbols":
"""You cannot have the same symbol defined in punctuation as elsewhere.
Raises:
ValueError: raised if a symbol from punctuation is found elsewhere
Returns:
Symbols: The validated symbol set
"""
for punctuation in self.punctuation.all:
if punctuation in self.all_except_punctuation:
raise ValueError(
f"Sorry, the symbol '{punctuation}' occurs in both your declared punctuation and in your other symbol set. Please inspect your text configuration and either remove the symbol from the punctuation or other symbol set."
)
return self

@model_validator(mode="after")
def member_must_be_list_of_strings(self) -> "Symbols":
"""Except for `punctuation` & `pad`, all user defined member variables
Expand All @@ -81,9 +98,7 @@ def member_must_be_list_of_strings(self) -> "Symbols":
class TextConfig(ConfigModel):
symbols: Symbols = Field(default_factory=Symbols)
to_replace: Dict[str, str] = {} # Happens before cleaners
cleaners: list[PossiblySerializedCallable] = [
collapse_whitespace,
]
cleaners: list[PossiblySerializedCallable] = [collapse_whitespace, strip_text]

@model_validator(mode="after")
def clean_symbols(self) -> "TextConfig":
Expand Down
12 changes: 6 additions & 6 deletions everyvoice/model/e2e/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@

# The contact information only needs to be registered on the main config
class AlignerConfigNoContact(AlignerConfig):
contact: Optional[ContactInformation] = None
contact: Optional[ContactInformation] = None # type: ignore


class VocoderConfigNoContact(VocoderConfig):
contact: Optional[ContactInformation] = None
contact: Optional[ContactInformation] = None # type: ignore


class FeaturePredictionConfigNoContact(FeaturePredictionConfig):
contact: Optional[ContactInformation] = None
contact: Optional[ContactInformation] = None # type: ignore


class E2ETrainingConfig(BaseTrainingConfig):
Expand All @@ -36,17 +36,17 @@ class E2ETrainingConfig(BaseTrainingConfig):

class EveryVoiceConfig(BaseModelWithContact):
aligner: AlignerConfig | AlignerConfigNoContact = Field(
default_factory=AlignerConfigNoContact
default_factory=AlignerConfigNoContact # type: ignore
)
path_to_aligner_config_file: Optional[FilePath] = None

feature_prediction: FeaturePredictionConfig | FeaturePredictionConfigNoContact = (
Field(default_factory=FeaturePredictionConfigNoContact)
Field(default_factory=FeaturePredictionConfigNoContact) # type: ignore
)
path_to_feature_prediction_config_file: Optional[FilePath] = None

vocoder: VocoderConfig | VocoderConfigNoContact = Field(
default_factory=VocoderConfigNoContact
default_factory=VocoderConfigNoContact # type: ignore
)
path_to_vocoder_config_file: Optional[FilePath] = None

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cleaners: [everyvoice.utils.lower, everyvoice.utils.collapse_whitespace, everyvoice.utils.nfc_normalize]
symbols:
dataset_0-symbols: [' ', '''', ',', '-', ., C, E, H, K, P, T, a, b, c, d, e, f,
dataset_0-symbols: [' ', C, E, H, K, P, T, a, b, c, d, e, f,
g, h, i, l, m, n, o, p, r, s, t, u, v, w, x, y]
pad: _
silence: [<SIL>]
Expand Down
1 change: 1 addition & 0 deletions everyvoice/tests/data/unit-test-case1.psv
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
columnA|non-standard-basename|non-standard-text|extra
blah|somefile|characters|irrelevant extra
boom|file2|CaSeD NFD: éàê NFC: éàê|blah
floop|file3| let us see if it collapses whitespace|blah
bam|banned_file|ZZZ|has banned symbol (Z)
8 changes: 6 additions & 2 deletions everyvoice/tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def test_token_sequence_to_text(self):

def test_hardcoded_symbols(self):
self.assertEqual(
self.base_text_processor.encode_text("\x80 "),
[0, 1],
self.base_text_processor.encode_text("\x80 \x80"),
[0, 1, 0],
"pad should be Unicode PAD symbol and index 0, whitespace should be index 1",
)

Expand All @@ -65,6 +65,10 @@ def test_cleaners_with_upper(self):
sequence = upper_text_processor.encode_text(text_upper)
self.assertEqual(upper_text_processor.decode_tokens(sequence, ""), text)

def test_no_duplicate_punctuation(self):
with self.assertRaises(ValidationError):
TextConfig(symbols=Symbols(letters=[":"] + list(string.ascii_letters)))

def test_punctuation(self):
text = "hello! How are you? My name's: foo;."
upper_text_processor = TextProcessor(
Expand Down
104 changes: 97 additions & 7 deletions everyvoice/tests/test_wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,80 @@ def test_sample_rate_config(self):
self.assertTrue(step.completed)
self.assertEqual(step.response, 512)

def test_whitespace_always_collapsed(self):
tour = Tour("unit testing", steps=dataset.get_dataset_steps())
filelist = str(self.data_dir / "unit-test-case1.psv")
filelist_step = find_step(SN.filelist_step, tour.steps)
monkey = monkeypatch(filelist_step, "prompt", Say(filelist))
with monkey:
filelist_step.run()

permission_step = find_step(SN.dataset_permission_step, tour.steps)
with patch_menu_prompt(1): # 1 is "yes, I have permission"
permission_step.run()
self.assertTrue(
permission_step.state[SN.dataset_permission_step].startswith("Yes")
)

format_step = find_step(SN.filelist_format_step, tour.steps)
with patch_menu_prompt(0): # 0 is "psv"
format_step.run()
step = format_step.children[0]
with patch_menu_prompt(1): # 1 is "yes"
step.run()

step = format_step.children[1]
with patch_menu_prompt(1): # 1 is second column
step.run()

step = format_step.children[2]
with patch_menu_prompt(1): # 1 is second remaining column, i.e., third column
step.run()

text_representation_step = find_step(
SN.filelist_text_representation_step, tour.steps
)
with patch_menu_prompt(0): # 0 is "characters"
text_representation_step.run()
speaker_step = find_step(SN.data_has_speaker_value_step, tour.steps)
with patch_menu_prompt(0): # 0 is "no"
speaker_step.run()

know_speaker_step = speaker_step.children[0]
with patch_menu_prompt(1): # 1 is "yes"
know_speaker_step.run()

add_speaker_step = know_speaker_step.children[0]
with patch_input("default"):
add_speaker_step.run()

language_step = find_step(SN.data_has_language_value_step, tour.steps)
with patch_menu_prompt(0): # 0 is "no"
language_step.run()

select_lang_step = language_step.children[0]
with capture_stdout(), capture_stderr():
with patch_menu_prompt(15): # some arbitrary language from the list
select_lang_step.run()

wavs_dir_step = find_step(SN.wavs_dir_step, tour.steps)
with monkeypatch(wavs_dir_step, "prompt", Say(str(self.data_dir))):
wavs_dir_step.run()

validate_wavs_step = find_step(SN.validate_wavs_step, tour.steps)
with patch_menu_prompt(1), capture_stdout():
validate_wavs_step.run()

text_processing_step = find_step(SN.text_processing_step, tour.steps)
# 0 is lowercase, 1 is NFC Normalization, select none
with monkeypatch(dataset, "tqdm", lambda seq, desc: seq):
with patch_menu_prompt([]):
text_processing_step.run()
self.assertEqual(
text_processing_step.state["filelist_data_list"][3][2],
"let us see if it collapses whitespace",
)

def test_dataset_subtour(self):
tour = Tour("unit testing", steps=dataset.get_dataset_steps())

Expand Down Expand Up @@ -435,7 +509,7 @@ def test_dataset_subtour(self):
with patch_menu_prompt(1): # 1 is "yes"
step.run()
self.assertEqual(step.state[SN.data_has_header_line_step.value], "yes")
self.assertEqual(len(step.state["filelist_data_list"]), 4)
self.assertEqual(len(step.state["filelist_data_list"]), 5)

step = format_step.children[1]
self.assertIsInstance(step, dataset.HeaderStep)
Expand Down Expand Up @@ -503,7 +577,7 @@ def test_dataset_subtour(self):
with patch_menu_prompt(1), capture_stdout() as out:
validate_wavs_step.run()
self.assertEqual(step.state[SN.validate_wavs_step][:2], "No")
self.assertIn("Warning: 3 wav files were not found", out.getvalue())
self.assertIn("Warning: 4 wav files were not found", out.getvalue())

text_processing_step = find_step(SN.text_processing_step, tour.steps)
# 0 is lowercase, 1 is NFC Normalization, select both
Expand All @@ -513,11 +587,16 @@ def test_dataset_subtour(self):
# print(text_processing_step.state)
self.assertEqual(
text_processing_step.state["filelist_data_list"][2][2],
"cased \t nfd: éàê nfc: éàê", # the "nfd: éàê" bit here is now NFC
"cased nfd: éàê nfc: éàê", # the "nfd: éàê" bit here is now NFC
)

self.assertEqual(
text_processing_step.state["filelist_data_list"][3][2],
"let us see if it collapses whitespace",
)

# Make sure realoading the data as dict stripped the header line
self.assertEqual(len(step.state["filelist_data"]), 3)
self.assertEqual(len(step.state["filelist_data"]), 4)

sox_effects_step = find_step(SN.sox_effects_step, tour.steps)
# 0 is resample to 22050 kHz, 2 is remove silence at start and end
Expand Down Expand Up @@ -547,11 +626,24 @@ def test_dataset_subtour(self):
)

symbol_set_step = find_step(SN.symbol_set_step, tour.steps)
self.assertEqual(len(symbol_set_step.state["filelist_data"]), 3)
self.assertEqual(len(symbol_set_step.state["filelist_data"]), 4)
with capture_stdout(), capture_stderr():
symbol_set_step.run()
self.assertEqual(len(symbol_set_step.state[SN.symbol_set_step.value]), 2)
self.assertIn("t͡s", symbol_set_step.state[SN.symbol_set_step.value]["phones"])
self.assertNotIn(
":", symbol_set_step.state[SN.symbol_set_step.value]["characters"]
)
self.assertNotIn(":", symbol_set_step.state[SN.symbol_set_step.value]["phones"])
# assert that symbols contain no duplicates
self.assertEqual(
len(set(symbol_set_step.state[SN.symbol_set_step.value]["characters"])),
len(symbol_set_step.state[SN.symbol_set_step.value]["characters"]),
)
self.assertEqual(
len(set(symbol_set_step.state[SN.symbol_set_step.value]["phones"])),
len(symbol_set_step.state[SN.symbol_set_step.value]["phones"]),
)

def test_empty_filelist(self):
tour = Tour(
Expand Down Expand Up @@ -1599,8 +1691,6 @@ def setUp(self):
SN.symbol_set_step.value: {
"characters": [
" ",
",",
".",
"A",
"D",
"E",
Expand Down
10 changes: 9 additions & 1 deletion everyvoice/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,14 +356,22 @@ def generic_dict_loader(
generic_csv_filelist_reader = partial(generic_dict_loader, delimiter=",")


def collapse_whitespace(text):
def collapse_whitespace(text: str):
"""
>>> collapse_whitespace(" asdf qwer ")
' asdf qwer '
"""
return re.sub(_whitespace_re, " ", text)


def strip_text(text: str):
"""
>>> strip_text(" asdf qwer ")
'asdf qwer'
"""
return text.strip()


@contextmanager
def tqdm_joblib_context(tqdm_instance):
"""Context manager to make tqdm compatible with joblib.Parallel
Expand Down
Loading

0 comments on commit 8fc4099

Please sign in to comment.