diff --git a/CLAUDE.md b/CLAUDE.md index 292f755..648bdaf 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -20,7 +20,8 @@ Audio (WAV/MP3) -> Mel Spectrogram -> Encoder (transformer) -> Decoder (cross-at ## Key Exports -- `transcribe(file, model, language)` - Main transcription function +- `transcribe(file, model, language, timestamps, word_timestamps)` - Main transcription function +- `whisper_pipeline(model)` - Load model once, call `$transcribe()` repeatedly - `load_whisper_model(model, device, dtype)` - Load model weights - `audio_to_mel(file, n_mels)` - Convert audio to mel spectrogram - `whisper_tokenizer()` - Get BPE tokenizer @@ -32,7 +33,15 @@ library(whisper) # Transcribe audio result <- transcribe("audio.wav", model = "tiny") -print(result$text) +result$text + +# Segment timestamps (uses Whisper's built-in timestamp tokens) +result <- transcribe("audio.wav", timestamps = TRUE) +result$segments # data.frame(start, end, text) + +# Word-level timestamps (cross-attention DTW alignment) +result <- transcribe("audio.wav", word_timestamps = TRUE) +result$words # data.frame(word, start, end) ``` ## Development @@ -56,13 +65,14 @@ Uses safetensors format from HuggingFace: ## File Structure -- `R/transcribe.R` - Main API +- `R/transcribe.R` - Main API, greedy decode, timestamp logit rules +- `R/alignment.R` - DTW alignment, word timestamp computation - `R/audio.R` - Audio to mel spectrogram -- `R/encoder.R` - Encoder transformer +- `R/encoder.R` - Encoder transformer (with `need_weights` dual-path attention) - `R/decoder.R` - Decoder with cross-attention - `R/model.R` - Full model + weight loading - `R/tokenizer.R` - Whisper BPE tokenizer -- `R/config.R` - Model configurations +- `R/config.R` - Model configurations + alignment heads - `R/download.R` - HuggingFace model download - `R/devices.R` - Device/dtype management @@ -75,10 +85,12 @@ Uses safetensors format from HuggingFace: - Transcription and translation (any language to English) - All model sizes: tiny, base, small, medium, large-v3 - CPU and CUDA support +- Segment-level timestamps (Whisper timestamp tokens with logit suppression) +- Word-level timestamps (cross-attention DTW alignment) - Pre-computed mel filterbank from official Whisper - HuggingFace model downloads via `hfhub` - KV cache for efficient incremental decoding -- Long audio support (automatic chunking) +- Long audio support (automatic chunking with time offsets) ### R torch notes @@ -88,12 +100,9 @@ Uses safetensors format from HuggingFace: ### Known Limitations -- UTF-8 encoding issues with some non-ASCII characters in output - Translation quality varies by model size (larger models work better) - No beam search (greedy decoding only) ### Potential Improvements - Beam search decoding -- Word-level timestamps (requires cross-attention analysis) -- Fix UTF-8 byte decoding in tokenizer diff --git a/R/alignment.R b/R/alignment.R new file mode 100644 index 0000000..815f9f1 --- /dev/null +++ b/R/alignment.R @@ -0,0 +1,306 @@ +#' Word-Level Timestamp Alignment +#' +#' DTW-based alignment of tokens to audio frames using cross-attention weights. + +#' Compute Word-Level Timestamps +#' +#' Use cross-attention weights and DTW alignment to assign timestamps +#' to individual words. +#' +#' @param tokens Integer vector of generated token IDs +#' @param cross_attn_weights List of cross-attention weight tensors per decode step +#' @param tokenizer Whisper tokenizer +#' @param config Model configuration +#' @param time_offset Time offset in seconds (for chunked audio) +#' @param sample_begin Index where content tokens start in generated +#' @return Data frame with word, start, end columns +compute_word_timestamps <- function( + tokens, + cross_attn_weights, + tokenizer, + config, + time_offset = 0, + sample_begin = 4L +) { + if (length(cross_attn_weights) == 0) { + return(data.frame(word = character(0), start = numeric(0), end = numeric(0))) + } + + special <- whisper_special_tokens(config$model_name) + + # Content tokens only (after initial prompt tokens) + content_tokens <- tokens[seq_len(length(tokens)) > sample_begin] + + # Filter out timestamp tokens for word alignment + text_mask <- content_tokens < special$timestamp_begin + if (sum(text_mask) == 0) { + return(data.frame(word = character(0), start = numeric(0), end = numeric(0))) + } + + # Get alignment heads for this model + alignment_heads <- config$alignment_heads + if (is.null(alignment_heads)) { + # Fallback: use all heads from last half of layers + n_layer <- config$n_text_layer + n_head <- config$n_text_head + half <- n_layer %/% 2L + layers <- seq(half, n_layer - 1L) + heads <- seq(0L, n_head - 1L) + alignment_heads <- as.matrix(expand.grid(layer = layers, head = heads)) + } + + # Build attention matrix: average over alignment heads and decode steps + # Each element of cross_attn_weights is a list of per-layer tensors + # Each tensor has shape (batch, n_head, 1, n_audio_ctx) + n_steps <- length(cross_attn_weights) + n_audio_ctx <- config$n_audio_ctx + + # Stack attention from alignment heads across all steps + # Result: (n_steps, n_audio_ctx) averaged over alignment heads + attn_matrix <- matrix(0, nrow = n_steps, ncol = n_audio_ctx) + + for (step in seq_len(n_steps)) { + step_weights <- cross_attn_weights[[step]] + n_heads_used <- 0 + + for (h in seq_len(nrow(alignment_heads))) { + layer_idx <- alignment_heads[h, 1] + 1L # 0-indexed to 1-indexed + head_idx <- alignment_heads[h, 2] + 1L + + if (layer_idx <= length(step_weights) && !is.null(step_weights[[layer_idx]])) { + # step_weights[[layer_idx]] is (batch, n_head, seq_len, src_len) + w <- step_weights[[layer_idx]] + # Extract specific head, last query position + head_attn <- as.array(w[1, head_idx, w$size(3), ]$cpu()) + attn_matrix[step, ] <- attn_matrix[step, ] + head_attn + n_heads_used <- n_heads_used + 1L + } + } + + if (n_heads_used > 0) { + attn_matrix[step, ] <- attn_matrix[step, ] / n_heads_used + } + } + + # Determine audio frame range from timestamp tokens (if present) + # Find the last timestamp token to cap the attention matrix + max_frame <- n_audio_ctx + for (j in rev(seq_along(content_tokens))) { + if (content_tokens[j] >= special$timestamp_begin) { + ts_seconds <- (content_tokens[j] - special$timestamp_begin) * 0.02 + max_frame <- min(n_audio_ctx, max(1L, as.integer(ts_seconds / 0.02))) + break + } + } + + # Keep only text token rows (not timestamp tokens) + text_indices <- which(text_mask) + if (length(text_indices) == 0) { + return(data.frame(word = character(0), start = numeric(0), end = numeric(0))) + } + text_attn <- attn_matrix[text_indices, 1:max_frame, drop = FALSE] + + # Apply median filter along time axis for smoothing + text_attn <- apply(text_attn, 1, function(row) medfilt1(row, 7L)) + text_attn <- t(text_attn) + + # Convert to cost matrix for DTW: -log(attn + eps) + cost <- -log(text_attn + 1e-10) + + # Run DTW alignment + path <- dtw_align(cost) + + # Map path to per-token frame ranges + text_token_ids <- content_tokens[text_indices] + n_text <- length(text_token_ids) + token_frames <- vector("list", n_text) + for (k in seq_len(n_text)) { + token_frames[[k]] <- integer(0) + } + + for (p in seq_len(nrow(path))) { + tok_idx <- path[p, 1] + frame_idx <- path[p, 2] + token_frames[[tok_idx]] <- c(token_frames[[tok_idx]], frame_idx) + } + + # Convert frame indices to timestamps + # Each audio frame = 2 mel frames (due to conv stride 2) + # Each mel frame = WHISPER_HOP_LENGTH / WHISPER_SAMPLE_RATE seconds + seconds_per_frame <- 0.02 # 1500 frames = 30 seconds + + token_starts <- numeric(n_text) + token_ends <- numeric(n_text) + for (k in seq_len(n_text)) { + frames <- token_frames[[k]] + if (length(frames) > 0) { + token_starts[k] <- (min(frames) - 1) * seconds_per_frame + time_offset + token_ends[k] <- max(frames) * seconds_per_frame + time_offset + } else if (k > 1) { + # Inherit from previous token + token_starts[k] <- token_ends[k - 1] + token_ends[k] <- token_starts[k] + } else { + token_starts[k] <- time_offset + token_ends[k] <- time_offset + } + } + + # Group subword tokens into words + group_into_words(text_token_ids, token_starts, token_ends, tokenizer) +} + +#' Group Subword Tokens into Words +#' +#' Merge BPE subword tokens into whole words with timestamps. +#' +#' @param token_ids Integer vector of text token IDs +#' @param starts Numeric vector of token start times +#' @param ends Numeric vector of token end times +#' @param tokenizer Whisper tokenizer +#' @return Data frame with word, start, end columns +group_into_words <- function( + token_ids, + starts, + ends, + tokenizer +) { + if (length(token_ids) == 0) { + return(data.frame(word = character(0), start = numeric(0), end = numeric(0))) + } + + # Decode each token individually + token_texts <- vapply(token_ids, function(id) tokenizer$decode(id), character(1)) + + # Group by word boundaries (space at start of token = new word) + words <- list() + current_word <- "" + current_start <- starts[1] + current_end <- ends[1] + + for (i in seq_along(token_texts)) { + text <- token_texts[i] + is_new_word <- grepl("^\\s", text) || i == 1L + + if (is_new_word && nchar(trimws(current_word)) > 0 && i > 1L) { + # Save previous word + words <- c(words, list(data.frame( + word = trimws(current_word), + start = current_start, + end = current_end, + stringsAsFactors = FALSE + ))) + current_word <- text + current_start <- starts[i] + current_end <- ends[i] + } else { + current_word <- paste0(current_word, text) + current_end <- ends[i] + } + } + + # Save last word + if (nchar(trimws(current_word)) > 0) { + words <- c(words, list(data.frame( + word = trimws(current_word), + start = current_start, + end = current_end, + stringsAsFactors = FALSE + ))) + } + + if (length(words) == 0) { + return(data.frame(word = character(0), start = numeric(0), end = numeric(0))) + } + + do.call(rbind, words) +} + +#' DTW Alignment +#' +#' Standard dynamic time warping on a cost matrix. +#' +#' @param cost Numeric matrix (n_tokens x n_frames) +#' @return Integer matrix with 2 columns (token_idx, frame_idx), 1-indexed +dtw_align <- function(cost) { + n <- nrow(cost) + m <- ncol(cost) + + # Accumulated cost matrix + D <- matrix(Inf, nrow = n, ncol = m) + D[1, 1] <- cost[1, 1] + + # First row: can only come from the left + for (j in 2:m) { + D[1, j] <- D[1, j - 1] + cost[1, j] + } + + # First column: can only come from above + for (i in 2:n) { + D[i, 1] <- D[i - 1, 1] + cost[i, 1] + } + + # Fill rest + for (i in 2:n) { + for (j in 2:m) { + D[i, j] <- cost[i, j] + min(D[i - 1, j], D[i, j - 1], D[i - 1, j - 1]) + } + } + + # Backtrack to find optimal path + path <- matrix(0L, nrow = n + m, ncol = 2) + k <- 1L + i <- n + j <- m + path[k, ] <- c(i, j) + + + while (i > 1 || j > 1) { + k <- k + 1L + if (i == 1) { + j <- j - 1L + } else if (j == 1) { + i <- i - 1L + } else { + candidates <- c(D[i - 1, j - 1], D[i - 1, j], D[i, j - 1]) + step <- which.min(candidates) + if (step == 1L) { + i <- i - 1L + j <- j - 1L + } else if (step == 2L) { + i <- i - 1L + } else { + j <- j - 1L + } + } + path[k, ] <- c(i, j) + } + + # Reverse path (was built backwards) + path <- path[k:1, , drop = FALSE] + path +} + +#' 1D Median Filter +#' +#' Apply a sliding median filter to a numeric vector. +#' +#' @param x Numeric vector +#' @param width Filter width (must be odd) +#' @return Filtered numeric vector of same length +medfilt1 <- function(x, width = 7L) { + n <- length(x) + if (n == 0) return(x) + + # Ensure odd width + if (width %% 2L == 0L) width <- width + 1L + half <- width %/% 2L + + result <- numeric(n) + for (i in seq_len(n)) { + lo <- max(1L, i - half) + hi <- min(n, i + half) + result[i] <- median(x[lo:hi]) + } + result +} diff --git a/R/config.R b/R/config.R index 2e75030..4c45fe7 100644 --- a/R/config.R +++ b/R/config.R @@ -30,7 +30,11 @@ whisper_config <- function(model = "tiny") { n_text_state = 384L, n_text_head = 6L, n_text_layer = 4L, - hf_repo = "openai/whisper-tiny" + hf_repo = "openai/whisper-tiny", + # (layer, head) pairs for cross-attention alignment (0-indexed) + alignment_heads = matrix(c( + 1, 0, 2, 0, 2, 5, 3, 0, 3, 1, 3, 2, 3, 3, 3, 4 + ), ncol = 2, byrow = TRUE) ), base = list( n_mels = 80L, @@ -43,7 +47,10 @@ whisper_config <- function(model = "tiny") { n_text_state = 512L, n_text_head = 8L, n_text_layer = 6L, - hf_repo = "openai/whisper-base" + hf_repo = "openai/whisper-base", + alignment_heads = matrix(c( + 3, 1, 4, 2, 4, 3, 4, 7, 5, 1, 5, 2, 5, 4, 5, 6 + ), ncol = 2, byrow = TRUE) ), small = list( n_mels = 80L, @@ -56,7 +63,11 @@ whisper_config <- function(model = "tiny") { n_text_state = 768L, n_text_head = 12L, n_text_layer = 12L, - hf_repo = "openai/whisper-small" + hf_repo = "openai/whisper-small", + alignment_heads = matrix(c( + 6, 6, 7, 0, 7, 3, 7, 8, 8, 2, 8, 5, 8, 7, 9, 0, 9, 4, 9, 8, + 9, 10, 10, 0, 10, 1, 10, 2, 10, 3, 10, 6, 10, 11, 11, 2, 11, 4 + ), ncol = 2, byrow = TRUE) ), medium = list( n_mels = 80L, @@ -69,7 +80,10 @@ whisper_config <- function(model = "tiny") { n_text_state = 1024L, n_text_head = 16L, n_text_layer = 24L, - hf_repo = "openai/whisper-medium" + hf_repo = "openai/whisper-medium", + alignment_heads = matrix(c( + 13, 15, 15, 4, 15, 15, 16, 1, 20, 0, 23, 4 + ), ncol = 2, byrow = TRUE) ), `large-v3` = list( n_mels = 128L, @@ -82,7 +96,10 @@ whisper_config <- function(model = "tiny") { n_text_state = 1280L, n_text_head = 20L, n_text_layer = 32L, - hf_repo = "openai/whisper-large-v3" + hf_repo = "openai/whisper-large-v3", + alignment_heads = matrix(c( + 9, 19, 11, 2, 11, 4, 11, 17, 22, 7, 22, 11, 22, 17, 23, 2, 23, 15 + ), ncol = 2, byrow = TRUE) ) ) diff --git a/R/decoder.R b/R/decoder.R index 035a27f..f129c17 100644 --- a/R/decoder.R +++ b/R/decoder.R @@ -36,7 +36,8 @@ whisper_decoder_layer <- torch::nn_module( x, xa, mask = NULL, - kv_cache = NULL + kv_cache = NULL, + need_weights = FALSE ) { # x: decoder input (batch, seq_len, n_state) # xa: encoder output (batch, src_len, n_state) @@ -51,23 +52,25 @@ whisper_decoder_layer <- torch::nn_module( cross_kv_cache <- kv_cache$cross } - # Self-attention with causal mask + # Self-attention with causal mask (never need weights for self-attn) attn_result <- self$attn(self$attn_ln(x), mask = mask, kv_cache = self_kv_cache) x <- x + attn_result$output new_self_kv <- attn_result$kv_cache - # Cross-attention to encoder output - cross_result <- self$cross_attn(self$cross_attn_ln(x), xa = xa, kv_cache = cross_kv_cache) + # Cross-attention to encoder output (need_weights for DTW alignment) + cross_result <- self$cross_attn(self$cross_attn_ln(x), xa = xa, + kv_cache = cross_kv_cache, need_weights = need_weights) x <- x + cross_result$output new_cross_kv <- cross_result$kv_cache # FFN x <- x + self$mlp(self$mlp_ln(x)) - # Return output and updated caches + # Return output, updated caches, and optionally cross-attention weights list( output = x, - kv_cache = list(self = new_self_kv, cross = new_cross_kv) + kv_cache = list(self = new_self_kv, cross = new_cross_kv), + cross_attn_weights = cross_result$attn_weights ) } ) @@ -126,7 +129,8 @@ whisper_decoder <- torch::nn_module( forward = function( x, xa, - kv_cache = NULL + kv_cache = NULL, + need_weights = FALSE ) { # x: token ids (batch, seq_len) # xa: encoder output (batch, src_len, n_state) @@ -168,6 +172,7 @@ whisper_decoder <- torch::nn_module( } new_kv_cache <- vector("list", self$n_layer) + cross_attn_weights <- if (need_weights) vector("list", self$n_layer) else NULL # Transformer layers for (i in seq_len(self$n_layer)) { @@ -176,16 +181,21 @@ whisper_decoder <- torch::nn_module( } else { layer_cache <- NULL } - result <- self$blocks[[i]](x, xa, mask = mask, kv_cache = layer_cache) + result <- self$blocks[[i]](x, xa, mask = mask, kv_cache = layer_cache, + need_weights = need_weights) x <- result$output new_kv_cache[[i]] <- result$kv_cache + if (need_weights) { + cross_attn_weights[[i]] <- result$cross_attn_weights + } } # Final layer norm x <- self$ln(x) - # Return hidden states and updated KV cache - list(hidden_states = x, kv_cache = new_kv_cache) + # Return hidden states, updated KV cache, and optionally cross-attention weights + list(hidden_states = x, kv_cache = new_kv_cache, + cross_attn_weights = cross_attn_weights) }, get_logits = function(hidden_states) { diff --git a/R/encoder.R b/R/encoder.R index 1aef998..284356a 100644 --- a/R/encoder.R +++ b/R/encoder.R @@ -30,7 +30,8 @@ whisper_attention <- torch::nn_module( x, xa = NULL, mask = NULL, - kv_cache = NULL + kv_cache = NULL, + need_weights = FALSE ) { # x: (batch, seq_len, n_state) # xa: optional cross-attention input (batch, src_len, n_state) @@ -70,9 +71,24 @@ whisper_attention <- torch::nn_module( } } - # Scaled dot-product attention (dispatches to FlashAttention on GPU) - attn_output <- torch:::torch_scaled_dot_product_attention( - q, k, v, is_causal = !is.null(mask)) + attn_weights <- NULL + + if (need_weights) { + # Manual attention to capture weights for DTW alignment + # q: (batch, n_head, seq_len, head_dim) + # k: (batch, n_head, src_len, head_dim) + scale <- sqrt(self$head_dim) + attn_scores <- torch::torch_matmul(q, k$transpose(3L, 4L)) / scale + if (!is.null(mask)) { + attn_scores <- attn_scores + mask + } + attn_weights <- torch::nnf_softmax(attn_scores, dim = -1L) + attn_output <- torch::torch_matmul(attn_weights, v) + } else { + # Scaled dot-product attention (dispatches to FlashAttention on GPU) + attn_output <- torch:::torch_scaled_dot_product_attention( + q, k, v, is_causal = !is.null(mask)) + } # Reshape back: (batch, n_head, seq_len, head_dim) -> (batch, seq_len, n_state) attn_output <- attn_output$transpose(2L, 3L)$contiguous() @@ -81,8 +97,9 @@ whisper_attention <- torch::nn_module( # Output projection output <- self$out(attn_output) - # Return output and new KV cache (reshaped k, v for efficient caching) - list(output = output, kv_cache = list(k = k, v = v)) + # Return output, KV cache, and optionally attention weights + list(output = output, kv_cache = list(k = k, v = v), + attn_weights = attn_weights) }, reshape_for_attention = function( diff --git a/R/model.R b/R/model.R index 8fecdd6..3b639ce 100644 --- a/R/model.R +++ b/R/model.R @@ -48,16 +48,24 @@ whisper_model <- torch::nn_module( decode = function( tokens, encoder_output, - kv_cache = NULL + kv_cache = NULL, + need_weights = FALSE ) { # Just decode (with pre-computed encoder output) - decoder_result <- self$decoder(tokens, encoder_output, kv_cache = kv_cache) + decoder_result <- self$decoder(tokens, encoder_output, kv_cache = kv_cache, + need_weights = need_weights) logits <- self$decoder$get_logits(decoder_result$hidden_states) - list( + result <- list( logits = logits, kv_cache = decoder_result$kv_cache ) + + if (need_weights) { + result$cross_attn_weights <- decoder_result$cross_attn_weights + } + + result } ) diff --git a/R/tokenizer.R b/R/tokenizer.R index 406c24e..e74ca34 100644 --- a/R/tokenizer.R +++ b/R/tokenizer.R @@ -214,20 +214,56 @@ tokenizer_decode <- function( text } +#' Build Reverse Byte Decoder +#' +#' Inverts the GPT-2 byte-to-unicode mapping used by byte_to_token(). +#' Cached after first call. +#' +#' @return Named character vector mapping unicode codepoint (as string) to +#' raw byte value +build_byte_decoder <- function() { + if (!is.null(.tokenizer_cache$byte_decoder)) { + return(.tokenizer_cache$byte_decoder) + } + decoder <- integer(256) + names(decoder) <- character(256) + for (b in 0:255) { + cp <- utf8ToInt(byte_to_token(b)) + names(decoder)[b + 1L] <- as.character(cp) + decoder[b + 1L] <- b + } + .tokenizer_cache$byte_decoder <- decoder + decoder +} + +# Module-level cache for byte decoder +.tokenizer_cache <- new.env(parent = emptyenv()) + #' Decode BPE Bytes Back to Text #' +#' Reverses the GPT-2 byte-level encoding, converting unicode tokens +#' back to raw UTF-8 bytes. +#' #' @param text Text with BPE byte tokens -#' @return Decoded text +#' @return Decoded UTF-8 text decode_bpe_bytes <- function(text) { - # Replace special space token - text <- gsub("\u0120", " ", text, fixed = TRUE) - - # Handle other byte-level encodings - # This is a simplified version - full implementation would - - # reverse the byte_to_token mapping + if (nchar(text) == 0) return(text) + + decoder <- build_byte_decoder() + codepoints <- utf8ToInt(text) + bytes <- raw(length(codepoints)) + + for (i in seq_along(codepoints)) { + cp_str <- as.character(codepoints[i]) + idx <- match(cp_str, names(decoder)) + if (!is.na(idx)) { + bytes[i] <- as.raw(decoder[idx]) + } else { + bytes[i] <- charToRaw("?") + } + } - text + rawToChar(bytes) } #' Ensure Tokenizer Files are Downloaded diff --git a/R/transcribe.R b/R/transcribe.R index a881a62..6ebc579 100644 --- a/R/transcribe.R +++ b/R/transcribe.R @@ -51,9 +51,12 @@ whisper_pipeline <- function( file, language = "en", task = "transcribe", + timestamps = FALSE, + word_timestamps = FALSE, verbose = TRUE ) { pipeline_transcribe(pipe, file, language = language, task = task, + timestamps = timestamps, word_timestamps = word_timestamps, verbose = verbose) } @@ -74,6 +77,8 @@ print.whisper_pipeline <- function(x, ...) { #' @param file Path to audio file. #' @param language Language code. #' @param task Task type. +#' @param timestamps Return segment-level timestamps. +#' @param word_timestamps Return word-level timestamps. #' @param verbose Print progress. #' @return List with text, language, and metadata. #' @keywords internal @@ -82,20 +87,27 @@ pipeline_transcribe <- function( file, language = "en", task = "transcribe", + timestamps = FALSE, + word_timestamps = FALSE, verbose = TRUE ) { if (!file.exists(file)) stop("Audio file not found: ", file) + # word_timestamps implies timestamps + if (word_timestamps) timestamps <- TRUE + duration <- audio_duration(file) if (verbose) message("Audio duration: ", round(duration, 1), "s") if (duration <= WHISPER_CHUNK_LENGTH) { result <- transcribe_chunk(file, pipe$model, pipe$tokenizer, pipe$config, - language = language, task = task, + language = language, task = task, timestamps = timestamps, + word_timestamps = word_timestamps, device = pipe$device, dtype = pipe$dtype, verbose = verbose) } else { result <- transcribe_long(file, pipe$model, pipe$tokenizer, pipe$config, - language = language, task = task, + language = language, task = task, timestamps = timestamps, + word_timestamps = word_timestamps, device = pipe$device, dtype = pipe$dtype, verbose = verbose) } @@ -116,10 +128,15 @@ pipeline_transcribe <- function( #' @param model Model name: "tiny", "base", "small", "medium", "large-v3" #' @param language Language code (e.g., "en", "es"). NULL for auto-detection. #' @param task "transcribe" or "translate" (translate to English) +#' @param timestamps If TRUE, return segment-level timestamps +#' @param word_timestamps If TRUE, return word-level timestamps (implies timestamps) #' @param device Device: "auto", "cpu", "cuda" #' @param dtype Data type: "auto", "float16", "float32" #' @param verbose Print progress messages -#' @return List with text, language, and metadata +#' @return List with text, language, and metadata. When \code{timestamps=TRUE}, +#' includes \code{segments} data.frame with start, end, text columns. When +#' \code{word_timestamps=TRUE}, includes \code{words} data.frame with word, +#' start, end columns. #' @export #' @examples #' \donttest{ @@ -129,6 +146,10 @@ pipeline_transcribe <- function( #' result <- transcribe(audio_file, model = "tiny") #' result$text #' +#' # With timestamps +#' result <- transcribe(audio_file, model = "tiny", timestamps = TRUE) +#' result$segments +#' #' # Translate Spanish audio to English #' spanish_file <- system.file("audio", "allende.mp3", package = "whisper") #' result <- transcribe(spanish_file, model = "tiny", @@ -141,13 +162,17 @@ transcribe <- function( model = "tiny", language = "en", task = "transcribe", + timestamps = FALSE, + word_timestamps = FALSE, device = "auto", dtype = "auto", verbose = TRUE ) { pipe <- whisper_pipeline(model, device = device, dtype = dtype, download = TRUE, verbose = verbose) - pipe$transcribe(file, language = language, task = task, verbose = verbose) + pipe$transcribe(file, language = language, task = task, + timestamps = timestamps, word_timestamps = word_timestamps, + verbose = verbose) } #' Transcribe Single Chunk @@ -169,6 +194,9 @@ transcribe_chunk <- function( config, language = "en", task = "transcribe", + timestamps = FALSE, + word_timestamps = FALSE, + time_offset = 0, device, dtype, verbose = TRUE @@ -178,7 +206,8 @@ transcribe_chunk <- function( mel <- audio_to_mel(file, n_mels = config$n_mels, device = device, dtype = dtype) # Get initial decoder tokens (use model name for correct special token IDs) - initial_tokens <- get_initial_tokens(language, task, model = config$model_name) + initial_tokens <- get_initial_tokens(language, task, + model = config$model_name, timestamps = timestamps) tokens <- torch::torch_tensor(matrix(initial_tokens, nrow = 1), dtype = torch::torch_long(), device = device) @@ -191,21 +220,44 @@ transcribe_chunk <- function( # Decode with greedy search if (verbose) message("Decoding...") - generated <- greedy_decode(model, encoder_output, tokens, tokenizer, + decode_result <- greedy_decode(model, encoder_output, tokens, tokenizer, max_length = config$n_text_ctx, + timestamps = timestamps, word_timestamps = word_timestamps, device = device) - # Decode tokens to text - text <- tokenizer$decode(generated) - - # Clean up text - text <- clean_text(text) + # greedy_decode returns list when timestamps/word_timestamps, integer vector otherwise + if (is.list(decode_result)) { + generated <- decode_result$tokens + cross_attn_weights <- decode_result$cross_attn_weights + } else { + generated <- decode_result + cross_attn_weights <- NULL + } # Build result - list( - text = text, - language = language - ) + if (timestamps) { + # Extract segments from timestamp tokens + segments <- extract_segments(generated, tokenizer, time_offset = time_offset) + text <- paste(segments$text, collapse = " ") + text <- clean_text(text) + result <- list(text = text, language = language, segments = segments) + } else { + text <- tokenizer$decode(generated) + text <- clean_text(text) + result <- list(text = text, language = language) + } + + # Word-level timestamps via cross-attention DTW + if (word_timestamps && !is.null(cross_attn_weights)) { + special <- whisper_special_tokens(config$model_name) + sample_begin <- length(initial_tokens) + words <- compute_word_timestamps(generated, cross_attn_weights, + tokenizer, config, time_offset = time_offset, + sample_begin = sample_begin) + result$words <- words + } + + result } #' Greedy Decoding @@ -215,22 +267,32 @@ transcribe_chunk <- function( #' @param initial_tokens Initial token tensor #' @param tokenizer Tokenizer #' @param max_length Maximum output length +#' @param timestamps Whether to allow timestamp tokens +#' @param word_timestamps Whether to collect cross-attention weights #' @param device Device -#' @return Integer vector of generated tokens +#' @return Integer vector of generated tokens, or list with tokens and +#' cross_attn_weights when word_timestamps is TRUE greedy_decode <- function( model, encoder_output, initial_tokens, tokenizer, max_length = 448L, + timestamps = FALSE, + word_timestamps = FALSE, device ) { # Use model-specific special tokens special <- whisper_special_tokens(tokenizer$model) generated <- as.integer(as.array(initial_tokens$cpu())) + sample_begin <- length(generated) kv_cache <- NULL tokens <- initial_tokens + need_weights <- word_timestamps + + # Collect cross-attention weights for word timestamps + all_cross_attn <- if (word_timestamps) list() else NULL torch::with_no_grad({ for (i in seq_len(max_length)) { @@ -238,13 +300,20 @@ greedy_decode <- function( if (length(generated) >= max_length) break # Get next token logits - result <- model$decode(tokens, encoder_output, kv_cache = kv_cache) + result <- model$decode(tokens, encoder_output, kv_cache = kv_cache, + need_weights = need_weights) logits <- result$logits kv_cache <- result$kv_cache - # Get last position logits (R uses 1-based indexing, not negative indexing like Python) - seq_len <- logits$size(2) - next_logits <- logits[, seq_len,]# (batch, vocab) + # Get last position logits (R uses 1-based indexing) + seq_len_val <- logits$size(2) + next_logits <- logits[, seq_len_val,]# (batch, vocab) + + # Apply timestamp logit rules when timestamps are enabled + if (timestamps) { + next_logits <- apply_timestamp_rules(next_logits, generated, + special, sample_begin) + } # Greedy: take argmax (subtract 1 because R torch argmax returns 1-indexed) next_token <- next_logits$argmax(dim = - 1L) @@ -258,6 +327,11 @@ greedy_decode <- function( # Append token generated <- c(generated, next_token_id) + # Collect cross-attention weights for this step + if (word_timestamps && !is.null(result$cross_attn_weights)) { + all_cross_attn <- c(all_cross_attn, list(result$cross_attn_weights)) + } + # Prepare next input (decoder expects 0-indexed token IDs, adds 1 internally) tokens <- torch::torch_tensor(matrix(next_token_id, nrow = 1L), dtype = torch::torch_long(), @@ -265,7 +339,133 @@ greedy_decode <- function( } }) - generated + if (word_timestamps) { + list(tokens = generated, cross_attn_weights = all_cross_attn) + } else if (timestamps) { + list(tokens = generated, cross_attn_weights = NULL) + } else { + generated + } +} + +#' Apply Timestamp Token Rules +#' +#' Enforce Whisper timestamp generation constraints on logits. +#' +#' @param logits Logit tensor (1, vocab) or (vocab) +#' @param generated Integer vector of tokens generated so far +#' @param special Special token IDs +#' @param sample_begin Index where content tokens start in generated +#' @return Modified logits tensor +apply_timestamp_rules <- function( + logits, + generated, + special, + sample_begin +) { + # Content tokens are those generated after the initial prompt tokens + content_tokens <- generated[seq_len(length(generated)) > sample_begin] + ts_begin <- special$timestamp_begin + # Max timestamp: 30.00s = 1500 steps of 0.02s + + max_ts <- ts_begin + 1500L + + # Determine if logits are 1D (vocab) or 2D (batch, vocab) + is_2d <- logits$dim() == 2L + + # Rule 1: First content token must be a timestamp (<|0.00|>) + if (length(content_tokens) == 0) { + # Suppress all non-timestamp tokens + if (is_2d) { + logits[, 1:ts_begin] <- -Inf + } else { + logits[1:ts_begin] <- -Inf + } + # Only allow <|0.00|> (first timestamp) + if (max_ts > ts_begin + 1L) { + if (is_2d) { + logits[, (ts_begin + 2L):logits$size(2)] <- -Inf + } else { + logits[(ts_begin + 2L):logits$size(1)] <- -Inf + } + } + return(logits) + } + + # Find last timestamp in content tokens + last_ts <- NA + for (j in rev(seq_along(content_tokens))) { + if (content_tokens[j] >= ts_begin) { + last_ts <- content_tokens[j] + break + } + } + + # Count consecutive timestamps at end + n_consecutive_ts <- 0L + for (j in rev(seq_along(content_tokens))) { + if (content_tokens[j] >= ts_begin) { + n_consecutive_ts <- n_consecutive_ts + 1L + } else { + break + } + } + + # Rule 2: After a closing timestamp (2 consecutive), next must be timestamp or EOT + if (n_consecutive_ts >= 2L && n_consecutive_ts %% 2L == 0L) { + # Suppress all text tokens, allow only timestamps and EOT + if (is_2d) { + # Suppress everything except EOT and timestamps + mask <- rep(-Inf, logits$size(2)) + mask[special$eot + 1L] <- 0 # Allow EOT (1-indexed) + mask[(ts_begin + 1L):length(mask)] <- 0 # Allow timestamps + logits <- logits + torch::torch_tensor(matrix(mask, nrow = 1), + device = logits$device, dtype = logits$dtype) + } else { + mask <- rep(-Inf, logits$size(1)) + mask[special$eot + 1L] <- 0 + mask[(ts_begin + 1L):length(mask)] <- 0 + logits <- logits + torch::torch_tensor(mask, + device = logits$device, dtype = logits$dtype) + } + } + + # Rule 3: After a single timestamp (odd count), next must be non-timestamp (text) + if (n_consecutive_ts >= 1L && n_consecutive_ts %% 2L == 1L) { + # Suppress timestamps + n_vocab <- if (is_2d) logits$size(2) else logits$size(1) + if (n_vocab > ts_begin) { + if (is_2d) { + logits[, (ts_begin + 1L):n_vocab] <- -Inf + } else { + logits[(ts_begin + 1L):n_vocab] <- -Inf + } + } + } + + # Rule 4: No backwards timestamps (suppress tokens below last emitted timestamp) + if (!is.na(last_ts) && last_ts >= ts_begin) { + suppress_up_to <- last_ts # Suppress all timestamps <= last_ts + if (suppress_up_to >= ts_begin) { + if (is_2d) { + logits[, (ts_begin + 1L):(suppress_up_to + 1L)] <- -Inf + } else { + logits[(ts_begin + 1L):(suppress_up_to + 1L)] <- -Inf + } + } + } + + # Rule 5: Cap max timestamp at 30.00s + n_vocab <- if (is_2d) logits$size(2) else logits$size(1) + if (n_vocab > max_ts + 1L) { + if (is_2d) { + logits[, (max_ts + 2L):n_vocab] <- -Inf + } else { + logits[(max_ts + 2L):n_vocab] <- -Inf + } + } + + logits } #' Transcribe Long Audio @@ -289,51 +489,68 @@ transcribe_long <- function( config, language, task, + timestamps = FALSE, + word_timestamps = FALSE, device, dtype, verbose ) { # Split into chunks - chunks <- split_audio(file, chunk_length = 30, overlap = 1) + chunk_length <- 30 + overlap <- 1 + hop_seconds <- chunk_length - overlap + chunks <- split_audio(file, chunk_length = chunk_length, overlap = overlap) if (verbose) message("Processing ", length(chunks), " chunks...") all_text <- character(length(chunks)) + all_segments <- if (timestamps) list() else NULL + all_words <- if (word_timestamps) list() else NULL for (i in seq_along(chunks)) { if (verbose) message(" Chunk ", i, "/", length(chunks)) + time_offset <- (i - 1) * hop_seconds - # Convert chunk to mel - mel <- audio_to_mel(chunks[[i]], n_mels = config$n_mels, - device = device, dtype = dtype) + # Transcribe chunk with time offset + chunk_result <- transcribe_chunk(chunks[[i]], model, tokenizer, config, + language = language, task = task, timestamps = timestamps, + word_timestamps = word_timestamps, time_offset = time_offset, + device = device, dtype = dtype, verbose = FALSE) - # Get initial tokens (use model name for correct special token IDs) - initial_tokens <- get_initial_tokens(language, task, model = config$model_name) - tokens <- torch::torch_tensor(matrix(initial_tokens, nrow = 1), - dtype = torch::torch_long(), - device = device) + all_text[i] <- chunk_result$text - # Encode - torch::with_no_grad({ - encoder_output <- model$encode(mel) - }) - - # Decode - generated <- greedy_decode(model, encoder_output, tokens, tokenizer, - max_length = config$n_text_ctx, - device = device) + if (timestamps && !is.null(chunk_result$segments) && nrow(chunk_result$segments) > 0) { + all_segments <- c(all_segments, list(chunk_result$segments)) + } - # Decode to text - text <- tokenizer$decode(generated) - text <- clean_text(text) - all_text[i] <- text + if (word_timestamps && !is.null(chunk_result$words) && nrow(chunk_result$words) > 0) { + all_words <- c(all_words, list(chunk_result$words)) + } } # Combine results - list( + result <- list( text = paste(all_text, collapse = " "), language = language ) + + if (timestamps) { + result$segments <- if (length(all_segments) > 0) { + do.call(rbind, all_segments) + } else { + data.frame(start = numeric(0), end = numeric(0), text = character(0)) + } + } + + if (word_timestamps) { + result$words <- if (length(all_words) > 0) { + do.call(rbind, all_words) + } else { + data.frame(word = character(0), start = numeric(0), end = numeric(0)) + } + } + + result } #' Clean Transcribed Text diff --git a/README.md b/README.md index ef39e8d..43c15db 100644 --- a/README.md +++ b/README.md @@ -70,15 +70,45 @@ result <- transcribe(allende, language = "es") result <- transcribe(allende, task = "translate", language = "es", model = "small") ``` +## Timestamps + +```r +# Segment-level timestamps +result <- transcribe("audio.wav", timestamps = TRUE) +result$segments +#> start end text +#> 1 0.00 7.44 Ask not what your country... + +# Word-level timestamps (via cross-attention DTW alignment) +result <- transcribe("audio.wav", word_timestamps = TRUE) +result$words +#> word start end +#> 1 Ask 0.00 0.54 +#> 2 not 0.54 1.16 +#> 3 what 1.16 2.46 +#> ... +``` + +Both work with the pipeline API for repeated transcription: + +```r +pipe <- whisper_pipeline("tiny") +result <- pipe$transcribe("audio.wav", word_timestamps = TRUE) +result$words +``` + ## Models -| Model | Parameters | Size | English WER | -|-------|------------|------|-------------| -| tiny | 39M | 151 MB | ~9% | -| base | 74M | 290 MB | ~7% | -| small | 244M | 967 MB | ~5% | -| medium | 769M | 3.0 GB | ~4% | -| large-v3 | 1550M | 6.2 GB | ~3% | +| Model | Parameters | Disk (fp32) | English WER | Peak VRAM (CUDA fp16) | Speed* | +|-------|------------|-------------|-------------|----------------------|--------| +| tiny | 39M | 151 MB | ~9% | 564 MiB | 5.5s | +| base | 74M | 290 MB | ~7% | 734 MiB | 1.9s | +| small | 244M | 967 MB | ~5% | 1,454 MiB | 3.6s | +| medium | 769M | 3.0 GB | ~4% | 3,580 MiB | 8.6s | +| large-v3 | 1550M | 6.2 GB | ~3% | 3,892 MiB | 16.7s | + +*Speed measured on RTX 5060 Ti transcribing a 17s audio clip with `word_timestamps = TRUE`. +Peak VRAM includes ~364 MiB torch CUDA context overhead. Models are downloaded from HuggingFace and cached in `~/.cache/huggingface/` unless otherwise specified. diff --git a/inst/tinytest/test_alignment.R b/inst/tinytest/test_alignment.R new file mode 100644 index 0000000..19bdddf --- /dev/null +++ b/inst/tinytest/test_alignment.R @@ -0,0 +1,80 @@ +# Tests for DTW alignment and word timestamps + +# Skip if torch not fully installed +if (!requireNamespace("torch", quietly = TRUE) || + !torch::torch_is_installed()) { + exit_file("torch not fully installed") +} + +# Test medfilt1 +# median(1,5)=3, median(1,5,1)=1, median(5,1,1)=1, median(1,1,1)=1, median(1,1)=1 +expect_equal(whisper:::medfilt1(c(1, 5, 1, 1, 1), 3L), c(3, 1, 1, 1, 1)) +# median(3,1)=2, median(3,1,4)=3, median(1,4,1)=1, median(4,1,5)=4, median(1,5)=3 +expect_equal(whisper:::medfilt1(c(3, 1, 4, 1, 5), 3L), c(2, 3, 1, 4, 3)) +expect_equal(whisper:::medfilt1(numeric(0), 3L), numeric(0)) +expect_equal(whisper:::medfilt1(c(42), 3L), c(42)) + +# Test dtw_align with known cost matrix +# Simple diagonal cost: cheapest path should follow diagonal +cost <- matrix(10, nrow = 3, ncol = 3) +diag(cost) <- 0 +path <- whisper:::dtw_align(cost) + +# Path should visit all 3 token positions +expect_true(all(1:3 %in% path[, 1])) +# Path should be monotonically increasing +expect_true(all(diff(path[, 1]) >= 0)) +expect_true(all(diff(path[, 2]) >= 0)) +# Path starts at (1,1) and ends at (n,m) +expect_equal(path[1, ], c(1L, 1L)) +expect_equal(path[nrow(path), ], c(3L, 3L)) + +# Test dtw_align with rectangular cost matrix +cost2 <- matrix(1, nrow = 2, ncol = 5) +cost2[1, 2] <- 0 # Token 1 maps to frame 2 +cost2[2, 4] <- 0 # Token 2 maps to frame 4 +path2 <- whisper:::dtw_align(cost2) +expect_equal(path2[1, ], c(1L, 1L)) +expect_equal(path2[nrow(path2), ], c(2L, 5L)) + +# Test group_into_words +if (model_exists("tiny")) { + tok <- whisper_tokenizer("tiny") + + # Simple case: known token IDs for space-delimited words + hello_ids <- tok$encode(" hello") + world_ids <- tok$encode(" world") + all_ids <- c(hello_ids, world_ids) + + starts <- seq(0, by = 0.5, length.out = length(all_ids)) + ends <- starts + 0.5 + + words <- whisper:::group_into_words(all_ids, starts, ends, tok) + expect_true(is.data.frame(words)) + expect_true(nrow(words) >= 2) + expect_true("word" %in% names(words)) + expect_true("start" %in% names(words)) + expect_true("end" %in% names(words)) +} + +# Integration test with word timestamps (requires model + audio) +if (at_home() && model_exists("tiny")) { + audio_file <- system.file("audio", "jfk.mp3", package = "whisper") + if (file.exists(audio_file)) { + result <- transcribe(audio_file, model = "tiny", word_timestamps = TRUE, + verbose = FALSE) + expect_true("words" %in% names(result)) + expect_true(is.data.frame(result$words)) + expect_true(nrow(result$words) > 0) + expect_true("word" %in% names(result$words)) + expect_true("start" %in% names(result$words)) + expect_true("end" %in% names(result$words)) + # Word timestamps should be monotonically non-decreasing + if (nrow(result$words) > 1) { + starts <- result$words$start + expect_true(all(diff(starts) >= -0.01)) # allow tiny float imprecision + } + # word_timestamps implies segments are also present + expect_true("segments" %in% names(result)) + } +} diff --git a/inst/tinytest/test_tokenizer.R b/inst/tinytest/test_tokenizer.R index 0cf73fd..0434bb7 100644 --- a/inst/tinytest/test_tokenizer.R +++ b/inst/tinytest/test_tokenizer.R @@ -26,3 +26,28 @@ expect_equal(whisper:::decode_timestamp(50364L), 0) expect_equal(whisper:::decode_timestamp(50365L), 0.02) expect_equal(whisper:::decode_timestamp(50414L), 1.0) # 50 * 0.02 +# Test byte_to_token / decode_bpe_bytes round-trip +# ASCII +ascii_text <- "hello world" +ascii_bpe <- paste(sapply(charToRaw(ascii_text), function(b) { + whisper:::byte_to_token(as.integer(b)) +}), collapse = "") +expect_equal(whisper:::decode_bpe_bytes(ascii_bpe), ascii_text) + +# Non-ASCII (accented Latin) +utf8_text <- "caf\u00e9" +utf8_bpe <- paste(sapply(charToRaw(utf8_text), function(b) { + whisper:::byte_to_token(as.integer(b)) +}), collapse = "") +expect_equal(whisper:::decode_bpe_bytes(utf8_bpe), utf8_text) + +# CJK characters +cjk_text <- "\u4e16\u754c" +cjk_bpe <- paste(sapply(charToRaw(cjk_text), function(b) { + whisper:::byte_to_token(as.integer(b)) +}), collapse = "") +expect_equal(whisper:::decode_bpe_bytes(cjk_bpe), cjk_text) + +# Empty string +expect_equal(whisper:::decode_bpe_bytes(""), "") + diff --git a/inst/tinytest/test_transcribe.R b/inst/tinytest/test_transcribe.R index f899b2c..4c9d146 100644 --- a/inst/tinytest/test_transcribe.R +++ b/inst/tinytest/test_transcribe.R @@ -16,11 +16,75 @@ expect_true("tiny" %in% models) expect_true("small" %in% models) expect_true("large-v3" %in% models) -# Integration test (requires model download - skip if not available) -if (FALSE) { - # This would be enabled for manual testing - result <- transcribe("test.wav", model = "tiny") - expect_true("text" %in% names(result)) - expect_true("language" %in% names(result)) +# Test extract_segments with synthetic token sequences +# Use tiny model token IDs (timestamp_begin = 50364) +special <- whisper:::whisper_special_tokens("tiny") +ts_begin <- special$timestamp_begin + +# Build a synthetic token sequence: <|0.00|> hello world <|2.50|> <|2.50|> foo <|5.00|> +# Token IDs: ts_begin + 0, some text tokens, ts_begin + 125, ts_begin + 125, text, ts_begin + 250 +if (model_exists("tiny")) { + tok <- whisper_tokenizer("tiny") + + # Encode "hello" to get real token IDs + hello_ids <- tok$encode("hello") + foo_ids <- tok$encode("foo") + + synthetic_tokens <- c( + special$sot, special$lang_en, special$transcribe, # prompt tokens + ts_begin, # <|0.00|> + hello_ids, # "hello" + ts_begin + 125L, # <|2.50|> + ts_begin + 125L, # <|2.50|> + foo_ids, # "foo" + ts_begin + 250L # <|5.00|> + ) + + segments <- whisper:::extract_segments(synthetic_tokens, tok, time_offset = 0) + expect_true(is.data.frame(segments)) + expect_true(nrow(segments) >= 2) + expect_equal(segments$start[1], 0) + expect_equal(segments$end[1], 2.5) + expect_equal(segments$start[2], 2.5) + expect_equal(segments$end[2], 5.0) + + # Test with time_offset + segments2 <- whisper:::extract_segments(synthetic_tokens, tok, time_offset = 10) + expect_equal(segments2$start[1], 10) + expect_equal(segments2$end[1], 12.5) +} + +# Test apply_timestamp_rules with synthetic logits +n_vocab <- special$timestamp_begin + 1501L # full vocab +logits_base <- torch::torch_zeros(1, n_vocab) + +# Rule 1: First content token must be <|0.00|> +logits1 <- logits_base$clone() +logits1[1, ts_begin + 1] <- 5.0 # <|0.00|> should be allowed (1-indexed) +logits1[1, 100] <- 10.0 # text token should be suppressed +result1 <- whisper:::apply_timestamp_rules(logits1, integer(4), special, 4L) +# Text tokens should be suppressed (-Inf) +expect_true(as.numeric(result1[1, 100]$item()) == -Inf) +# <|0.00|> should be allowed +expect_true(as.numeric(result1[1, ts_begin + 1]$item()) > -Inf) + +# Integration test with timestamps (requires model + audio) +if (at_home() && model_exists("tiny")) { + audio_file <- system.file("audio", "jfk.mp3", package = "whisper") + if (file.exists(audio_file)) { + result <- transcribe(audio_file, model = "tiny", timestamps = TRUE, + verbose = FALSE) + expect_true("segments" %in% names(result)) + expect_true(is.data.frame(result$segments)) + expect_true(nrow(result$segments) > 0) + expect_true("start" %in% names(result$segments)) + expect_true("end" %in% names(result$segments)) + expect_true("text" %in% names(result$segments)) + # Segments should have monotonic start times + if (nrow(result$segments) > 1) { + starts <- result$segments$start + expect_true(all(diff(starts) >= 0)) + } + } } diff --git a/man/apply_timestamp_rules.Rd b/man/apply_timestamp_rules.Rd new file mode 100644 index 0000000..5a1e643 --- /dev/null +++ b/man/apply_timestamp_rules.Rd @@ -0,0 +1,22 @@ +% tinyrox says don't edit this manually, but it can't stop you! +\name{apply_timestamp_rules} +\alias{apply_timestamp_rules} +\title{Apply Timestamp Token Rules} +\usage{ +apply_timestamp_rules(logits, generated, special, sample_begin) +} +\arguments{ +\item{logits}{Logit tensor (1, vocab) or (vocab)} + +\item{generated}{Integer vector of tokens generated so far} + +\item{special}{Special token IDs} + +\item{sample_begin}{Index where content tokens start in generated} +} +\value{ +Modified logits tensor +} +\description{ +Enforce Whisper timestamp generation constraints on logits. +} diff --git a/man/compute_word_timestamps.Rd b/man/compute_word_timestamps.Rd new file mode 100644 index 0000000..17198ff --- /dev/null +++ b/man/compute_word_timestamps.Rd @@ -0,0 +1,36 @@ +% tinyrox says don't edit this manually, but it can't stop you! +\name{compute_word_timestamps} +\alias{compute_word_timestamps} +\title{Word-Level Timestamp Alignment} +\usage{ +compute_word_timestamps( + tokens, + cross_attn_weights, + tokenizer, + config, + time_offset = 0, + sample_begin = 4L +) +} +\arguments{ +\item{tokens}{Integer vector of generated token IDs} + +\item{cross_attn_weights}{List of cross-attention weight tensors per decode step} + +\item{tokenizer}{Whisper tokenizer} + +\item{config}{Model configuration} + +\item{time_offset}{Time offset in seconds (for chunked audio)} + +\item{sample_begin}{Index where content tokens start in generated} +} +\value{ +Data frame with word, start, end columns +} +\description{ +DTW-based alignment of tokens to audio frames using cross-attention weights. +Compute Word-Level Timestamps +Use cross-attention weights and DTW alignment to assign timestamps +to individual words. +} diff --git a/man/dtw_align.Rd b/man/dtw_align.Rd new file mode 100644 index 0000000..d94134d --- /dev/null +++ b/man/dtw_align.Rd @@ -0,0 +1,16 @@ +% tinyrox says don't edit this manually, but it can't stop you! +\name{dtw_align} +\alias{dtw_align} +\title{DTW Alignment} +\usage{ +dtw_align(cost) +} +\arguments{ +\item{cost}{Numeric matrix (n_tokens x n_frames)} +} +\value{ +Integer matrix with 2 columns (token_idx, frame_idx), 1-indexed +} +\description{ +Standard dynamic time warping on a cost matrix. +} diff --git a/man/greedy_decode.Rd b/man/greedy_decode.Rd index 52411e1..8f18b70 100644 --- a/man/greedy_decode.Rd +++ b/man/greedy_decode.Rd @@ -9,6 +9,8 @@ greedy_decode( initial_tokens, tokenizer, max_length = 448L, + timestamps = FALSE, + word_timestamps = FALSE, device ) } @@ -23,10 +25,15 @@ greedy_decode( \item{max_length}{Maximum output length} +\item{timestamps}{Whether to allow timestamp tokens} + +\item{word_timestamps}{Whether to collect cross-attention weights} + \item{device}{Device} } \value{ -Integer vector of generated tokens +Integer vector of generated tokens, or list with tokens and + cross_attn_weights when word_timestamps is TRUE } \description{ Greedy Decoding diff --git a/man/group_into_words.Rd b/man/group_into_words.Rd new file mode 100644 index 0000000..babdfd3 --- /dev/null +++ b/man/group_into_words.Rd @@ -0,0 +1,22 @@ +% tinyrox says don't edit this manually, but it can't stop you! +\name{group_into_words} +\alias{group_into_words} +\title{Group Subword Tokens into Words} +\usage{ +group_into_words(token_ids, starts, ends, tokenizer) +} +\arguments{ +\item{token_ids}{Integer vector of text token IDs} + +\item{starts}{Numeric vector of token start times} + +\item{ends}{Numeric vector of token end times} + +\item{tokenizer}{Whisper tokenizer} +} +\value{ +Data frame with word, start, end columns +} +\description{ +Merge BPE subword tokens into whole words with timestamps. +} diff --git a/man/medfilt1.Rd b/man/medfilt1.Rd new file mode 100644 index 0000000..255b284 --- /dev/null +++ b/man/medfilt1.Rd @@ -0,0 +1,18 @@ +% tinyrox says don't edit this manually, but it can't stop you! +\name{medfilt1} +\alias{medfilt1} +\title{1D Median Filter} +\usage{ +medfilt1(x, width = 7L) +} +\arguments{ +\item{x}{Numeric vector} + +\item{width}{Filter width (must be odd)} +} +\value{ +Filtered numeric vector of same length +} +\description{ +Apply a sliding median filter to a numeric vector. +} diff --git a/man/pipeline_transcribe.Rd b/man/pipeline_transcribe.Rd index 9509d8d..d80942e 100644 --- a/man/pipeline_transcribe.Rd +++ b/man/pipeline_transcribe.Rd @@ -8,6 +8,8 @@ pipeline_transcribe( file, language = "en", task = "transcribe", + timestamps = FALSE, + word_timestamps = FALSE, verbose = TRUE ) } @@ -20,6 +22,10 @@ pipeline_transcribe( \item{task}{Task type.} +\item{timestamps}{Return segment-level timestamps.} + +\item{word_timestamps}{Return word-level timestamps.} + \item{verbose}{Print progress.} } \value{ diff --git a/man/transcribe.Rd b/man/transcribe.Rd index c7cbf7f..2926c36 100644 --- a/man/transcribe.Rd +++ b/man/transcribe.Rd @@ -8,6 +8,8 @@ transcribe( model = "tiny", language = "en", task = "transcribe", + timestamps = FALSE, + word_timestamps = FALSE, device = "auto", dtype = "auto", verbose = TRUE @@ -22,6 +24,10 @@ transcribe( \item{task}{"transcribe" or "translate" (translate to English)} +\item{timestamps}{If TRUE, return segment-level timestamps} + +\item{word_timestamps}{If TRUE, return word-level timestamps (implies timestamps)} + \item{device}{Device: "auto", "cpu", "cuda"} \item{dtype}{Data type: "auto", "float16", "float32"} @@ -29,7 +35,10 @@ transcribe( \item{verbose}{Print progress messages} } \value{ -List with text, language, and metadata +List with text, language, and metadata. When \code{timestamps=TRUE}, + includes \code{segments} data.frame with start, end, text columns. When + \code{word_timestamps=TRUE}, includes \code{words} data.frame with word, + start, end columns. } \description{ Transcribe speech from an audio file using Whisper. @@ -44,6 +53,10 @@ if (model_exists("tiny")) { result <- transcribe(audio_file, model = "tiny") result$text + # With timestamps + result <- transcribe(audio_file, model = "tiny", timestamps = TRUE) + result$segments + # Translate Spanish audio to English spanish_file <- system.file("audio", "allende.mp3", package = "whisper") result <- transcribe(spanish_file, model = "tiny", diff --git a/man/transcribe_chunk.Rd b/man/transcribe_chunk.Rd index 29bd652..259560c 100644 --- a/man/transcribe_chunk.Rd +++ b/man/transcribe_chunk.Rd @@ -10,6 +10,9 @@ transcribe_chunk( config, language = "en", task = "transcribe", + timestamps = FALSE, + word_timestamps = FALSE, + time_offset = 0, device, dtype, verbose = TRUE diff --git a/man/transcribe_long.Rd b/man/transcribe_long.Rd index 170c0b1..1592516 100644 --- a/man/transcribe_long.Rd +++ b/man/transcribe_long.Rd @@ -10,6 +10,8 @@ transcribe_long( config, language, task, + timestamps = FALSE, + word_timestamps = FALSE, device, dtype, verbose