Skip to content

Commit 6e6b022

Browse files
authored
performed end to end testing to the VALL-E recipe (k2-fsa#1818)
* added the missing ``visualize`` function * minor fixes
1 parent bdd0f85 commit 6e6b022

File tree

5 files changed

+109
-11
lines changed

5 files changed

+109
-11
lines changed

egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -516,9 +516,19 @@ def main():
516516
for idx, part in enumerate(cut_sets):
517517
if args.audio_extractor:
518518
if args.audio_extractor == "Encodec":
519-
storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx if split > 1 else ''}"
519+
if split > 1:
520+
storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx}"
521+
else:
522+
storage_path = (
523+
f"{args.output_dir}/{args.prefix}_encodec_{partition}"
524+
)
520525
else:
521-
storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx if split > 1 else ''}"
526+
if split > 1:
527+
storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx}"
528+
else:
529+
storage_path = (
530+
f"{args.output_dir}/{args.prefix}_fbank_{partition}"
531+
)
522532

523533
if args.prefix.lower() in [
524534
"ljspeech",
@@ -587,9 +597,11 @@ def main():
587597
].normalized_text, "normalized_text is None"
588598

589599
# Save each part with an index if split > 1
590-
cuts_filename = (
591-
f"{prefix}cuts_{partition}.{idx if split > 1 else ''}.{args.suffix}"
592-
)
600+
if split > 1:
601+
cuts_filename = f"{prefix}cuts_{partition}.{idx}.{args.suffix}"
602+
else:
603+
cuts_filename = f"{prefix}cuts_{partition}.{args.suffix}"
604+
593605
part.to_file(f"{args.output_dir}/{cuts_filename}")
594606
logging.info(f"Saved {cuts_filename}")
595607

egs/wenetspeech4tts/TTS/valle/infer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def get_args():
8686
parser.add_argument(
8787
"--checkpoint",
8888
type=str,
89-
default="exp/vallf_nano_full/checkpoint-100000.pt",
89+
default="./valle/exp/checkpoint-100000.pt",
9090
help="Path to the saved checkpoint.",
9191
)
9292

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
phonemizer==3.2.1
2+
git+https://github.com/facebookresearch/encodec.git

egs/wenetspeech4tts/TTS/valle/train.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Mingshuang Luo)
55
# Copyright 2023 (authors: Feiteng Li)
66
# Copyright 2024 (authors: Yuekai Zhang)
7+
# Copyright 2024 Tsinghua University (authors: Zengrui Jin,)
78
#
89
# See ../../../../LICENSE for clarification regarding multiple authors
910
#
@@ -48,10 +49,8 @@
4849
import argparse
4950
import copy
5051
import logging
51-
import os
5252
import random
5353
import warnings
54-
from contextlib import nullcontext
5554
from pathlib import Path
5655
from shutil import copyfile
5756
from typing import Any, Dict, Optional, Tuple, Union
@@ -216,7 +215,7 @@ def get_parser():
216215
parser.add_argument(
217216
"--exp-dir",
218217
type=str,
219-
default="exp/valle_dev",
218+
default="./valle/exp",
220219
help="""The experiment dir.
221220
It specifies the directory where all training related
222221
files, e.g., checkpoints, log, etc, are saved
@@ -686,9 +685,9 @@ def compute_validation_loss(
686685
output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}")
687686
output_dir.mkdir(parents=True, exist_ok=True)
688687
if isinstance(model, DDP):
689-
model.module.visualize(predicts, batch, output_dir=output_dir)
688+
model.module.visualize(predicts, batch, tokenizer, output_dir=output_dir)
690689
else:
691-
model.visualize(predicts, batch, output_dir=output_dir)
690+
model.visualize(predicts, batch, tokenizer, output_dir=output_dir)
692691

693692
return tot_loss
694693

egs/wenetspeech4tts/TTS/valle/valle.py

+85
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@
1919
from functools import partial
2020
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
2121

22+
import matplotlib.pyplot as plt
23+
import numpy as np
2224
import torch
2325
import torch.nn as nn
26+
from tokenizer import TextTokenCollater
2427
from torch import Tensor
2528
from torch.nn import Linear, Module
2629
from torch.nn import functional as F
@@ -1658,6 +1661,88 @@ def continual(
16581661
assert len(codes) == 8
16591662
return torch.stack(codes, dim=-1)
16601663

1664+
def visualize(
1665+
self,
1666+
predicts: Tuple[torch.Tensor],
1667+
batch: Dict[str, Union[List, torch.Tensor]],
1668+
tokenizer: TextTokenCollater,
1669+
output_dir: str,
1670+
limit: int = 4,
1671+
) -> None:
1672+
audio_features = batch["features"].to("cpu").detach().numpy()
1673+
audio_features_lens = batch["features_lens"].to("cpu").detach().numpy()
1674+
1675+
tokens = batch["tokens"]
1676+
text_tokens, text_tokens_lens = tokenizer(tokens)
1677+
assert text_tokens.ndim == 2
1678+
1679+
texts = batch["text"]
1680+
utt_ids = [cut.id for cut in batch["cut"]]
1681+
1682+
encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy()
1683+
decoder_outputs = predicts[1]
1684+
if isinstance(decoder_outputs, list):
1685+
decoder_outputs = decoder_outputs[-1]
1686+
decoder_outputs = decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
1687+
1688+
vmin, vmax = 0, 1024 # Encodec
1689+
if decoder_outputs.dtype == np.float32:
1690+
vmin, vmax = -6, 0 # Fbank
1691+
1692+
num_figures = 3
1693+
for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])):
1694+
_ = plt.figure(figsize=(14, 8 * num_figures))
1695+
1696+
S = text_tokens_lens[b]
1697+
T = audio_features_lens[b]
1698+
1699+
# encoder
1700+
plt.subplot(num_figures, 1, 1)
1701+
plt.title(f"Text: {text}")
1702+
plt.imshow(
1703+
X=np.transpose(encoder_outputs[b]),
1704+
cmap=plt.get_cmap("jet"),
1705+
aspect="auto",
1706+
interpolation="nearest",
1707+
)
1708+
plt.gca().invert_yaxis()
1709+
plt.axvline(x=S - 0.4, linewidth=2, color="r")
1710+
plt.xlabel("Encoder Output")
1711+
plt.colorbar()
1712+
1713+
# decoder
1714+
plt.subplot(num_figures, 1, 2)
1715+
plt.imshow(
1716+
X=np.transpose(decoder_outputs[b]),
1717+
cmap=plt.get_cmap("jet"),
1718+
aspect="auto",
1719+
interpolation="nearest",
1720+
vmin=vmin,
1721+
vmax=vmax,
1722+
)
1723+
plt.gca().invert_yaxis()
1724+
plt.axvline(x=T - 0.4, linewidth=2, color="r")
1725+
plt.xlabel("Decoder Output")
1726+
plt.colorbar()
1727+
1728+
# target
1729+
plt.subplot(num_figures, 1, 3)
1730+
plt.imshow(
1731+
X=np.transpose(audio_features[b]),
1732+
cmap=plt.get_cmap("jet"),
1733+
aspect="auto",
1734+
interpolation="nearest",
1735+
vmin=vmin,
1736+
vmax=vmax,
1737+
)
1738+
plt.gca().invert_yaxis()
1739+
plt.axvline(x=T - 0.4, linewidth=2, color="r")
1740+
plt.xlabel("Decoder Target")
1741+
plt.colorbar()
1742+
1743+
plt.savefig(f"{output_dir}/{utt_id}.png")
1744+
plt.close()
1745+
16611746

16621747
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
16631748
def top_k_top_p_filtering(

0 commit comments

Comments
 (0)