@@ -5,137 +5,21 @@ extern crate num_traits;
55
66use ndarray:: * ;
77use ndarray_linalg:: * ;
8- use num_traits:: { One , Zero } ;
9-
10- /// Returns the matrix with the specified `row` and `col` removed.
11- fn matrix_minor < A , S > ( a : ArrayBase < S , Ix2 > , ( row, col) : ( usize , usize ) ) -> Array2 < A >
12- where
13- A : Scalar ,
14- S : Data < Elem = A > ,
15- {
16- let mut select_rows = ( 0 ..a. rows ( ) ) . collect :: < Vec < _ > > ( ) ;
17- select_rows. remove ( row) ;
18- let mut select_cols = ( 0 ..a. cols ( ) ) . collect :: < Vec < _ > > ( ) ;
19- select_cols. remove ( col) ;
20- a. select ( Axis ( 0 ) , & select_rows) . select (
21- Axis ( 1 ) ,
22- & select_cols,
23- )
24- }
25-
26- /// Computes the determinant of matrix `a`.
27- ///
28- /// Note: This implementation is written to be clearly correct so that it's
29- /// useful for verification, but it's very inefficient.
30- fn det_naive < A , S > ( a : ArrayBase < S , Ix2 > ) -> A
31- where
32- A : Scalar ,
33- S : Data < Elem = A > ,
34- {
35- assert_eq ! ( a. rows( ) , a. cols( ) ) ;
36- match a. cols ( ) {
37- 0 => A :: one ( ) ,
38- 1 => a[ ( 0 , 0 ) ] ,
39- cols => {
40- ( 0 ..cols)
41- . map ( |col| {
42- let sign = if col % 2 == 0 { A :: one ( ) } else { -A :: one ( ) } ;
43- sign * a[ ( 0 , col) ] * det_naive ( matrix_minor ( a. view ( ) , ( 0 , col) ) )
44- } )
45- . fold ( A :: zero ( ) , |sum, subdet| sum + subdet)
46- }
47- }
48- }
49-
50- #[ test]
51- fn det_empty ( ) {
52- macro_rules! det_empty {
53- ( $elem: ty) => {
54- let a: Array2 <$elem> = Array2 :: zeros( ( 0 , 0 ) ) ;
55- assert_eq!( a. factorize( ) . unwrap( ) . det( ) . unwrap( ) , One :: one( ) ) ;
56- assert_eq!( a. factorize( ) . unwrap( ) . det_into( ) . unwrap( ) , One :: one( ) ) ;
57- assert_eq!( a. det( ) . unwrap( ) , One :: one( ) ) ;
58- assert_eq!( a. det_into( ) . unwrap( ) , One :: one( ) ) ;
59- }
60- }
61- det_empty ! ( f64 ) ;
62- det_empty ! ( f32 ) ;
63- det_empty ! ( c64) ;
64- det_empty ! ( c32) ;
65- }
66-
67- #[ test]
68- fn det_zero ( ) {
69- macro_rules! det_zero {
70- ( $elem: ty) => {
71- let a: Array2 <$elem> = Array2 :: zeros( ( 1 , 1 ) ) ;
72- assert_eq!( a. det( ) . unwrap( ) , Zero :: zero( ) ) ;
73- assert_eq!( a. det_into( ) . unwrap( ) , Zero :: zero( ) ) ;
74- }
75- }
76- det_zero ! ( f64 ) ;
77- det_zero ! ( f32 ) ;
78- det_zero ! ( c64) ;
79- det_zero ! ( c32) ;
80- }
81-
82- #[ test]
83- fn det_zero_nonsquare ( ) {
84- macro_rules! det_zero_nonsquare {
85- ( $elem: ty, $shape: expr) => {
86- let a: Array2 <$elem> = Array2 :: zeros( $shape) ;
87- assert!( a. det( ) . is_err( ) ) ;
88- assert!( a. det_into( ) . is_err( ) ) ;
89- }
90- }
91- for & shape in & [ ( 1 , 2 ) . into_shape ( ) , ( 1 , 2 ) . f ( ) ] {
92- det_zero_nonsquare ! ( f64 , shape) ;
93- det_zero_nonsquare ! ( f32 , shape) ;
94- det_zero_nonsquare ! ( c64, shape) ;
95- det_zero_nonsquare ! ( c32, shape) ;
96- }
97- }
988
999#[ test]
100- fn det ( ) {
101- macro_rules! det {
102- ( $elem: ty, $shape: expr, $rtol: expr) => {
103- let a: Array2 <$elem> = random( $shape) ;
104- println!( "a = \n {:?}" , a) ;
105- let det = det_naive( a. view( ) ) ;
106- assert_rclose!( a. factorize( ) . unwrap( ) . det( ) . unwrap( ) , det, $rtol) ;
107- assert_rclose!( a. factorize( ) . unwrap( ) . det_into( ) . unwrap( ) , det, $rtol) ;
108- assert_rclose!( a. det( ) . unwrap( ) , det, $rtol) ;
109- assert_rclose!( a. det_into( ) . unwrap( ) , det, $rtol) ;
110- }
111- }
112- for rows in 1 ..5 {
113- for & shape in & [ ( rows, rows) . into_shape ( ) , ( rows, rows) . f ( ) ] {
114- det ! ( f64 , shape, 1e-9 ) ;
115- det ! ( f32 , shape, 1e-4 ) ;
116- det ! ( c64, shape, 1e-9 ) ;
117- det ! ( c32, shape, 1e-4 ) ;
118- }
119- }
10+ fn solve_random ( ) {
11+ let a: Array2 < f64 > = random ( ( 3 , 3 ) ) ;
12+ let x: Array1 < f64 > = random ( 3 ) ;
13+ let b = a. dot ( & x) ;
14+ let y = a. solve_into ( b) . unwrap ( ) ;
15+ assert_close_l2 ! ( & x, & y, 1e-7 ) ;
12016}
12117
12218#[ test]
123- fn det_nonsquare ( ) {
124- macro_rules! det_nonsquare {
125- ( $elem: ty, $shape: expr) => {
126- let a: Array2 <$elem> = random( $shape) ;
127- assert!( a. factorize( ) . unwrap( ) . det( ) . is_err( ) ) ;
128- assert!( a. factorize( ) . unwrap( ) . det_into( ) . is_err( ) ) ;
129- assert!( a. det( ) . is_err( ) ) ;
130- assert!( a. det_into( ) . is_err( ) ) ;
131- }
132- }
133- for & dims in & [ ( 1 , 0 ) , ( 1 , 2 ) , ( 2 , 1 ) , ( 2 , 3 ) ] {
134- for & shape in & [ dims. clone ( ) . into_shape ( ) , dims. clone ( ) . f ( ) ] {
135- det_nonsquare ! ( f64 , shape) ;
136- det_nonsquare ! ( f32 , shape) ;
137- det_nonsquare ! ( c64, shape) ;
138- det_nonsquare ! ( c32, shape) ;
139- }
140- }
19+ fn solve_random_t ( ) {
20+ let a: Array2 < f64 > = random ( ( 3 , 3 ) . f ( ) ) ;
21+ let x: Array1 < f64 > = random ( 3 ) ;
22+ let b = a. dot ( & x) ;
23+ let y = a. solve_into ( b) . unwrap ( ) ;
24+ assert_close_l2 ! ( & x, & y, 1e-7 ) ;
14125}
0 commit comments