Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export(load_s3gen)
export(load_s3gen_weights)
export(load_s3tokenizer_weights)
export(models_available)
export(normalize_tts_text)
export(quick_tts)
export(read_audio)
export(read_safetensors)
Expand Down
31 changes: 22 additions & 9 deletions R/t3.R
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ t3_inference <- function (model, cond, text_tokens, max_new_tokens = 1000,
# Track generated tokens (only conditional path for CFG)
generated_ids <- bos_token[1,, drop = FALSE]$clone()
predicted <- list()
eos_found <- FALSE

# Generation loop
for (i in seq_len(max_new_tokens)) {
Expand Down Expand Up @@ -574,6 +575,7 @@ t3_inference <- function (model, cond, text_tokens, max_new_tokens = 1000,
token_id <- as.integer(next_token$cpu()) - 1L
if (token_id == config$stop_speech_token) {
message("EOS detected at step ", i)
eos_found <- TRUE
break
}

Expand All @@ -598,10 +600,11 @@ t3_inference <- function (model, cond, text_tokens, max_new_tokens = 1000,
if (length(predicted) > 0) {
tokens <- torch::torch_cat(predicted, dim = 2)$squeeze(1)
tokens <- tokens$sub(1L) # Convert to 0-indexed token IDs
tokens
} else {
torch::torch_tensor(integer(0), device = device)
tokens <- torch::torch_tensor(integer(0), device = device)
}
attr(tokens, "eos_found") <- eos_found
tokens
}

# ============================================================================
Expand Down Expand Up @@ -747,6 +750,7 @@ t3_inference_traced <- function(model, cond, text_tokens, max_new_tokens = 1000,
# Track generated tokens
generated_ids <- bos_token[1,, drop = FALSE]$clone()
predicted <- list()
eos_found <- FALSE

# === GENERATION LOOP: Use traced transformer ===
# Limit generation to fit in cache
Expand Down Expand Up @@ -808,6 +812,7 @@ t3_inference_traced <- function(model, cond, text_tokens, max_new_tokens = 1000,
token_id <- as.integer(next_token$cpu()) - 1L
if (token_id == config$stop_speech_token) {
message("EOS detected at step ", i)
eos_found <- TRUE
break
}

Expand Down Expand Up @@ -864,10 +869,11 @@ t3_inference_traced <- function(model, cond, text_tokens, max_new_tokens = 1000,
if (length(predicted) > 0L) {
tokens <- torch::torch_cat(predicted, dim = 2L)$squeeze(1L)
tokens <- tokens$sub(1L)
tokens
} else {
torch::torch_tensor(integer(0), device = device)
tokens <- torch::torch_tensor(integer(0), device = device)
}
attr(tokens, "eos_found") <- eos_found
tokens
}

# ============================================================================
Expand Down Expand Up @@ -1036,17 +1042,21 @@ t3_inference_cpp <- function (model, cond, text_tokens, max_new_tokens = 1000,
)

# token_ids is an integer vector of 0-indexed token IDs (from C++)
eos_found <- FALSE
if (length(token_ids) > 0L) {
# Check for EOS and strip it
eos_pos <- match(config$stop_speech_token, token_ids)
if (!is.na(eos_pos)) {
message("EOS detected at step ", eos_pos)
eos_found <- TRUE
token_ids <- token_ids[seq_len(eos_pos - 1L)]
}
torch::torch_tensor(token_ids, dtype = torch::torch_long(), device = device)
tokens <- torch::torch_tensor(token_ids, dtype = torch::torch_long(), device = device)
} else {
torch::torch_tensor(integer(0), device = device)
tokens <- torch::torch_tensor(integer(0), device = device)
}
attr(tokens, "eos_found") <- eos_found
tokens
}

# ============================================================================
Expand Down Expand Up @@ -1605,6 +1615,7 @@ t3_inference_turbo <- function (model, cond, text_tokens, max_new_tokens = 1000,
past_key_values <- output$past_key_values

generated_tokens <- list()
eos_found <- FALSE

# Get first logits from last position
hidden <- output$last_hidden_state[, -1L, , drop = FALSE]
Expand All @@ -1621,6 +1632,7 @@ t3_inference_turbo <- function (model, cond, text_tokens, max_new_tokens = 1000,
token_id <- as.integer(current_token$cpu()) - 1L
if (token_id == config$stop_speech_token) {
message("EOS at step 1")
eos_found <- TRUE
} else {
# Generation loop
for (i in seq_len(max_new_tokens)) {
Expand Down Expand Up @@ -1653,6 +1665,7 @@ t3_inference_turbo <- function (model, cond, text_tokens, max_new_tokens = 1000,
token_id <- as.integer(next_token$cpu()) - 1L
if (token_id == config$stop_speech_token) {
message("EOS at step ", i + 1L)
eos_found <- TRUE
break
}
}
Expand All @@ -1669,11 +1682,11 @@ t3_inference_turbo <- function (model, cond, text_tokens, max_new_tokens = 1000,
if (length(token_vals) > 0L && token_vals[length(token_vals)] == config$stop_speech_token) {
tokens <- tokens[1:(tokens$size(1) - 1L)]
}

tokens
} else {
torch::torch_tensor(integer(0), device = device)
tokens <- torch::torch_tensor(integer(0), device = device)
}
attr(tokens, "eos_found") <- eos_found
tokens
}

#' Sample a token using turbo logit processors
Expand Down
113 changes: 106 additions & 7 deletions R/tts.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,62 @@
# chatterbox - High-level Text-to-Speech API
# Provides simple interface for TTS generation using the Chatterbox engine

# ============================================================================
# Text Normalization
# ============================================================================

#' Normalize text for TTS
#'
#' Lowercases words that contain internal capital letters (e.g.
#' "ALERT", "Rarely"). The Chatterbox model interprets internal capitals
#' as emphasis cues, which often causes it to produce only the first word
#' followed by silence. Sentence-initial capitals are left alone.
#'
#' @param text Character scalar.
#' @return Normalized text.
#' @export
normalize_tts_text <- function (text)
{
if (!is.character(text) || length(text) != 1L || is.na(text)) {
return(text)
}
# Split into tokens preserving whitespace and punctuation
parts <- strsplit(text, "(\\s+)", perl = TRUE)[[1]]
if (length(parts) == 0L) return(text)

# Track sentence boundary: first word, or word right after .!?
prev_was_sentence_end <- TRUE
out <- character(length(parts))
for (i in seq_along(parts)) {
word <- parts[i]
letters_only <- gsub("[^A-Za-z]", "", word)
is_capitalized <- nzchar(letters_only) &&
grepl("^[A-Z]", letters_only)
has_internal_caps <- nzchar(letters_only) &&
grepl("[A-Z]", substring(letters_only, 2L))
is_all_caps <- nzchar(letters_only) &&
letters_only == toupper(letters_only) && nchar(letters_only) > 1L

# Lowercase if:
# - all caps (and longer than 1 letter, to skip "I"), OR
# - internal caps (camelCase / weirdCase), OR
# - capitalized mid-sentence (not first word and not after .!?),
# except for the standalone pronoun "I"
is_pronoun_i <- letters_only == "I"
should_lower <- is_all_caps || has_internal_caps ||
(is_capitalized && !prev_was_sentence_end && !is_pronoun_i)
if (should_lower) {
out[i] <- tolower(word)
} else {
out[i] <- word
}

# Update sentence-end tracker for next word
prev_was_sentence_end <- grepl("[.!?]\\s*$", word)
}
paste(out, collapse = " ")
}

# ============================================================================
# Chatterbox TTS Model
# ============================================================================
Expand Down Expand Up @@ -251,17 +307,35 @@ create_voice_embedding <- function (model, audio, sample_rate = NULL, autocast =
#' @param backend Character. Inference backend, either "r" or "cpp". Default "r".
#' @param top_k Integer. Top-k sampling parameter. Default 1000.
#' @param repetition_penalty Numeric. Repetition penalty. Default 1.2.
#' @return List with audio (numeric vector) and sample_rate
#' @param normalize_text Logical. If TRUE (default), pre-process text to
#' reduce model failure modes: lowercase words with internal capitals
#' (which the model interprets as emphasis cues and often produces
#' silent audio for). Set to FALSE to pass text through unchanged.
#' @return List with elements:
#' \describe{
#' \item{audio}{Numeric vector of audio samples}
#' \item{sample_rate}{Sample rate in Hz}
#' \item{eos_found}{Logical. Whether the model emitted an end-of-speech
#' token (TRUE) or hit the token cap (FALSE). FALSE often indicates
#' garbage output and a need to retry or split the input.}
#' \item{n_tokens}{Number of speech tokens generated}
#' \item{audio_sec}{Audio duration in seconds}
#' }
#' @export
generate <- function (model, text, voice, exaggeration = 0.5, cfg_weight = 0.5,
temperature = 0.8, top_p = 0.9, autocast = NULL,
traced = FALSE, backend = c("r", "cpp"),
top_k = 1000L, repetition_penalty = 1.2)
top_k = 1000L, repetition_penalty = 1.2,
normalize_text = TRUE)
{
if (!is_loaded(model)) {
stop("Model not loaded. Call load_chatterbox() first.")
}

if (isTRUE(normalize_text)) {
text <- normalize_tts_text(text)
}

device <- model$device
is_turbo <- isTRUE(model$turbo)
# Default: autocast on CUDA, off on CPU
Expand Down Expand Up @@ -362,12 +436,23 @@ generate <- function (model, text, voice, exaggeration = 0.5, cfg_weight = 0.5,
}
}

# Capture EOS status before drop_invalid_tokens strips the attribute
eos_found <- isTRUE(attr(speech_tokens, "eos_found"))

# Drop invalid tokens
speech_tokens <- drop_invalid_tokens(speech_tokens)
n_tokens <- as.integer(speech_tokens$size(1L))

if (length(speech_tokens) == 0) {
warning("No valid speech tokens generated")
return(list(audio = numeric(0), sample_rate = S3GEN_SR))
return(list(audio = numeric(0), sample_rate = S3GEN_SR,
eos_found = eos_found, n_tokens = 0L, audio_sec = 0))
}

if (!eos_found) {
warning("Generation hit token cap without emitting end-of-speech ",
"(", n_tokens, " tokens). Output may be garbage; ",
"consider splitting the input or retrying.")
}

# Convert to integer vector and add silence for turbo
Expand Down Expand Up @@ -414,12 +499,16 @@ generate <- function (model, text, voice, exaggeration = 0.5, cfg_weight = 0.5,

# Convert to numeric
audio_samples <- as.numeric(audio$squeeze()$cpu())
audio_sec <- length(audio_samples) / S3GEN_SR

message("Done! Generated ", round(length(audio_samples) / S3GEN_SR, 2), " seconds of audio.")
message("Done! Generated ", round(audio_sec, 2), " seconds of audio.")

list(
audio = audio_samples,
sample_rate = S3GEN_SR
sample_rate = S3GEN_SR,
eos_found = eos_found,
n_tokens = n_tokens,
audio_sec = audio_sec
)
}

Expand All @@ -430,15 +519,25 @@ generate <- function (model, text, voice, exaggeration = 0.5, cfg_weight = 0.5,
#' @param voice Voice embedding or path to reference audio
#' @param output_path Output file path (WAV format)
#' @param ... Additional arguments passed to generate()
#' @return Invisibly returns the output path
#' @return Invisibly returns a list with elements: \code{path},
#' \code{eos_found}, \code{n_tokens}, \code{audio_sec}. When iterating
#' over many texts, collect these into a data.frame to identify which
#' inputs failed (\code{eos_found = FALSE}) and need reprocessing.
#' @export
tts_to_file <- function (model, text, voice, output_path, ...)
{
result <- generate(model, text, voice, ...)
write_audio(result$audio, result$sample_rate, output_path)
invisible(output_path)
invisible(list(
path = output_path,
eos_found = isTRUE(result$eos_found),
n_tokens = result$n_tokens %||% NA_integer_,
audio_sec = result$audio_sec %||% NA_real_
))
}

`%||%` <- function (a, b) if (is.null(a)) b else a

# ============================================================================
# Streaming TTS (for longer texts)
# ============================================================================
Expand Down
41 changes: 41 additions & 0 deletions inst/tinytest/test_normalize_text.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
library(chatterbox)

# Sentence-initial caps preserved; mid-sentence caps lowercased
expect_equal(
normalize_tts_text("Yes, Rarely or never Almost never."),
"Yes, rarely or never almost never."
)

# Pronoun "I" stays capitalized mid-sentence
expect_equal(
normalize_tts_text("As I said earlier, homework is a battle."),
"As I said earlier, homework is a battle."
)

# Caps after sentence boundary stay capitalized
expect_equal(
normalize_tts_text("Very often. As I said earlier."),
"Very often. As I said earlier."
)

# Caps after semicolon get lowercased (mid-sentence)
expect_equal(
normalize_tts_text("My head hurts; My stomach hurts."),
"My head hurts; my stomach hurts."
)

# All-caps emphasis words get lowercased
expect_equal(
normalize_tts_text("This is ALERT level."),
"This is alert level."
)

# Internal caps (camelCase / weirdCase) get lowercased
expect_equal(
normalize_tts_text("The rarelY pattern."),
"The rarely pattern."
)

# Empty / non-character inputs pass through
expect_equal(normalize_tts_text(""), "")
expect_equal(normalize_tts_text(NA_character_), NA_character_)
19 changes: 17 additions & 2 deletions man/generate.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
\usage{
generate(model, text, voice, exaggeration = 0.5, cfg_weight = 0.5,
temperature = 0.8, top_p = 0.9, autocast = NULL, traced = FALSE,
backend = c("r", "cpp"), top_k = 1000L, repetition_penalty = 1.2)
backend = c("r", "cpp"), top_k = 1000L, repetition_penalty = 1.2,
normalize_text = TRUE)
}
\arguments{
\item{model}{Chatterbox model}
Expand All @@ -31,9 +32,23 @@ generate(model, text, voice, exaggeration = 0.5, cfg_weight = 0.5,
\item{top_k}{Integer. Top-k sampling parameter. Default 1000.}

\item{repetition_penalty}{Numeric. Repetition penalty. Default 1.2.}

\item{normalize_text}{Logical. If TRUE (default), pre-process text to
reduce model failure modes: lowercase words with internal capitals
(which the model interprets as emphasis cues and often produces
silent audio for). Set to FALSE to pass text through unchanged.}
}
\value{
List with audio (numeric vector) and sample_rate
List with elements:
\describe{
\item{audio}{Numeric vector of audio samples}
\item{sample_rate}{Sample rate in Hz}
\item{eos_found}{Logical. Whether the model emitted an end-of-speech
token (TRUE) or hit the token cap (FALSE). FALSE often indicates
garbage output and a need to retry or split the input.}
\item{n_tokens}{Number of speech tokens generated}
\item{audio_sec}{Audio duration in seconds}
}
}
\description{
Generate speech from text
Expand Down
Loading