Skip to content

Commit ac6cefd

Browse files
Add segment-level and word-level timestamps (#1)
* Add segment-level and word-level timestamps to transcription Segment timestamps use Whisper's built-in timestamp tokens (<|0.00|> through <|30.00|>) with logit suppression rules that enforce proper timestamp generation (forward-only, paired, capped at 30s). Word timestamps use cross-attention DTW alignment: during decoding, cross-attention weights are captured from model-specific alignment heads, then dynamic time warping maps each token to audio frames. Subword tokens are merged into words with start/end times. API: transcribe(..., timestamps=TRUE) returns segments data.frame, transcribe(..., word_timestamps=TRUE) returns words data.frame. Both work with single chunks and long audio (automatic time offsets). * Update README and CLAUDE.md with timestamp documentation * Add peak VRAM and speed benchmarks to models table * Fix UTF-8 byte decoding in tokenizer decode_bpe_bytes() was a stub that only handled the space token, causing non-ASCII characters (accented Latin, CJK, etc.) to come out garbled. Now fully reverses the GPT-2 byte-to-unicode mapping with a cached lookup table.
1 parent b4208c3 commit ac6cefd

22 files changed

Lines changed: 1065 additions & 101 deletions

CLAUDE.md

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ Audio (WAV/MP3) -> Mel Spectrogram -> Encoder (transformer) -> Decoder (cross-at
2020

2121
## Key Exports
2222

23-
- `transcribe(file, model, language)` - Main transcription function
23+
- `transcribe(file, model, language, timestamps, word_timestamps)` - Main transcription function
24+
- `whisper_pipeline(model)` - Load model once, call `$transcribe()` repeatedly
2425
- `load_whisper_model(model, device, dtype)` - Load model weights
2526
- `audio_to_mel(file, n_mels)` - Convert audio to mel spectrogram
2627
- `whisper_tokenizer()` - Get BPE tokenizer
@@ -32,7 +33,15 @@ library(whisper)
3233

3334
# Transcribe audio
3435
result <- transcribe("audio.wav", model = "tiny")
35-
print(result$text)
36+
result$text
37+
38+
# Segment timestamps (uses Whisper's built-in timestamp tokens)
39+
result <- transcribe("audio.wav", timestamps = TRUE)
40+
result$segments # data.frame(start, end, text)
41+
42+
# Word-level timestamps (cross-attention DTW alignment)
43+
result <- transcribe("audio.wav", word_timestamps = TRUE)
44+
result$words # data.frame(word, start, end)
3645
```
3746

3847
## Development
@@ -56,13 +65,14 @@ Uses safetensors format from HuggingFace:
5665

5766
## File Structure
5867

59-
- `R/transcribe.R` - Main API
68+
- `R/transcribe.R` - Main API, greedy decode, timestamp logit rules
69+
- `R/alignment.R` - DTW alignment, word timestamp computation
6070
- `R/audio.R` - Audio to mel spectrogram
61-
- `R/encoder.R` - Encoder transformer
71+
- `R/encoder.R` - Encoder transformer (with `need_weights` dual-path attention)
6272
- `R/decoder.R` - Decoder with cross-attention
6373
- `R/model.R` - Full model + weight loading
6474
- `R/tokenizer.R` - Whisper BPE tokenizer
65-
- `R/config.R` - Model configurations
75+
- `R/config.R` - Model configurations + alignment heads
6676
- `R/download.R` - HuggingFace model download
6777
- `R/devices.R` - Device/dtype management
6878

@@ -75,10 +85,12 @@ Uses safetensors format from HuggingFace:
7585
- Transcription and translation (any language to English)
7686
- All model sizes: tiny, base, small, medium, large-v3
7787
- CPU and CUDA support
88+
- Segment-level timestamps (Whisper timestamp tokens with logit suppression)
89+
- Word-level timestamps (cross-attention DTW alignment)
7890
- Pre-computed mel filterbank from official Whisper
7991
- HuggingFace model downloads via `hfhub`
8092
- KV cache for efficient incremental decoding
81-
- Long audio support (automatic chunking)
93+
- Long audio support (automatic chunking with time offsets)
8294

8395
### R torch notes
8496

@@ -88,12 +100,9 @@ Uses safetensors format from HuggingFace:
88100

89101
### Known Limitations
90102

91-
- UTF-8 encoding issues with some non-ASCII characters in output
92103
- Translation quality varies by model size (larger models work better)
93104
- No beam search (greedy decoding only)
94105

95106
### Potential Improvements
96107

97108
- Beam search decoding
98-
- Word-level timestamps (requires cross-attention analysis)
99-
- Fix UTF-8 byte decoding in tokenizer

R/alignment.R

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
#' Word-Level Timestamp Alignment
2+
#'
3+
#' DTW-based alignment of tokens to audio frames using cross-attention weights.
4+
5+
#' Compute Word-Level Timestamps
6+
#'
7+
#' Use cross-attention weights and DTW alignment to assign timestamps
8+
#' to individual words.
9+
#'
10+
#' @param tokens Integer vector of generated token IDs
11+
#' @param cross_attn_weights List of cross-attention weight tensors per decode step
12+
#' @param tokenizer Whisper tokenizer
13+
#' @param config Model configuration
14+
#' @param time_offset Time offset in seconds (for chunked audio)
15+
#' @param sample_begin Index where content tokens start in generated
16+
#' @return Data frame with word, start, end columns
17+
compute_word_timestamps <- function(
18+
tokens,
19+
cross_attn_weights,
20+
tokenizer,
21+
config,
22+
time_offset = 0,
23+
sample_begin = 4L
24+
) {
25+
if (length(cross_attn_weights) == 0) {
26+
return(data.frame(word = character(0), start = numeric(0), end = numeric(0)))
27+
}
28+
29+
special <- whisper_special_tokens(config$model_name)
30+
31+
# Content tokens only (after initial prompt tokens)
32+
content_tokens <- tokens[seq_len(length(tokens)) > sample_begin]
33+
34+
# Filter out timestamp tokens for word alignment
35+
text_mask <- content_tokens < special$timestamp_begin
36+
if (sum(text_mask) == 0) {
37+
return(data.frame(word = character(0), start = numeric(0), end = numeric(0)))
38+
}
39+
40+
# Get alignment heads for this model
41+
alignment_heads <- config$alignment_heads
42+
if (is.null(alignment_heads)) {
43+
# Fallback: use all heads from last half of layers
44+
n_layer <- config$n_text_layer
45+
n_head <- config$n_text_head
46+
half <- n_layer %/% 2L
47+
layers <- seq(half, n_layer - 1L)
48+
heads <- seq(0L, n_head - 1L)
49+
alignment_heads <- as.matrix(expand.grid(layer = layers, head = heads))
50+
}
51+
52+
# Build attention matrix: average over alignment heads and decode steps
53+
# Each element of cross_attn_weights is a list of per-layer tensors
54+
# Each tensor has shape (batch, n_head, 1, n_audio_ctx)
55+
n_steps <- length(cross_attn_weights)
56+
n_audio_ctx <- config$n_audio_ctx
57+
58+
# Stack attention from alignment heads across all steps
59+
# Result: (n_steps, n_audio_ctx) averaged over alignment heads
60+
attn_matrix <- matrix(0, nrow = n_steps, ncol = n_audio_ctx)
61+
62+
for (step in seq_len(n_steps)) {
63+
step_weights <- cross_attn_weights[[step]]
64+
n_heads_used <- 0
65+
66+
for (h in seq_len(nrow(alignment_heads))) {
67+
layer_idx <- alignment_heads[h, 1] + 1L # 0-indexed to 1-indexed
68+
head_idx <- alignment_heads[h, 2] + 1L
69+
70+
if (layer_idx <= length(step_weights) && !is.null(step_weights[[layer_idx]])) {
71+
# step_weights[[layer_idx]] is (batch, n_head, seq_len, src_len)
72+
w <- step_weights[[layer_idx]]
73+
# Extract specific head, last query position
74+
head_attn <- as.array(w[1, head_idx, w$size(3), ]$cpu())
75+
attn_matrix[step, ] <- attn_matrix[step, ] + head_attn
76+
n_heads_used <- n_heads_used + 1L
77+
}
78+
}
79+
80+
if (n_heads_used > 0) {
81+
attn_matrix[step, ] <- attn_matrix[step, ] / n_heads_used
82+
}
83+
}
84+
85+
# Determine audio frame range from timestamp tokens (if present)
86+
# Find the last timestamp token to cap the attention matrix
87+
max_frame <- n_audio_ctx
88+
for (j in rev(seq_along(content_tokens))) {
89+
if (content_tokens[j] >= special$timestamp_begin) {
90+
ts_seconds <- (content_tokens[j] - special$timestamp_begin) * 0.02
91+
max_frame <- min(n_audio_ctx, max(1L, as.integer(ts_seconds / 0.02)))
92+
break
93+
}
94+
}
95+
96+
# Keep only text token rows (not timestamp tokens)
97+
text_indices <- which(text_mask)
98+
if (length(text_indices) == 0) {
99+
return(data.frame(word = character(0), start = numeric(0), end = numeric(0)))
100+
}
101+
text_attn <- attn_matrix[text_indices, 1:max_frame, drop = FALSE]
102+
103+
# Apply median filter along time axis for smoothing
104+
text_attn <- apply(text_attn, 1, function(row) medfilt1(row, 7L))
105+
text_attn <- t(text_attn)
106+
107+
# Convert to cost matrix for DTW: -log(attn + eps)
108+
cost <- -log(text_attn + 1e-10)
109+
110+
# Run DTW alignment
111+
path <- dtw_align(cost)
112+
113+
# Map path to per-token frame ranges
114+
text_token_ids <- content_tokens[text_indices]
115+
n_text <- length(text_token_ids)
116+
token_frames <- vector("list", n_text)
117+
for (k in seq_len(n_text)) {
118+
token_frames[[k]] <- integer(0)
119+
}
120+
121+
for (p in seq_len(nrow(path))) {
122+
tok_idx <- path[p, 1]
123+
frame_idx <- path[p, 2]
124+
token_frames[[tok_idx]] <- c(token_frames[[tok_idx]], frame_idx)
125+
}
126+
127+
# Convert frame indices to timestamps
128+
# Each audio frame = 2 mel frames (due to conv stride 2)
129+
# Each mel frame = WHISPER_HOP_LENGTH / WHISPER_SAMPLE_RATE seconds
130+
seconds_per_frame <- 0.02 # 1500 frames = 30 seconds
131+
132+
token_starts <- numeric(n_text)
133+
token_ends <- numeric(n_text)
134+
for (k in seq_len(n_text)) {
135+
frames <- token_frames[[k]]
136+
if (length(frames) > 0) {
137+
token_starts[k] <- (min(frames) - 1) * seconds_per_frame + time_offset
138+
token_ends[k] <- max(frames) * seconds_per_frame + time_offset
139+
} else if (k > 1) {
140+
# Inherit from previous token
141+
token_starts[k] <- token_ends[k - 1]
142+
token_ends[k] <- token_starts[k]
143+
} else {
144+
token_starts[k] <- time_offset
145+
token_ends[k] <- time_offset
146+
}
147+
}
148+
149+
# Group subword tokens into words
150+
group_into_words(text_token_ids, token_starts, token_ends, tokenizer)
151+
}
152+
153+
#' Group Subword Tokens into Words
154+
#'
155+
#' Merge BPE subword tokens into whole words with timestamps.
156+
#'
157+
#' @param token_ids Integer vector of text token IDs
158+
#' @param starts Numeric vector of token start times
159+
#' @param ends Numeric vector of token end times
160+
#' @param tokenizer Whisper tokenizer
161+
#' @return Data frame with word, start, end columns
162+
group_into_words <- function(
163+
token_ids,
164+
starts,
165+
ends,
166+
tokenizer
167+
) {
168+
if (length(token_ids) == 0) {
169+
return(data.frame(word = character(0), start = numeric(0), end = numeric(0)))
170+
}
171+
172+
# Decode each token individually
173+
token_texts <- vapply(token_ids, function(id) tokenizer$decode(id), character(1))
174+
175+
# Group by word boundaries (space at start of token = new word)
176+
words <- list()
177+
current_word <- ""
178+
current_start <- starts[1]
179+
current_end <- ends[1]
180+
181+
for (i in seq_along(token_texts)) {
182+
text <- token_texts[i]
183+
is_new_word <- grepl("^\\s", text) || i == 1L
184+
185+
if (is_new_word && nchar(trimws(current_word)) > 0 && i > 1L) {
186+
# Save previous word
187+
words <- c(words, list(data.frame(
188+
word = trimws(current_word),
189+
start = current_start,
190+
end = current_end,
191+
stringsAsFactors = FALSE
192+
)))
193+
current_word <- text
194+
current_start <- starts[i]
195+
current_end <- ends[i]
196+
} else {
197+
current_word <- paste0(current_word, text)
198+
current_end <- ends[i]
199+
}
200+
}
201+
202+
# Save last word
203+
if (nchar(trimws(current_word)) > 0) {
204+
words <- c(words, list(data.frame(
205+
word = trimws(current_word),
206+
start = current_start,
207+
end = current_end,
208+
stringsAsFactors = FALSE
209+
)))
210+
}
211+
212+
if (length(words) == 0) {
213+
return(data.frame(word = character(0), start = numeric(0), end = numeric(0)))
214+
}
215+
216+
do.call(rbind, words)
217+
}
218+
219+
#' DTW Alignment
220+
#'
221+
#' Standard dynamic time warping on a cost matrix.
222+
#'
223+
#' @param cost Numeric matrix (n_tokens x n_frames)
224+
#' @return Integer matrix with 2 columns (token_idx, frame_idx), 1-indexed
225+
dtw_align <- function(cost) {
226+
n <- nrow(cost)
227+
m <- ncol(cost)
228+
229+
# Accumulated cost matrix
230+
D <- matrix(Inf, nrow = n, ncol = m)
231+
D[1, 1] <- cost[1, 1]
232+
233+
# First row: can only come from the left
234+
for (j in 2:m) {
235+
D[1, j] <- D[1, j - 1] + cost[1, j]
236+
}
237+
238+
# First column: can only come from above
239+
for (i in 2:n) {
240+
D[i, 1] <- D[i - 1, 1] + cost[i, 1]
241+
}
242+
243+
# Fill rest
244+
for (i in 2:n) {
245+
for (j in 2:m) {
246+
D[i, j] <- cost[i, j] + min(D[i - 1, j], D[i, j - 1], D[i - 1, j - 1])
247+
}
248+
}
249+
250+
# Backtrack to find optimal path
251+
path <- matrix(0L, nrow = n + m, ncol = 2)
252+
k <- 1L
253+
i <- n
254+
j <- m
255+
path[k, ] <- c(i, j)
256+
257+
258+
while (i > 1 || j > 1) {
259+
k <- k + 1L
260+
if (i == 1) {
261+
j <- j - 1L
262+
} else if (j == 1) {
263+
i <- i - 1L
264+
} else {
265+
candidates <- c(D[i - 1, j - 1], D[i - 1, j], D[i, j - 1])
266+
step <- which.min(candidates)
267+
if (step == 1L) {
268+
i <- i - 1L
269+
j <- j - 1L
270+
} else if (step == 2L) {
271+
i <- i - 1L
272+
} else {
273+
j <- j - 1L
274+
}
275+
}
276+
path[k, ] <- c(i, j)
277+
}
278+
279+
# Reverse path (was built backwards)
280+
path <- path[k:1, , drop = FALSE]
281+
path
282+
}
283+
284+
#' 1D Median Filter
285+
#'
286+
#' Apply a sliding median filter to a numeric vector.
287+
#'
288+
#' @param x Numeric vector
289+
#' @param width Filter width (must be odd)
290+
#' @return Filtered numeric vector of same length
291+
medfilt1 <- function(x, width = 7L) {
292+
n <- length(x)
293+
if (n == 0) return(x)
294+
295+
# Ensure odd width
296+
if (width %% 2L == 0L) width <- width + 1L
297+
half <- width %/% 2L
298+
299+
result <- numeric(n)
300+
for (i in seq_len(n)) {
301+
lo <- max(1L, i - half)
302+
hi <- min(n, i + half)
303+
result[i] <- median(x[lo:hi])
304+
}
305+
result
306+
}

0 commit comments

Comments
 (0)