@@ -485,106 +485,136 @@ impl<R: Records, R2: Records, T: AsTargets<Elem = bool>, T2: AsTargets<Elem = Pr
485
485
}
486
486
}
487
487
488
- /*
489
488
#[ cfg( test) ]
490
489
mod tests {
491
- use super::{BinaryClassification, ToConfusionMatrix};
492
- use super::{DatasetBase, Pr};
493
- use approx::{abs_diff_eq, AbsDiffEq};
494
- use ndarray::{array, Array1, ArrayBase, ArrayView1, Data, Dimension};
490
+ use super :: { BinaryClassification , ConfusionMatrix , ToConfusionMatrix } ;
491
+ //use crate::dataset::CountedTargets;
492
+ use super :: { Label , /*DatasetBase,*/ Pr } ;
493
+ use approx:: { assert_abs_diff_eq} ;
494
+ use ndarray:: {
495
+ array, Array1 , /* Axis*/ Array2 ,
496
+ /*ArrayBase,*/ ArrayView1 , /*, Data, Dimension*/
497
+ } ;
495
498
use rand:: { distributions:: Uniform , rngs:: SmallRng , Rng , SeedableRng } ;
496
- use std::borrow::Borrow;
497
-
498
- fn assert_eq_slice<
499
- A: std::fmt::Debug + PartialEq + AbsDiffEq,
500
- S: Data<Elem = A>,
501
- D: Dimension,
502
- >(
503
- a: ArrayBase<S, D>,
504
- b: &[A],
505
- ) {
506
- assert_eq_iter(a.iter(), b);
499
+ use std:: collections:: HashMap ;
500
+
501
+ fn get_labels_map < L : Label > ( cm : & ConfusionMatrix < L > ) -> HashMap < L , usize > {
502
+ cm. members
503
+ . iter ( )
504
+ . enumerate ( )
505
+ . map ( |( index, label) | ( label. clone ( ) , index) )
506
+ . collect ( )
507
507
}
508
508
509
- fn assert_eq_iter<'a, A, B>(a: impl IntoIterator<Item = B>, b: impl IntoIterator<Item = &'a A>)
510
- where
511
- A: 'a + std::fmt::Debug + PartialEq + AbsDiffEq,
512
- B: Borrow<A>,
513
- {
514
- let mut a_iter = a.into_iter();
515
- let mut b_iter = b.into_iter();
516
- loop {
517
- match (a_iter.next(), b_iter.next()) {
518
- (None, None) => break,
519
- (Some(a_item), Some(b_item)) => {
520
- abs_diff_eq!(a_item.borrow(), b_item);
521
- }
522
- _ => {
523
- panic!("assert_eq_iters: iterators had different lengths");
524
- }
525
- }
509
+ // confusion mtrices use hash sets for the labels to pair so
510
+ // the order of the rows of the matrices is not constant.
511
+ // we can transform the index->member mapping in `cm.members`
512
+ // into a member->index mapping to check each element independently
513
+ fn assert_cm_eq < L : Label > ( cm : & ConfusionMatrix < L > , expected : & Array2 < f32 > , labels : & Array1 < L > ) {
514
+ let map = get_labels_map ( cm) ;
515
+ for ( ( row, column) , value) in expected. indexed_iter ( ) . map ( |( ( r, c) , v) | {
516
+ (
517
+ ( * map. get ( & labels[ r] ) . unwrap ( ) , * map. get ( & labels[ c] ) . unwrap ( ) ) ,
518
+ v,
519
+ )
520
+ } ) {
521
+ let cm_value = * cm. matrix . get ( ( row, column) ) . unwrap ( ) ;
522
+ assert_abs_diff_eq ! ( cm_value, value) ;
523
+ }
524
+ }
525
+
526
+ fn assert_split_eq < L : Label , C : Fn ( & ConfusionMatrix < bool > ) -> f32 > (
527
+ cm : & ConfusionMatrix < L > ,
528
+ eval : C ,
529
+ expected : & Array1 < f32 > ,
530
+ labels : & Array1 < L > ,
531
+ ) {
532
+ let map = get_labels_map ( cm) ;
533
+ let evals = cm
534
+ . split_one_vs_all ( )
535
+ . into_iter ( )
536
+ . map ( |x| eval ( & x) )
537
+ . collect :: < Vec < _ > > ( ) ;
538
+ for ( index, value) in expected
539
+ . indexed_iter ( )
540
+ . map ( |( i, v) | ( * map. get ( & labels[ i] ) . unwrap ( ) , v) )
541
+ {
542
+ let evals_value = * evals. get ( index) . unwrap ( ) ;
543
+ assert_abs_diff_eq ! ( evals_value, value) ;
526
544
}
527
545
}
528
546
529
547
#[ test]
530
548
fn test_confusion_matrix ( ) {
531
- let predicted = ArrayView1::from(&[0, 1, 0, 1, 0, 1]);
532
549
let ground_truth = ArrayView1 :: from ( & [ 1 , 1 , 0 , 1 , 0 , 1 ] ) ;
550
+ let predicted = ArrayView1 :: from ( & [ 0 , 1 , 0 , 1 , 0 , 1 ] ) ;
533
551
534
- let cm = predicted.confusion_matrix(ground_truth);
552
+ let cm = predicted. confusion_matrix ( ground_truth) . unwrap ( ) ;
535
553
536
- assert_eq_slice(cm.matrix, &[2., 1., 0., 3.]);
554
+ let labels = array ! [ 0 , 1 ] ;
555
+ let expected = array ! [ [ 2. , 1. ] , [ 0. , 3. ] ] ;
556
+
557
+ assert_cm_eq ( & cm, & expected, & labels) ;
537
558
}
538
559
539
560
#[ test]
540
561
fn test_cm_metrices ( ) {
541
- let predicted = Array1::from(vec![0, 1, 0, 1, 0, 1]);
542
562
let ground_truth = Array1 :: from ( vec ! [ 1 , 1 , 0 , 1 , 0 , 1 ] ) ;
563
+ let predicted = Array1 :: from ( vec ! [ 0 , 1 , 0 , 1 , 0 , 1 ] ) ;
564
+
565
+ let x = predicted. confusion_matrix ( ground_truth) . unwrap ( ) ;
543
566
544
- let x = predicted.confusion_matrix(ground_truth) ;
567
+ let labels = array ! [ 0 , 1 ] ;
545
568
546
- abs_diff_eq !(x.accuracy(), 5.0 / 6.0_f32);
547
- abs_diff_eq !(
569
+ assert_abs_diff_eq ! ( x. accuracy( ) , 5.0 / 6.0_f32 ) ;
570
+ assert_abs_diff_eq ! (
548
571
x. mcc( ) ,
549
572
( 2. * 3. - 1. * 0. ) / ( 2.0f32 * 3. * 3. * 4. ) . sqrt( ) as f32
550
573
) ;
551
574
552
- assert_eq_iter(
553
- x.split_one_vs_all().into_iter().map(|x| x.precision()),
554
- &[1.0, 3. / 4.],
575
+ assert_split_eq (
576
+ & x,
577
+ |cm| ConfusionMatrix :: precision ( cm) ,
578
+ & array ! [ 1.0 , 3. / 4. ] ,
579
+ & labels,
555
580
) ;
556
- assert_eq_iter(
557
- x.split_one_vs_all().into_iter().map(|x| x.recall()),
558
- &[2.0 / 3.0, 1.0],
581
+ assert_split_eq (
582
+ & x,
583
+ |cm| ConfusionMatrix :: recall ( cm) ,
584
+ & array ! [ 2.0 / 3.0 , 1.0 ] ,
585
+ & labels,
559
586
) ;
560
- assert_eq_iter(
561
- x.split_one_vs_all().into_iter().map(|x| x.f1_score()),
562
- &[4.0 / 5.0, 6.0 / 7.0],
587
+ assert_split_eq (
588
+ & x,
589
+ |cm| ConfusionMatrix :: f1_score ( cm) ,
590
+ & array ! [ 4.0 / 5.0 , 6.0 / 7.0 ] ,
591
+ & labels,
563
592
) ;
564
593
}
565
594
566
- #[test]
595
+ /* #[test]
567
596
fn test_modification() {
568
597
let predicted = array![0, 3, 2, 0, 1, 1, 1, 3, 2, 3];
569
598
570
- let ground_truth =
571
- DatasetBase::new((), array![0, 2, 3, 0, 1, 2, 1, 2, 3, 2]) .with_labels(&[0, 1, 2 ]);
599
+ let ground_truth : DatasetBase<Array2<f64>, CountedTargets<usize, Array2<usize>>> =
600
+ DatasetBase::new(array![[0.,0.]], array![0, 2, 3, 0, 1, 2, 1, 2, 3, 2].insert_axis(Axis(1))) .with_labels(&[&[0],&[1],&[2]] ]);
572
601
573
602
// exclude class 3 from evaluation
574
- let cm = predicted.confusion_matrix(&ground_truth);
603
+ let cm = predicted.confusion_matrix(&ground_truth).unwrap();
604
+ println!("cm {:?}",cm);
575
605
576
606
assert_eq_slice(cm.matrix, &[2., 0., 0., 0., 2., 1., 0., 0., 0.]);
577
607
578
608
// weight errors in class 2 more severe and exclude class 1
579
609
let ground_truth = ground_truth
580
- .with_weights(vec ![1., 2., 1., 1., 1., 2., 1., 2., 1., 2.])
581
- .with_labels(&[0, 2, 3 ]);
610
+ .with_weights(array ![1., 2., 1., 1., 1., 2., 1., 2., 1., 2.])
611
+ .with_labels(&[&[0], &[2], &[3] ]);
582
612
583
- let cm = predicted.confusion_matrix(&ground_truth);
613
+ let cm = predicted.confusion_matrix(&ground_truth).unwrap() ;
584
614
585
615
// the false-positive error for label=2 is twice severe here
586
616
assert_eq_slice(cm.matrix, &[2., 0., 0., 0., 0., 4., 0., 3., 0.]);
587
- }
617
+ }*/
588
618
589
619
#[ test]
590
620
fn test_roc_curve ( ) {
@@ -602,7 +632,7 @@ mod tests {
602
632
( 1. , 1. ) ,
603
633
] ;
604
634
605
- let roc = predicted.roc(&groundtruth);
635
+ let roc = predicted. roc ( & groundtruth) . unwrap ( ) ;
606
636
assert_eq ! ( roc. get_curve( ) , result) ;
607
637
}
608
638
@@ -619,32 +649,38 @@ mod tests {
619
649
. collect :: < Vec < _ > > ( ) ;
620
650
621
651
// ROC Area-Under-Curve should be approximately 0.5
622
- let roc = predicted.roc(&ground_truth);
652
+ let roc = predicted. roc ( & ground_truth) . unwrap ( ) ;
623
653
assert ! ( ( roc. area_under_curve( ) - 0.5 ) < 0.04 ) ;
624
654
}
625
655
626
656
#[ test]
627
657
fn split_one_vs_all ( ) {
628
- let predicted = array![0, 3, 2, 0, 1, 1, 1, 3, 2, 3];
629
658
let ground_truth = array ! [ 0 , 2 , 3 , 0 , 1 , 2 , 1 , 2 , 3 , 2 ] ;
659
+ let predicted = array ! [ 0 , 3 , 2 , 0 , 1 , 1 , 1 , 3 , 2 , 3 ] ;
630
660
631
661
// create a confusion matrix
632
- let cm = predicted.confusion_matrix(ground_truth);
662
+ let cm = predicted. confusion_matrix ( ground_truth) . unwrap ( ) ;
663
+
664
+ let labels = array ! [ 0 , 1 , 2 , 3 ] ;
665
+ let bin_labels = array ! [ true , false ] ;
666
+ let map = get_labels_map ( & cm) ;
633
667
634
668
// split four class confusion matrix into 4 binary confusion matrix
635
669
let n_cm = cm. split_one_vs_all ( ) ;
636
670
637
- let result: &[&[f32]] = &[
638
- &[ 2., 0., 0., 8.], // no misclassification for label=0
639
- &[ 2., 1., 0., 7.], // one false-positive for label=1
640
- &[ 0., 2., 4., 4.], // two false-positive and four false-negative for label=2
641
- &[ 0., 3., 2., 5.], // three false-positive and two false-negative for label=3
671
+ let result = & [
672
+ array ! [ [ 2. , 0. ] , [ 0. , 8. ] ] , // no misclassification for label=0
673
+ array ! [ [ 2. , 1. ] , [ 0. , 7. ] ] , // one false-positive for label=1
674
+ array ! [ [ 0. , 2. ] , [ 4. , 4. ] ] , // two false-positive and four false-negative for label=2
675
+ array ! [ [ 0. , 3. ] , [ 2. , 5. ] ] , // three false-positive and two false-negative for label=3
642
676
] ;
643
677
644
- // compare to result
645
- n_cm.into_iter()
646
- .zip(result.iter())
647
- .for_each(|(x, r)| assert_eq_slice(x.matrix, r))
678
+ for ( r, x) in result
679
+ . iter ( )
680
+ . zip ( labels. iter ( ) )
681
+ . map ( |( r, l) | ( r, n_cm. get ( * map. get ( l) . unwrap ( ) ) . unwrap ( ) ) )
682
+ {
683
+ assert_cm_eq ( x, r, & bin_labels) ;
684
+ }
648
685
}
649
686
}
650
- */
0 commit comments