diff --git a/R/tokenizer.R b/R/tokenizer.R index e74ca34..c454525 100644 --- a/R/tokenizer.R +++ b/R/tokenizer.R @@ -263,7 +263,16 @@ decode_bpe_bytes <- function(text) { } } - rawToChar(bytes) + # Write raw bytes to a connection and read back as UTF-8, + + # replacing any invalid multibyte sequences + tmp <- tempfile() + on.exit(unlink(tmp), add = TRUE) + writeBin(bytes, tmp) + out <- readLines(tmp, warn = FALSE, encoding = "UTF-8") + out <- paste(out, collapse = "\n") + # Strip any remaining invalid bytes + iconv(out, from = "UTF-8", to = "UTF-8", sub = "") } #' Ensure Tokenizer Files are Downloaded diff --git a/R/transcribe.R b/R/transcribe.R index 431b182..5c0855e 100644 --- a/R/transcribe.R +++ b/R/transcribe.R @@ -271,68 +271,155 @@ transcribe_chunk <- function( dtype, verbose = TRUE ) { - # Convert audio to mel spectrogram + # Convert audio to mel spectrogram (full 30s window) if (verbose) message("Processing audio...") - mel <- audio_to_mel(file, n_mels = config$n_mels, device = device, dtype = dtype) + full_mel <- audio_to_mel(file, n_mels = config$n_mels, device = device, dtype = dtype) + n_frames <- full_mel$size(3) # 3000 for 30s # Beam search needs timestamps internally for proper termination - # (without timestamps, the model generates repetitive tokens) user_timestamps <- timestamps internal_timestamps <- timestamps || beam_size > 1L - # Get initial decoder tokens (use model name for correct special token IDs) - initial_tokens <- get_initial_tokens(language, task, - model = config$model_name, timestamps = internal_timestamps) - tokens <- torch::torch_tensor(matrix(initial_tokens, nrow = 1), - dtype = torch::torch_long(), - device = device) + special <- whisper_special_tokens(config$model_name) - # Encode audio - if (verbose) message("Encoding audio...") - torch::with_no_grad({ + # Seek loop: decode repeatedly, advancing through the mel spectrogram + seek <- 0L # current frame position + all_generated <- integer(0) + all_cross_attn <- if (word_timestamps) list() else NULL + all_segments <- list() + seek_iter <- 0L + + while (seek < n_frames) { + seek_iter <- seek_iter + 1L + if (seek_iter > 50L) break # safety limit + + # Slice mel from seek position, pad to full width + remaining <- n_frames - seek + if (remaining < 1L) break + + if (seek == 0L) { + mel <- full_mel + } else { + # Slice mel[1, :, seek:] and pad to n_frames width + mel_slice <- full_mel[, , (seek + 1L):n_frames] + pad_width <- n_frames - mel_slice$size(3) + if (pad_width > 0L) { + mel <- torch::nnf_pad(mel_slice, c(0L, pad_width), value = 0) + } else { + mel <- mel_slice + } + } + + # Compute seek time for this iteration + seek_time <- seek * 0.01 # frames to seconds (10ms per frame) + + # Get initial decoder tokens + initial_tokens <- get_initial_tokens(language, task, + model = config$model_name, timestamps = internal_timestamps) + tokens <- torch::torch_tensor(matrix(initial_tokens, nrow = 1), + dtype = torch::torch_long(), device = device) + + # Encode audio + torch::with_no_grad({ encoder_output <- model$encode(mel) }) - # Decode - if (verbose) message("Decoding...") - decode_result <- decode_with_fallback(model, encoder_output, tokens, - tokenizer, temperatures = temperatures, beam_size = beam_size, - best_of = best_of, max_length = config$n_text_ctx, - timestamps = internal_timestamps, word_timestamps = word_timestamps, - compression_ratio_threshold = compression_ratio_threshold, - logprob_threshold = logprob_threshold, - length_penalty = length_penalty, patience = patience, - device = device) + # Decode + decode_result <- decode_with_fallback(model, encoder_output, tokens, + tokenizer, temperatures = temperatures, beam_size = beam_size, + best_of = best_of, max_length = config$n_text_ctx, + timestamps = internal_timestamps, word_timestamps = word_timestamps, + compression_ratio_threshold = compression_ratio_threshold, + logprob_threshold = logprob_threshold, + length_penalty = length_penalty, patience = patience, + device = device) + + generated <- decode_result$tokens + + # Find the last timestamp token to determine where to seek next + last_ts_frame <- 0L + for (tok in generated) { + if (tok >= special$timestamp_begin) { + ts_seconds <- (tok - special$timestamp_begin) * 0.02 + ts_frame <- as.integer(ts_seconds * 100) # seconds to frames (10ms) + if (ts_frame > last_ts_frame) last_ts_frame <- ts_frame + } + } + + # Extract segments with proper time offset + if (user_timestamps) { + segments <- extract_segments(generated, tokenizer, + time_offset = time_offset + seek_time) + if (nrow(segments) > 0) { + all_segments <- c(all_segments, list(segments)) + } + } + + # Collect cross-attention weights with seek offset + if (word_timestamps && !is.null(decode_result$cross_attn_weights)) { + all_cross_attn <- c(all_cross_attn, list(list( + weights = decode_result$cross_attn_weights, + tokens = generated, + initial_tokens = initial_tokens, + seek_time = seek_time + ))) + } + + # Advance seek position + if (last_ts_frame > 0L) { + seek <- seek + last_ts_frame + } else { + # No timestamp found — model produced no timed output, skip ahead + break + } + + # If last timestamp covered nearly the full remaining audio, stop + if (last_ts_frame >= remaining - 100L) break + } - generated <- decode_result$tokens - cross_attn_weights <- decode_result$cross_attn_weights + if (verbose && seek_iter > 1L) { + message(" Seek loop: ", seek_iter, " iterations") + } # Build result if (user_timestamps) { - # Extract segments from timestamp tokens - segments <- extract_segments(generated, tokenizer, time_offset = time_offset) - text <- paste(segments$text, collapse = " ") - text <- clean_text(text) + if (length(all_segments) > 0) { + segments <- do.call(rbind, all_segments) + text <- paste(segments$text, collapse = " ") + text <- clean_text(text) + } else { + segments <- data.frame(start = numeric(0), end = numeric(0), + text = character(0)) + text <- "" + } result <- list(text = text, language = language, segments = segments) } else { - # Strip timestamp tokens if used internally for beam search + # For non-timestamp mode, combine all generated tokens + # (strip timestamp tokens if used internally) if (internal_timestamps) { - special <- whisper_special_tokens(config$model_name) - generated <- generated[generated < special$timestamp_begin] + all_generated <- all_generated[all_generated < special$timestamp_begin] } - text <- tokenizer$decode(generated) + text <- tokenizer$decode(all_generated) text <- clean_text(text) result <- list(text = text, language = language) } - # Word-level timestamps via cross-attention DTW - if (word_timestamps && !is.null(cross_attn_weights)) { - special <- whisper_special_tokens(config$model_name) - sample_begin <- length(initial_tokens) - words <- compute_word_timestamps(generated, cross_attn_weights, - tokenizer, config, time_offset = time_offset, - sample_begin = sample_begin) - result$words <- words + # Word-level timestamps via cross-attention DTW (per seek iteration) + if (word_timestamps && length(all_cross_attn) > 0) { + all_words <- list() + for (ca in all_cross_attn) { + sample_begin <- length(ca$initial_tokens) + words <- compute_word_timestamps(ca$tokens, ca$weights, + tokenizer, config, + time_offset = time_offset + ca$seek_time, + sample_begin = sample_begin) + if (!is.null(words) && nrow(words) > 0) { + all_words <- c(all_words, list(words)) + } + } + if (length(all_words) > 0) { + result$words <- do.call(rbind, all_words) + } } result @@ -636,6 +723,9 @@ transcribe_long <- function( } } + # Get actual audio duration to filter hallucinations from padded chunks + audio_dur <- audio_duration(file) + # Combine results result <- list( text = paste(all_text, collapse = " "), @@ -643,18 +733,73 @@ transcribe_long <- function( ) if (timestamps) { - result$segments <- if (length(all_segments) > 0) { - do.call(rbind, all_segments) + if (length(all_segments) > 0) { + combined <- do.call(rbind, all_segments) + # Remove segments that start after the actual audio duration + combined <- combined[combined$start < audio_dur, , drop = FALSE] + # Cap end times to audio duration + combined$end <- pmin(combined$end, audio_dur) + # Deduplicate overlapping segments at chunk boundaries. + # Strategy: when two segments overlap, keep the later one (from the + # chunk that has actual audio for that time region) unless it looks + # hallucinated (very short text). + if (nrow(combined) > 1) { + keep <- rep(TRUE, nrow(combined)) + for (j in 2:nrow(combined)) { + if (combined$start[j] < combined$end[j - 1] - 0.1) { + # Overlap detected + prev_len <- nchar(combined$text[j - 1]) + curr_len <- nchar(combined$text[j]) + if (curr_len < 5) { + # Later segment is likely hallucination, drop it + keep[j] <- FALSE + } else if (prev_len < 5) { + # Previous segment is likely hallucination, drop it + keep[j - 1] <- FALSE + } else { + # Both substantial: trim previous to end before overlap starts + combined$end[j - 1] <- combined$start[j] + } + } + } + combined <- combined[keep, , drop = FALSE] + } + # Remove likely hallucinated segments (very short duration + short text) + seg_dur <- combined$end - combined$start + combined <- combined[!(seg_dur < 0.5 & nchar(combined$text) < 15), , drop = FALSE] + result$segments <- combined } else { - data.frame(start = numeric(0), end = numeric(0), text = character(0)) + result$segments <- data.frame(start = numeric(0), end = numeric(0), + text = character(0)) } } if (word_timestamps) { - result$words <- if (length(all_words) > 0) { - do.call(rbind, all_words) + if (length(all_words) > 0) { + combined_words <- do.call(rbind, all_words) + # Remove words that start after actual audio duration + combined_words <- combined_words[combined_words$start < audio_dur, , drop = FALSE] + # Cap word end times + combined_words$end <- pmin(combined_words$end, audio_dur) + # Remove duplicate words from chunk overlap (keep first occurrence) + if (nrow(combined_words) > 1) { + keep <- rep(TRUE, nrow(combined_words)) + for (j in 2:nrow(combined_words)) { + if (combined_words$start[j] < combined_words$end[j - 1] - 0.05) { + keep[j] <- FALSE + } + } + combined_words <- combined_words[keep, , drop = FALSE] + } + # Filter words to only those within retained segments + if (!is.null(result$segments) && nrow(result$segments) > 0) { + seg_end <- max(result$segments$end) + combined_words <- combined_words[combined_words$start < seg_end, , drop = FALSE] + } + result$words <- combined_words } else { - data.frame(word = character(0), start = numeric(0), end = numeric(0)) + result$words <- data.frame(word = character(0), start = numeric(0), + end = numeric(0)) } }