|
| 1 | +#' Word-Level Timestamp Alignment |
| 2 | +#' |
| 3 | +#' DTW-based alignment of tokens to audio frames using cross-attention weights. |
| 4 | + |
| 5 | +#' Compute Word-Level Timestamps |
| 6 | +#' |
| 7 | +#' Use cross-attention weights and DTW alignment to assign timestamps |
| 8 | +#' to individual words. |
| 9 | +#' |
| 10 | +#' @param tokens Integer vector of generated token IDs |
| 11 | +#' @param cross_attn_weights List of cross-attention weight tensors per decode step |
| 12 | +#' @param tokenizer Whisper tokenizer |
| 13 | +#' @param config Model configuration |
| 14 | +#' @param time_offset Time offset in seconds (for chunked audio) |
| 15 | +#' @param sample_begin Index where content tokens start in generated |
| 16 | +#' @return Data frame with word, start, end columns |
| 17 | +compute_word_timestamps <- function( |
| 18 | + tokens, |
| 19 | + cross_attn_weights, |
| 20 | + tokenizer, |
| 21 | + config, |
| 22 | + time_offset = 0, |
| 23 | + sample_begin = 4L |
| 24 | +) { |
| 25 | + if (length(cross_attn_weights) == 0) { |
| 26 | + return(data.frame(word = character(0), start = numeric(0), end = numeric(0))) |
| 27 | + } |
| 28 | + |
| 29 | + special <- whisper_special_tokens(config$model_name) |
| 30 | + |
| 31 | + # Content tokens only (after initial prompt tokens) |
| 32 | + content_tokens <- tokens[seq_len(length(tokens)) > sample_begin] |
| 33 | + |
| 34 | + # Filter out timestamp tokens for word alignment |
| 35 | + text_mask <- content_tokens < special$timestamp_begin |
| 36 | + if (sum(text_mask) == 0) { |
| 37 | + return(data.frame(word = character(0), start = numeric(0), end = numeric(0))) |
| 38 | + } |
| 39 | + |
| 40 | + # Get alignment heads for this model |
| 41 | + alignment_heads <- config$alignment_heads |
| 42 | + if (is.null(alignment_heads)) { |
| 43 | + # Fallback: use all heads from last half of layers |
| 44 | + n_layer <- config$n_text_layer |
| 45 | + n_head <- config$n_text_head |
| 46 | + half <- n_layer %/% 2L |
| 47 | + layers <- seq(half, n_layer - 1L) |
| 48 | + heads <- seq(0L, n_head - 1L) |
| 49 | + alignment_heads <- as.matrix(expand.grid(layer = layers, head = heads)) |
| 50 | + } |
| 51 | + |
| 52 | + # Build attention matrix: average over alignment heads and decode steps |
| 53 | + # Each element of cross_attn_weights is a list of per-layer tensors |
| 54 | + # Each tensor has shape (batch, n_head, 1, n_audio_ctx) |
| 55 | + n_steps <- length(cross_attn_weights) |
| 56 | + n_audio_ctx <- config$n_audio_ctx |
| 57 | + |
| 58 | + # Stack attention from alignment heads across all steps |
| 59 | + # Result: (n_steps, n_audio_ctx) averaged over alignment heads |
| 60 | + attn_matrix <- matrix(0, nrow = n_steps, ncol = n_audio_ctx) |
| 61 | + |
| 62 | + for (step in seq_len(n_steps)) { |
| 63 | + step_weights <- cross_attn_weights[[step]] |
| 64 | + n_heads_used <- 0 |
| 65 | + |
| 66 | + for (h in seq_len(nrow(alignment_heads))) { |
| 67 | + layer_idx <- alignment_heads[h, 1] + 1L # 0-indexed to 1-indexed |
| 68 | + head_idx <- alignment_heads[h, 2] + 1L |
| 69 | + |
| 70 | + if (layer_idx <= length(step_weights) && !is.null(step_weights[[layer_idx]])) { |
| 71 | + # step_weights[[layer_idx]] is (batch, n_head, seq_len, src_len) |
| 72 | + w <- step_weights[[layer_idx]] |
| 73 | + # Extract specific head, last query position |
| 74 | + head_attn <- as.array(w[1, head_idx, w$size(3), ]$cpu()) |
| 75 | + attn_matrix[step, ] <- attn_matrix[step, ] + head_attn |
| 76 | + n_heads_used <- n_heads_used + 1L |
| 77 | + } |
| 78 | + } |
| 79 | + |
| 80 | + if (n_heads_used > 0) { |
| 81 | + attn_matrix[step, ] <- attn_matrix[step, ] / n_heads_used |
| 82 | + } |
| 83 | + } |
| 84 | + |
| 85 | + # Determine audio frame range from timestamp tokens (if present) |
| 86 | + # Find the last timestamp token to cap the attention matrix |
| 87 | + max_frame <- n_audio_ctx |
| 88 | + for (j in rev(seq_along(content_tokens))) { |
| 89 | + if (content_tokens[j] >= special$timestamp_begin) { |
| 90 | + ts_seconds <- (content_tokens[j] - special$timestamp_begin) * 0.02 |
| 91 | + max_frame <- min(n_audio_ctx, max(1L, as.integer(ts_seconds / 0.02))) |
| 92 | + break |
| 93 | + } |
| 94 | + } |
| 95 | + |
| 96 | + # Keep only text token rows (not timestamp tokens) |
| 97 | + text_indices <- which(text_mask) |
| 98 | + if (length(text_indices) == 0) { |
| 99 | + return(data.frame(word = character(0), start = numeric(0), end = numeric(0))) |
| 100 | + } |
| 101 | + text_attn <- attn_matrix[text_indices, 1:max_frame, drop = FALSE] |
| 102 | + |
| 103 | + # Apply median filter along time axis for smoothing |
| 104 | + text_attn <- apply(text_attn, 1, function(row) medfilt1(row, 7L)) |
| 105 | + text_attn <- t(text_attn) |
| 106 | + |
| 107 | + # Convert to cost matrix for DTW: -log(attn + eps) |
| 108 | + cost <- -log(text_attn + 1e-10) |
| 109 | + |
| 110 | + # Run DTW alignment |
| 111 | + path <- dtw_align(cost) |
| 112 | + |
| 113 | + # Map path to per-token frame ranges |
| 114 | + text_token_ids <- content_tokens[text_indices] |
| 115 | + n_text <- length(text_token_ids) |
| 116 | + token_frames <- vector("list", n_text) |
| 117 | + for (k in seq_len(n_text)) { |
| 118 | + token_frames[[k]] <- integer(0) |
| 119 | + } |
| 120 | + |
| 121 | + for (p in seq_len(nrow(path))) { |
| 122 | + tok_idx <- path[p, 1] |
| 123 | + frame_idx <- path[p, 2] |
| 124 | + token_frames[[tok_idx]] <- c(token_frames[[tok_idx]], frame_idx) |
| 125 | + } |
| 126 | + |
| 127 | + # Convert frame indices to timestamps |
| 128 | + # Each audio frame = 2 mel frames (due to conv stride 2) |
| 129 | + # Each mel frame = WHISPER_HOP_LENGTH / WHISPER_SAMPLE_RATE seconds |
| 130 | + seconds_per_frame <- 0.02 # 1500 frames = 30 seconds |
| 131 | + |
| 132 | + token_starts <- numeric(n_text) |
| 133 | + token_ends <- numeric(n_text) |
| 134 | + for (k in seq_len(n_text)) { |
| 135 | + frames <- token_frames[[k]] |
| 136 | + if (length(frames) > 0) { |
| 137 | + token_starts[k] <- (min(frames) - 1) * seconds_per_frame + time_offset |
| 138 | + token_ends[k] <- max(frames) * seconds_per_frame + time_offset |
| 139 | + } else if (k > 1) { |
| 140 | + # Inherit from previous token |
| 141 | + token_starts[k] <- token_ends[k - 1] |
| 142 | + token_ends[k] <- token_starts[k] |
| 143 | + } else { |
| 144 | + token_starts[k] <- time_offset |
| 145 | + token_ends[k] <- time_offset |
| 146 | + } |
| 147 | + } |
| 148 | + |
| 149 | + # Group subword tokens into words |
| 150 | + group_into_words(text_token_ids, token_starts, token_ends, tokenizer) |
| 151 | +} |
| 152 | + |
| 153 | +#' Group Subword Tokens into Words |
| 154 | +#' |
| 155 | +#' Merge BPE subword tokens into whole words with timestamps. |
| 156 | +#' |
| 157 | +#' @param token_ids Integer vector of text token IDs |
| 158 | +#' @param starts Numeric vector of token start times |
| 159 | +#' @param ends Numeric vector of token end times |
| 160 | +#' @param tokenizer Whisper tokenizer |
| 161 | +#' @return Data frame with word, start, end columns |
| 162 | +group_into_words <- function( |
| 163 | + token_ids, |
| 164 | + starts, |
| 165 | + ends, |
| 166 | + tokenizer |
| 167 | +) { |
| 168 | + if (length(token_ids) == 0) { |
| 169 | + return(data.frame(word = character(0), start = numeric(0), end = numeric(0))) |
| 170 | + } |
| 171 | + |
| 172 | + # Decode each token individually |
| 173 | + token_texts <- vapply(token_ids, function(id) tokenizer$decode(id), character(1)) |
| 174 | + |
| 175 | + # Group by word boundaries (space at start of token = new word) |
| 176 | + words <- list() |
| 177 | + current_word <- "" |
| 178 | + current_start <- starts[1] |
| 179 | + current_end <- ends[1] |
| 180 | + |
| 181 | + for (i in seq_along(token_texts)) { |
| 182 | + text <- token_texts[i] |
| 183 | + is_new_word <- grepl("^\\s", text) || i == 1L |
| 184 | + |
| 185 | + if (is_new_word && nchar(trimws(current_word)) > 0 && i > 1L) { |
| 186 | + # Save previous word |
| 187 | + words <- c(words, list(data.frame( |
| 188 | + word = trimws(current_word), |
| 189 | + start = current_start, |
| 190 | + end = current_end, |
| 191 | + stringsAsFactors = FALSE |
| 192 | + ))) |
| 193 | + current_word <- text |
| 194 | + current_start <- starts[i] |
| 195 | + current_end <- ends[i] |
| 196 | + } else { |
| 197 | + current_word <- paste0(current_word, text) |
| 198 | + current_end <- ends[i] |
| 199 | + } |
| 200 | + } |
| 201 | + |
| 202 | + # Save last word |
| 203 | + if (nchar(trimws(current_word)) > 0) { |
| 204 | + words <- c(words, list(data.frame( |
| 205 | + word = trimws(current_word), |
| 206 | + start = current_start, |
| 207 | + end = current_end, |
| 208 | + stringsAsFactors = FALSE |
| 209 | + ))) |
| 210 | + } |
| 211 | + |
| 212 | + if (length(words) == 0) { |
| 213 | + return(data.frame(word = character(0), start = numeric(0), end = numeric(0))) |
| 214 | + } |
| 215 | + |
| 216 | + do.call(rbind, words) |
| 217 | +} |
| 218 | + |
| 219 | +#' DTW Alignment |
| 220 | +#' |
| 221 | +#' Standard dynamic time warping on a cost matrix. |
| 222 | +#' |
| 223 | +#' @param cost Numeric matrix (n_tokens x n_frames) |
| 224 | +#' @return Integer matrix with 2 columns (token_idx, frame_idx), 1-indexed |
| 225 | +dtw_align <- function(cost) { |
| 226 | + n <- nrow(cost) |
| 227 | + m <- ncol(cost) |
| 228 | + |
| 229 | + # Accumulated cost matrix |
| 230 | + D <- matrix(Inf, nrow = n, ncol = m) |
| 231 | + D[1, 1] <- cost[1, 1] |
| 232 | + |
| 233 | + # First row: can only come from the left |
| 234 | + for (j in 2:m) { |
| 235 | + D[1, j] <- D[1, j - 1] + cost[1, j] |
| 236 | + } |
| 237 | + |
| 238 | + # First column: can only come from above |
| 239 | + for (i in 2:n) { |
| 240 | + D[i, 1] <- D[i - 1, 1] + cost[i, 1] |
| 241 | + } |
| 242 | + |
| 243 | + # Fill rest |
| 244 | + for (i in 2:n) { |
| 245 | + for (j in 2:m) { |
| 246 | + D[i, j] <- cost[i, j] + min(D[i - 1, j], D[i, j - 1], D[i - 1, j - 1]) |
| 247 | + } |
| 248 | + } |
| 249 | + |
| 250 | + # Backtrack to find optimal path |
| 251 | + path <- matrix(0L, nrow = n + m, ncol = 2) |
| 252 | + k <- 1L |
| 253 | + i <- n |
| 254 | + j <- m |
| 255 | + path[k, ] <- c(i, j) |
| 256 | + |
| 257 | + |
| 258 | + while (i > 1 || j > 1) { |
| 259 | + k <- k + 1L |
| 260 | + if (i == 1) { |
| 261 | + j <- j - 1L |
| 262 | + } else if (j == 1) { |
| 263 | + i <- i - 1L |
| 264 | + } else { |
| 265 | + candidates <- c(D[i - 1, j - 1], D[i - 1, j], D[i, j - 1]) |
| 266 | + step <- which.min(candidates) |
| 267 | + if (step == 1L) { |
| 268 | + i <- i - 1L |
| 269 | + j <- j - 1L |
| 270 | + } else if (step == 2L) { |
| 271 | + i <- i - 1L |
| 272 | + } else { |
| 273 | + j <- j - 1L |
| 274 | + } |
| 275 | + } |
| 276 | + path[k, ] <- c(i, j) |
| 277 | + } |
| 278 | + |
| 279 | + # Reverse path (was built backwards) |
| 280 | + path <- path[k:1, , drop = FALSE] |
| 281 | + path |
| 282 | +} |
| 283 | + |
| 284 | +#' 1D Median Filter |
| 285 | +#' |
| 286 | +#' Apply a sliding median filter to a numeric vector. |
| 287 | +#' |
| 288 | +#' @param x Numeric vector |
| 289 | +#' @param width Filter width (must be odd) |
| 290 | +#' @return Filtered numeric vector of same length |
| 291 | +medfilt1 <- function(x, width = 7L) { |
| 292 | + n <- length(x) |
| 293 | + if (n == 0) return(x) |
| 294 | + |
| 295 | + # Ensure odd width |
| 296 | + if (width %% 2L == 0L) width <- width + 1L |
| 297 | + half <- width %/% 2L |
| 298 | + |
| 299 | + result <- numeric(n) |
| 300 | + for (i in seq_len(n)) { |
| 301 | + lo <- max(1L, i - half) |
| 302 | + hi <- min(n, i + half) |
| 303 | + result[i] <- median(x[lo:hi]) |
| 304 | + } |
| 305 | + result |
| 306 | +} |
0 commit comments