Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added examples/__init__.py
Empty file.
97 changes: 82 additions & 15 deletions examples/cmd/run.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import os, sys

if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

now_dir = os.getcwd()
sys.path.append(now_dir)

import argparse
from typing import Optional, List
import argparse
import os
import sys

import numpy as np
import torch

import ChatTTS

from tools.audio import pcm_arr_to_mp3_view
from tools.logger import get_logger
from tools.audio import pcm_arr_to_mp3_view
from tools.normalizer.en import normalizer_en_nemo_text
from tools.normalizer.zh import normalizer_zh_tn

if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

now_dir = os.getcwd()
sys.path.append(now_dir)

logger = get_logger("Command")

Expand All @@ -27,12 +29,60 @@ def save_mp3_file(wav, index):
logger.info(f"Audio saved to {mp3_filename}")


def main(texts: List[str], spk: Optional[str] = None, stream=False):
def ndarray_to_tensor(audio: np.ndarray) -> torch.Tensor:
# Assuming 'wavs' is a NumPy array of shape (num_samples,) or (num_channels, num_samples)
wav_tensor = torch.from_numpy(audio.astype(np.float32)) # Ensure data is float32

# If 'wavs' is 1D, add a channel dimension
if wav_tensor.dim() == 1:
wav_tensor = wav_tensor.unsqueeze(0)

return wav_tensor


def load_normalizer(chat: ChatTTS.Chat):
# try to load normalizer
try:
chat.normalizer.register("en", normalizer_en_nemo_text())
except ValueError as e:
logger.error(e)
except BaseException:
logger.warning("Package nemo_text_processing not found!")
logger.warning(
"Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing",
)
try:
chat.normalizer.register("zh", normalizer_zh_tn())
except ValueError as e:
logger.error(e)
except BaseException:
logger.warning("Package WeTextProcessing not found!")
logger.warning(
"Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing",
)


def main(texts: List[str],
spk: Optional[str] = None,
stream: bool = False,
source: str = "local",
custom_path: str = "",
):
logger.info("Text input: %s", str(texts))

chat = ChatTTS.Chat(get_logger("ChatTTS"))
logger.info("Initializing ChatTTS...")
if chat.load():
load_normalizer(chat)

is_load = False
if os.path.isdir(custom_path) and source == "custom":
is_load = chat.load(compile=True,
source="custom",
custom_path=custom_path)
else:
is_load = chat.load(compile=True, source=source)

if is_load:
logger.info("Models loaded successfully.")
else:
logger.error("Models load failed.")
Expand Down Expand Up @@ -69,10 +119,14 @@ def main(texts: List[str], spk: Optional[str] = None, stream=False):


if __name__ == "__main__":
r"""
python -m examples.cmd.run \
--source custom --custom_path ../../models/2Noise/ChatTTS 你好喲 ":)"
"""
logger.info("Starting ChatTTS commandline demo...")
parser = argparse.ArgumentParser(
description="ChatTTS Command",
usage='[--spk xxx] [--stream] "Your text 1." " Your text 2."',
usage='[--spk xxx] [--stream] [--source ***] [--custom_path XXX] "Your text 1." " Your text 2."',
)
parser.add_argument(
"--spk",
Expand All @@ -85,12 +139,25 @@ def main(texts: List[str], spk: Optional[str] = None, stream=False):
help="Use stream mode",
action="store_true",
)
parser.add_argument(
"--source",
help="source form [ huggingface(hf download), local(ckpt save to asset dir), custom(define) ]",
type=str,
default="local",
)
parser.add_argument(
"--custom_path",
help="custom defined model path(include asset ckpt dir)",
type=str,
default="",
)
parser.add_argument(
"texts",
help="Original text",
default=["YOUR TEXT HERE"],
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
main(args.texts, args.spk, args.stream)
logger.info(args)
main(args.texts, args.spk, args.stream, args.source, args.custom_path)
logger.info("ChatTTS process finished.")