Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion R/tokenizer.R
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,16 @@ decode_bpe_bytes <- function(text) {
}
}

rawToChar(bytes)
# Write raw bytes to a connection and read back as UTF-8,

# replacing any invalid multibyte sequences
tmp <- tempfile()
on.exit(unlink(tmp), add = TRUE)
writeBin(bytes, tmp)
out <- readLines(tmp, warn = FALSE, encoding = "UTF-8")
out <- paste(out, collapse = "\n")
# Strip any remaining invalid bytes
iconv(out, from = "UTF-8", to = "UTF-8", sub = "")
}

#' Ensure Tokenizer Files are Downloaded
Expand Down
237 changes: 191 additions & 46 deletions R/transcribe.R
Original file line number Diff line number Diff line change
Expand Up @@ -271,68 +271,155 @@ transcribe_chunk <- function(
dtype,
verbose = TRUE
) {
# Convert audio to mel spectrogram
# Convert audio to mel spectrogram (full 30s window)
if (verbose) message("Processing audio...")
mel <- audio_to_mel(file, n_mels = config$n_mels, device = device, dtype = dtype)
full_mel <- audio_to_mel(file, n_mels = config$n_mels, device = device, dtype = dtype)
n_frames <- full_mel$size(3) # 3000 for 30s

# Beam search needs timestamps internally for proper termination
# (without timestamps, the model generates repetitive tokens)
user_timestamps <- timestamps
internal_timestamps <- timestamps || beam_size > 1L

# Get initial decoder tokens (use model name for correct special token IDs)
initial_tokens <- get_initial_tokens(language, task,
model = config$model_name, timestamps = internal_timestamps)
tokens <- torch::torch_tensor(matrix(initial_tokens, nrow = 1),
dtype = torch::torch_long(),
device = device)
special <- whisper_special_tokens(config$model_name)

# Encode audio
if (verbose) message("Encoding audio...")
torch::with_no_grad({
# Seek loop: decode repeatedly, advancing through the mel spectrogram
seek <- 0L # current frame position
all_generated <- integer(0)
all_cross_attn <- if (word_timestamps) list() else NULL
all_segments <- list()
seek_iter <- 0L

while (seek < n_frames) {
seek_iter <- seek_iter + 1L
if (seek_iter > 50L) break # safety limit

# Slice mel from seek position, pad to full width
remaining <- n_frames - seek
if (remaining < 1L) break

if (seek == 0L) {
mel <- full_mel
} else {
# Slice mel[1, :, seek:] and pad to n_frames width
mel_slice <- full_mel[, , (seek + 1L):n_frames]
pad_width <- n_frames - mel_slice$size(3)
if (pad_width > 0L) {
mel <- torch::nnf_pad(mel_slice, c(0L, pad_width), value = 0)
} else {
mel <- mel_slice
}
}

# Compute seek time for this iteration
seek_time <- seek * 0.01 # frames to seconds (10ms per frame)

# Get initial decoder tokens
initial_tokens <- get_initial_tokens(language, task,
model = config$model_name, timestamps = internal_timestamps)
tokens <- torch::torch_tensor(matrix(initial_tokens, nrow = 1),
dtype = torch::torch_long(), device = device)

# Encode audio
torch::with_no_grad({
encoder_output <- model$encode(mel)
})

# Decode
if (verbose) message("Decoding...")
decode_result <- decode_with_fallback(model, encoder_output, tokens,
tokenizer, temperatures = temperatures, beam_size = beam_size,
best_of = best_of, max_length = config$n_text_ctx,
timestamps = internal_timestamps, word_timestamps = word_timestamps,
compression_ratio_threshold = compression_ratio_threshold,
logprob_threshold = logprob_threshold,
length_penalty = length_penalty, patience = patience,
device = device)
# Decode
decode_result <- decode_with_fallback(model, encoder_output, tokens,
tokenizer, temperatures = temperatures, beam_size = beam_size,
best_of = best_of, max_length = config$n_text_ctx,
timestamps = internal_timestamps, word_timestamps = word_timestamps,
compression_ratio_threshold = compression_ratio_threshold,
logprob_threshold = logprob_threshold,
length_penalty = length_penalty, patience = patience,
device = device)

generated <- decode_result$tokens

# Find the last timestamp token to determine where to seek next
last_ts_frame <- 0L
for (tok in generated) {
if (tok >= special$timestamp_begin) {
ts_seconds <- (tok - special$timestamp_begin) * 0.02
ts_frame <- as.integer(ts_seconds * 100) # seconds to frames (10ms)
if (ts_frame > last_ts_frame) last_ts_frame <- ts_frame
}
}

# Extract segments with proper time offset
if (user_timestamps) {
segments <- extract_segments(generated, tokenizer,
time_offset = time_offset + seek_time)
if (nrow(segments) > 0) {
all_segments <- c(all_segments, list(segments))
}
}

# Collect cross-attention weights with seek offset
if (word_timestamps && !is.null(decode_result$cross_attn_weights)) {
all_cross_attn <- c(all_cross_attn, list(list(
weights = decode_result$cross_attn_weights,
tokens = generated,
initial_tokens = initial_tokens,
seek_time = seek_time
)))
}

# Advance seek position
if (last_ts_frame > 0L) {
seek <- seek + last_ts_frame
} else {
# No timestamp found — model produced no timed output, skip ahead
break
}

# If last timestamp covered nearly the full remaining audio, stop
if (last_ts_frame >= remaining - 100L) break
}

generated <- decode_result$tokens
cross_attn_weights <- decode_result$cross_attn_weights
if (verbose && seek_iter > 1L) {
message(" Seek loop: ", seek_iter, " iterations")
}

# Build result
if (user_timestamps) {
# Extract segments from timestamp tokens
segments <- extract_segments(generated, tokenizer, time_offset = time_offset)
text <- paste(segments$text, collapse = " ")
text <- clean_text(text)
if (length(all_segments) > 0) {
segments <- do.call(rbind, all_segments)
text <- paste(segments$text, collapse = " ")
text <- clean_text(text)
} else {
segments <- data.frame(start = numeric(0), end = numeric(0),
text = character(0))
text <- ""
}
result <- list(text = text, language = language, segments = segments)
} else {
# Strip timestamp tokens if used internally for beam search
# For non-timestamp mode, combine all generated tokens
# (strip timestamp tokens if used internally)
if (internal_timestamps) {
special <- whisper_special_tokens(config$model_name)
generated <- generated[generated < special$timestamp_begin]
all_generated <- all_generated[all_generated < special$timestamp_begin]
}
text <- tokenizer$decode(generated)
text <- tokenizer$decode(all_generated)
text <- clean_text(text)
result <- list(text = text, language = language)
}

# Word-level timestamps via cross-attention DTW
if (word_timestamps && !is.null(cross_attn_weights)) {
special <- whisper_special_tokens(config$model_name)
sample_begin <- length(initial_tokens)
words <- compute_word_timestamps(generated, cross_attn_weights,
tokenizer, config, time_offset = time_offset,
sample_begin = sample_begin)
result$words <- words
# Word-level timestamps via cross-attention DTW (per seek iteration)
if (word_timestamps && length(all_cross_attn) > 0) {
all_words <- list()
for (ca in all_cross_attn) {
sample_begin <- length(ca$initial_tokens)
words <- compute_word_timestamps(ca$tokens, ca$weights,
tokenizer, config,
time_offset = time_offset + ca$seek_time,
sample_begin = sample_begin)
if (!is.null(words) && nrow(words) > 0) {
all_words <- c(all_words, list(words))
}
}
if (length(all_words) > 0) {
result$words <- do.call(rbind, all_words)
}
}

result
Expand Down Expand Up @@ -636,25 +723,83 @@ transcribe_long <- function(
}
}

# Get actual audio duration to filter hallucinations from padded chunks
audio_dur <- audio_duration(file)

# Combine results
result <- list(
text = paste(all_text, collapse = " "),
language = language
)

if (timestamps) {
result$segments <- if (length(all_segments) > 0) {
do.call(rbind, all_segments)
if (length(all_segments) > 0) {
combined <- do.call(rbind, all_segments)
# Remove segments that start after the actual audio duration
combined <- combined[combined$start < audio_dur, , drop = FALSE]
# Cap end times to audio duration
combined$end <- pmin(combined$end, audio_dur)
# Deduplicate overlapping segments at chunk boundaries.
# Strategy: when two segments overlap, keep the later one (from the
# chunk that has actual audio for that time region) unless it looks
# hallucinated (very short text).
if (nrow(combined) > 1) {
keep <- rep(TRUE, nrow(combined))
for (j in 2:nrow(combined)) {
if (combined$start[j] < combined$end[j - 1] - 0.1) {
# Overlap detected
prev_len <- nchar(combined$text[j - 1])
curr_len <- nchar(combined$text[j])
if (curr_len < 5) {
# Later segment is likely hallucination, drop it
keep[j] <- FALSE
} else if (prev_len < 5) {
# Previous segment is likely hallucination, drop it
keep[j - 1] <- FALSE
} else {
# Both substantial: trim previous to end before overlap starts
combined$end[j - 1] <- combined$start[j]
}
}
}
combined <- combined[keep, , drop = FALSE]
}
# Remove likely hallucinated segments (very short duration + short text)
seg_dur <- combined$end - combined$start
combined <- combined[!(seg_dur < 0.5 & nchar(combined$text) < 15), , drop = FALSE]
result$segments <- combined
} else {
data.frame(start = numeric(0), end = numeric(0), text = character(0))
result$segments <- data.frame(start = numeric(0), end = numeric(0),
text = character(0))
}
}

if (word_timestamps) {
result$words <- if (length(all_words) > 0) {
do.call(rbind, all_words)
if (length(all_words) > 0) {
combined_words <- do.call(rbind, all_words)
# Remove words that start after actual audio duration
combined_words <- combined_words[combined_words$start < audio_dur, , drop = FALSE]
# Cap word end times
combined_words$end <- pmin(combined_words$end, audio_dur)
# Remove duplicate words from chunk overlap (keep first occurrence)
if (nrow(combined_words) > 1) {
keep <- rep(TRUE, nrow(combined_words))
for (j in 2:nrow(combined_words)) {
if (combined_words$start[j] < combined_words$end[j - 1] - 0.05) {
keep[j] <- FALSE
}
}
combined_words <- combined_words[keep, , drop = FALSE]
}
# Filter words to only those within retained segments
if (!is.null(result$segments) && nrow(result$segments) > 0) {
seg_end <- max(result$segments$end)
combined_words <- combined_words[combined_words$start < seg_end, , drop = FALSE]
}
result$words <- combined_words
} else {
data.frame(word = character(0), start = numeric(0), end = numeric(0))
result$words <- data.frame(word = character(0), start = numeric(0),
end = numeric(0))
}
}

Expand Down