Skip to content

Commit d7b89aa

Browse files
committed
reinstantiate all but one cm test
1 parent bd62a36 commit d7b89aa

File tree

2 files changed

+108
-71
lines changed

2 files changed

+108
-71
lines changed

src/dataset/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,7 @@ mod tests {
720720
let acc = dataset
721721
.cross_validate_mt(5, &params, |_pred, _truth| Ok(array![5., 6.]))
722722
.unwrap();
723+
assert_eq!(acc.dim(), (params.len(), dataset.ntargets()));
723724
assert_eq!(acc, array![[5., 6.], [5., 6.]])
724725
}
725726
#[test]

src/metrics_classification.rs

+107-71
Original file line numberDiff line numberDiff line change
@@ -485,106 +485,136 @@ impl<R: Records, R2: Records, T: AsTargets<Elem = bool>, T2: AsTargets<Elem = Pr
485485
}
486486
}
487487

488-
/*
489488
#[cfg(test)]
490489
mod tests {
491-
use super::{BinaryClassification, ToConfusionMatrix};
492-
use super::{DatasetBase, Pr};
493-
use approx::{abs_diff_eq, AbsDiffEq};
494-
use ndarray::{array, Array1, ArrayBase, ArrayView1, Data, Dimension};
490+
use super::{BinaryClassification, ConfusionMatrix, ToConfusionMatrix};
491+
//use crate::dataset::CountedTargets;
492+
use super::{Label, /*DatasetBase,*/ Pr};
493+
use approx::{assert_abs_diff_eq};
494+
use ndarray::{
495+
array, Array1, /* Axis*/ Array2,
496+
/*ArrayBase,*/ ArrayView1, /*, Data, Dimension*/
497+
};
495498
use rand::{distributions::Uniform, rngs::SmallRng, Rng, SeedableRng};
496-
use std::borrow::Borrow;
497-
498-
fn assert_eq_slice<
499-
A: std::fmt::Debug + PartialEq + AbsDiffEq,
500-
S: Data<Elem = A>,
501-
D: Dimension,
502-
>(
503-
a: ArrayBase<S, D>,
504-
b: &[A],
505-
) {
506-
assert_eq_iter(a.iter(), b);
499+
use std::collections::HashMap;
500+
501+
fn get_labels_map<L: Label>(cm: &ConfusionMatrix<L>) -> HashMap<L, usize> {
502+
cm.members
503+
.iter()
504+
.enumerate()
505+
.map(|(index, label)| (label.clone(), index))
506+
.collect()
507507
}
508508

509-
fn assert_eq_iter<'a, A, B>(a: impl IntoIterator<Item = B>, b: impl IntoIterator<Item = &'a A>)
510-
where
511-
A: 'a + std::fmt::Debug + PartialEq + AbsDiffEq,
512-
B: Borrow<A>,
513-
{
514-
let mut a_iter = a.into_iter();
515-
let mut b_iter = b.into_iter();
516-
loop {
517-
match (a_iter.next(), b_iter.next()) {
518-
(None, None) => break,
519-
(Some(a_item), Some(b_item)) => {
520-
abs_diff_eq!(a_item.borrow(), b_item);
521-
}
522-
_ => {
523-
panic!("assert_eq_iters: iterators had different lengths");
524-
}
525-
}
509+
// confusion mtrices use hash sets for the labels to pair so
510+
// the order of the rows of the matrices is not constant.
511+
// we can transform the index->member mapping in `cm.members`
512+
// into a member->index mapping to check each element independently
513+
fn assert_cm_eq<L: Label>(cm: &ConfusionMatrix<L>, expected: &Array2<f32>, labels: &Array1<L>) {
514+
let map = get_labels_map(cm);
515+
for ((row, column), value) in expected.indexed_iter().map(|((r, c), v)| {
516+
(
517+
(*map.get(&labels[r]).unwrap(), *map.get(&labels[c]).unwrap()),
518+
v,
519+
)
520+
}) {
521+
let cm_value = *cm.matrix.get((row, column)).unwrap();
522+
assert_abs_diff_eq!(cm_value, value);
523+
}
524+
}
525+
526+
fn assert_split_eq<L: Label, C: Fn(&ConfusionMatrix<bool>) -> f32>(
527+
cm: &ConfusionMatrix<L>,
528+
eval: C,
529+
expected: &Array1<f32>,
530+
labels: &Array1<L>,
531+
) {
532+
let map = get_labels_map(cm);
533+
let evals = cm
534+
.split_one_vs_all()
535+
.into_iter()
536+
.map(|x| eval(&x))
537+
.collect::<Vec<_>>();
538+
for (index, value) in expected
539+
.indexed_iter()
540+
.map(|(i, v)| (*map.get(&labels[i]).unwrap(), v))
541+
{
542+
let evals_value = *evals.get(index).unwrap();
543+
assert_abs_diff_eq!(evals_value, value);
526544
}
527545
}
528546

529547
#[test]
530548
fn test_confusion_matrix() {
531-
let predicted = ArrayView1::from(&[0, 1, 0, 1, 0, 1]);
532549
let ground_truth = ArrayView1::from(&[1, 1, 0, 1, 0, 1]);
550+
let predicted = ArrayView1::from(&[0, 1, 0, 1, 0, 1]);
533551

534-
let cm = predicted.confusion_matrix(ground_truth);
552+
let cm = predicted.confusion_matrix(ground_truth).unwrap();
535553

536-
assert_eq_slice(cm.matrix, &[2., 1., 0., 3.]);
554+
let labels = array![0, 1];
555+
let expected = array![[2., 1.], [0., 3.]];
556+
557+
assert_cm_eq(&cm, &expected, &labels);
537558
}
538559

539560
#[test]
540561
fn test_cm_metrices() {
541-
let predicted = Array1::from(vec![0, 1, 0, 1, 0, 1]);
542562
let ground_truth = Array1::from(vec![1, 1, 0, 1, 0, 1]);
563+
let predicted = Array1::from(vec![0, 1, 0, 1, 0, 1]);
564+
565+
let x = predicted.confusion_matrix(ground_truth).unwrap();
543566

544-
let x = predicted.confusion_matrix(ground_truth);
567+
let labels = array![0, 1];
545568

546-
abs_diff_eq!(x.accuracy(), 5.0 / 6.0_f32);
547-
abs_diff_eq!(
569+
assert_abs_diff_eq!(x.accuracy(), 5.0 / 6.0_f32);
570+
assert_abs_diff_eq!(
548571
x.mcc(),
549572
(2. * 3. - 1. * 0.) / (2.0f32 * 3. * 3. * 4.).sqrt() as f32
550573
);
551574

552-
assert_eq_iter(
553-
x.split_one_vs_all().into_iter().map(|x| x.precision()),
554-
&[1.0, 3. / 4.],
575+
assert_split_eq(
576+
&x,
577+
|cm| ConfusionMatrix::precision(cm),
578+
&array![1.0, 3. / 4.],
579+
&labels,
555580
);
556-
assert_eq_iter(
557-
x.split_one_vs_all().into_iter().map(|x| x.recall()),
558-
&[2.0 / 3.0, 1.0],
581+
assert_split_eq(
582+
&x,
583+
|cm| ConfusionMatrix::recall(cm),
584+
&array![2.0 / 3.0, 1.0],
585+
&labels,
559586
);
560-
assert_eq_iter(
561-
x.split_one_vs_all().into_iter().map(|x| x.f1_score()),
562-
&[4.0 / 5.0, 6.0 / 7.0],
587+
assert_split_eq(
588+
&x,
589+
|cm| ConfusionMatrix::f1_score(cm),
590+
&array![4.0 / 5.0, 6.0 / 7.0],
591+
&labels,
563592
);
564593
}
565594

566-
#[test]
595+
/*#[test]
567596
fn test_modification() {
568597
let predicted = array![0, 3, 2, 0, 1, 1, 1, 3, 2, 3];
569598
570-
let ground_truth =
571-
DatasetBase::new((), array![0, 2, 3, 0, 1, 2, 1, 2, 3, 2]).with_labels(&[0, 1, 2]);
599+
let ground_truth : DatasetBase<Array2<f64>, CountedTargets<usize, Array2<usize>>> =
600+
DatasetBase::new(array![[0.,0.]], array![0, 2, 3, 0, 1, 2, 1, 2, 3, 2].insert_axis(Axis(1))).with_labels(&[&[0],&[1],&[2]]]);
572601
573602
// exclude class 3 from evaluation
574-
let cm = predicted.confusion_matrix(&ground_truth);
603+
let cm = predicted.confusion_matrix(&ground_truth).unwrap();
604+
println!("cm {:?}",cm);
575605
576606
assert_eq_slice(cm.matrix, &[2., 0., 0., 0., 2., 1., 0., 0., 0.]);
577607
578608
// weight errors in class 2 more severe and exclude class 1
579609
let ground_truth = ground_truth
580-
.with_weights(vec![1., 2., 1., 1., 1., 2., 1., 2., 1., 2.])
581-
.with_labels(&[0, 2, 3]);
610+
.with_weights(array![1., 2., 1., 1., 1., 2., 1., 2., 1., 2.])
611+
.with_labels(&[&[0], &[2], &[3]]);
582612
583-
let cm = predicted.confusion_matrix(&ground_truth);
613+
let cm = predicted.confusion_matrix(&ground_truth).unwrap();
584614
585615
// the false-positive error for label=2 is twice severe here
586616
assert_eq_slice(cm.matrix, &[2., 0., 0., 0., 0., 4., 0., 3., 0.]);
587-
}
617+
}*/
588618

589619
#[test]
590620
fn test_roc_curve() {
@@ -602,7 +632,7 @@ mod tests {
602632
(1., 1.),
603633
];
604634

605-
let roc = predicted.roc(&groundtruth);
635+
let roc = predicted.roc(&groundtruth).unwrap();
606636
assert_eq!(roc.get_curve(), result);
607637
}
608638

@@ -619,32 +649,38 @@ mod tests {
619649
.collect::<Vec<_>>();
620650

621651
// ROC Area-Under-Curve should be approximately 0.5
622-
let roc = predicted.roc(&ground_truth);
652+
let roc = predicted.roc(&ground_truth).unwrap();
623653
assert!((roc.area_under_curve() - 0.5) < 0.04);
624654
}
625655

626656
#[test]
627657
fn split_one_vs_all() {
628-
let predicted = array![0, 3, 2, 0, 1, 1, 1, 3, 2, 3];
629658
let ground_truth = array![0, 2, 3, 0, 1, 2, 1, 2, 3, 2];
659+
let predicted = array![0, 3, 2, 0, 1, 1, 1, 3, 2, 3];
630660

631661
// create a confusion matrix
632-
let cm = predicted.confusion_matrix(ground_truth);
662+
let cm = predicted.confusion_matrix(ground_truth).unwrap();
663+
664+
let labels = array![0, 1, 2, 3];
665+
let bin_labels = array![true, false];
666+
let map = get_labels_map(&cm);
633667

634668
// split four class confusion matrix into 4 binary confusion matrix
635669
let n_cm = cm.split_one_vs_all();
636670

637-
let result: &[&[f32]] = &[
638-
&[2., 0., 0., 8.], // no misclassification for label=0
639-
&[2., 1., 0., 7.], // one false-positive for label=1
640-
&[0., 2., 4., 4.], // two false-positive and four false-negative for label=2
641-
&[0., 3., 2., 5.], // three false-positive and two false-negative for label=3
671+
let result = &[
672+
array![[2., 0.], [0., 8.]], // no misclassification for label=0
673+
array![[2., 1.], [0., 7.]], // one false-positive for label=1
674+
array![[0., 2.], [4., 4.]], // two false-positive and four false-negative for label=2
675+
array![[0., 3.], [2., 5.]], // three false-positive and two false-negative for label=3
642676
];
643677

644-
// compare to result
645-
n_cm.into_iter()
646-
.zip(result.iter())
647-
.for_each(|(x, r)| assert_eq_slice(x.matrix, r))
678+
for (r, x) in result
679+
.iter()
680+
.zip(labels.iter())
681+
.map(|(r, l)| (r, n_cm.get(*map.get(l).unwrap()).unwrap()))
682+
{
683+
assert_cm_eq(x, r, &bin_labels);
684+
}
648685
}
649686
}
650-
*/

0 commit comments

Comments
 (0)