Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 214 additions & 0 deletions crates/burn-train/src/metric/cer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
use super::state::{FormatOptions, NumericMetricState};
use super::{MetricEntry, MetricMetadata};
use crate::metric::{Metric, Numeric};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::{Int, Tensor};
use core::marker::PhantomData;

/// Computes the edit distance (Levenshtein distance) between two sequences of integers.
///
/// The edit distance is defined as the minimum number of single-element edits (insertions,
/// deletions, or substitutions) required to change one sequence into the other. This
/// implementation is optimized for space, using only two rows of the dynamic programming table.
///
pub fn edit_distance(reference: &[i32], prediction: &[i32]) -> usize {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can mark it as pub(crate) only

let mut prev = (0..=prediction.len()).collect::<Vec<_>>();
let mut curr = vec![0; prediction.len() + 1];

for (i, &r) in reference.iter().enumerate() {
curr[0] = i + 1;
for (j, &p) in prediction.iter().enumerate() {
curr[j + 1] = if r == p {
prev[j] // no operation needed
} else {
1 + prev[j].min(prev[j + 1]).min(curr[j]) // substitution, insertion, deletion
};
}
core::mem::swap(&mut prev, &mut curr);
}
prev[prediction.len()]
}


/// Character error rate (CER) is defined as the edit distance (e.g. Levenshtein distance) between the predicted
/// and reference character sequences, divided by the total number of characters in the reference.
/// This metric is commonly used in tasks such as speech recognition, OCR, or text generation
/// to quantify how closely the predicted output matches the ground truth at a character level.
///
#[derive(Default)]
pub struct CharErrorRate<B: Backend> {
state: NumericMetricState,
pad_token: Option<usize>,
_b: PhantomData<B>,
}

/// The [character error rate metric](CerMetric) input type.
#[derive(new)]
pub struct CerInput<B: Backend> {
/// The predicted token sequences (as a 2-D tensor of token indices).
pub outputs: Tensor<B, 2, Int>,
/// The target token sequences (as a 2-D tensor of token indices).
pub targets: Tensor<B, 2, Int>,
}

impl<B: Backend> CharErrorRate<B> {
/// Creates the metric.
pub fn new() -> Self {
Self::default()
}

/// Sets the pad token.
pub fn with_pad_token(mut self, index: usize) -> Self {
self.pad_token = Some(index);
self
}
}

/// The [character error rate metric](CerMetric) implementation.
impl<B: Backend> Metric for CharErrorRate<B> {
type Input = CerInput<B>;

fn update(&mut self, input: &CerInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
let outputs = &input.outputs;
let targets = &input.targets;
let [batch_size, seq_len] = targets.dims();

let (output_lengths, target_lengths) = if let Some(pad) = self.pad_token {
// Create boolean masks for non-padding tokens.
let output_mask = outputs.clone().not_equal_elem(pad as i64);
let target_mask = targets.clone().not_equal_elem(pad as i64);

let output_lengths_tensor = output_mask.int().sum_dim(1);
let target_lengths_tensor = target_mask.int().sum_dim(1);

(
output_lengths_tensor.to_data().to_vec::<i64>().unwrap(),
target_lengths_tensor.to_data().to_vec::<i64>().unwrap(),
)
} else {
// If there's no padding, all sequences have the full length.
(vec![seq_len as i64; batch_size], vec![seq_len as i64; batch_size])
};

let outputs_data = outputs.to_data().to_vec::<i64>().unwrap();
let targets_data = targets.to_data().to_vec::<i64>().unwrap();

let total_edit_distance: usize = (0..batch_size)
.into_iter()
.map(|i| {
let start = i * seq_len;

// Get pre-calculated lengths for the current sequence.
let output_len = output_lengths[i] as usize;
let target_len = target_lengths[i] as usize;

let output_seq_slice = &outputs_data[start..(start + output_len)];
let target_seq_slice = &targets_data[start..(start + target_len)];
let output_seq: Vec<i32> = output_seq_slice.iter().map(|&x| x as i32).collect();
let target_seq: Vec<i32> = target_seq_slice.iter().map(|&x| x as i32).collect();

edit_distance(&target_seq, &output_seq)
})
.sum();

let total_target_length = target_lengths.iter().map(|&x| x as f64).sum::<f64>();

let value = if total_target_length > 0.0 {
100.0 * total_edit_distance as f64 / total_target_length
} else {
0.0
};

self.state.update(
value,
batch_size,
FormatOptions::new(self.name()).unit("%").precision(2),
)
Comment on lines +122 to +126
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the state might need to keep track of the errors and total characters (or words for WER)? Otherwise aggregation might be incorrect 🤔 this would require a new state type though

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, that sounds correct. However, the value here includes the errors relative to the total characters, since value=total_edit_distance/total_characters * 100, so why would we need to keep the total characters?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the current batch that's accurate, but when aggregated for an epoch it might be incorrect since this is a numeric state (not all batches have the same composition). Probably out of scope for this PR so no worries 👍

}

fn clear(&mut self) {
self.state.reset();
}

fn name(&self) -> String {
"CER".to_string()
}
}

/// The [character error rate metric](CerMetric) implementation.
impl<B: Backend> Numeric for CharErrorRate<B> {
fn value(&self) -> f64 {
self.state.value()
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;

/// Perfect match ⇒ CER = 0 %.
#[test]
fn test_cer_without_padding() {
let device = Default::default();
let mut metric = CharErrorRate::<TestBackend>::new();

// Batch size = 2, sequence length = 2
let preds = Tensor::from_data([[1, 2], [3, 4]], &device);
let tgts = Tensor::from_data([[1, 2], [3, 4]], &device);

metric.update(&CerInput::new(preds, tgts), &MetricMetadata::fake());

assert_eq!(0.0, metric.value());
}

/// Two edits in four target tokens ⇒ 50 %.
#[test]
fn test_cer_without_padding_two_errors() {
let device = Default::default();
let mut metric = CharErrorRate::<TestBackend>::new();

// One substitution in each sequence.
let preds = Tensor::from_data([[1, 2], [3, 5]], &device);
let tgts = Tensor::from_data([[1, 3], [3, 4]], &device);

metric.update(&CerInput::new(preds, tgts), &MetricMetadata::fake());

// 2 edits / 4 tokens = 50 %
assert_eq!(50.0, metric.value());
}

/// Same scenario as above, but with right-padding (token 9) ignored.
#[test]
fn test_cer_with_padding() {
let device = Default::default();
let pad = 9_i64;
let mut metric = CharErrorRate::<TestBackend>::new().with_pad_token(pad as usize);

// Each row has three columns, last one is the pad token.
let preds = Tensor::from_data([[1, 2, pad], [3, 5, pad]], &device);
let tgts = Tensor::from_data([[1, 3, pad], [3, 4, pad]], &device);

metric.update(&CerInput::new(preds, tgts), &MetricMetadata::fake());
assert_eq!(50.0, metric.value());
}

/// `clear()` must reset the running statistics to zero.
#[test]
fn test_clear_resets_state() {
let device = Default::default();
let mut metric = CharErrorRate::<TestBackend>::new();

let preds = Tensor::from_data([[1, 2]], &device);
let tgts = Tensor::from_data([[1, 3]], &device); // one error

metric.update(
&CerInput::new(preds.clone(), tgts.clone()),
&MetricMetadata::fake(),
);
assert!(metric.value() > 0.0);

metric.clear();
assert!(metric.value().is_nan());
}
}
4 changes: 4 additions & 0 deletions crates/burn-train/src/metric/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub use memory_use::*;
mod acc;
mod auroc;
mod base;
mod cer;
mod confusion_stats;
mod fbetascore;
mod hamming;
Expand All @@ -36,10 +37,12 @@ mod loss;
mod precision;
mod recall;
mod top_k_acc;
mod wer;

pub use acc::*;
pub use auroc::*;
pub use base::*;
pub use cer::*;
pub use confusion_stats::ConfusionStatsInput;
pub use fbetascore::*;
pub use hamming::*;
Expand All @@ -49,6 +52,7 @@ pub use loss::*;
pub use precision::*;
pub use recall::*;
pub use top_k_acc::*;
pub use wer::*;

pub(crate) mod classification;
pub(crate) mod processor;
Expand Down
Loading
Loading