5
5
use ndarray:: linalg:: general_mat_mul;
6
6
use ndarray:: linalg:: kron;
7
7
use ndarray:: prelude:: * ;
8
+ #[ cfg( feature = "approx" ) ]
9
+ use ndarray:: Order ;
8
10
use ndarray:: { rcarr1, rcarr2} ;
9
11
use ndarray:: { Data , LinalgScalar } ;
10
12
use ndarray:: { Ix , Ixs } ;
11
- use num_traits :: Zero ;
13
+ use ndarray_gen :: array_builder :: ArrayBuilder ;
12
14
13
15
use approx:: assert_abs_diff_eq;
14
16
use defmac:: defmac;
17
+ use num_traits:: Num ;
18
+ use num_traits:: Zero ;
15
19
16
20
fn test_oper ( op : & str , a : & [ f32 ] , b : & [ f32 ] , c : & [ f32 ] )
17
21
{
@@ -271,31 +275,20 @@ fn product()
271
275
}
272
276
}
273
277
274
- fn range_mat ( m : Ix , n : Ix ) -> Array2 < f32 >
278
+ fn range_mat < A : Num + Copy > ( m : Ix , n : Ix ) -> Array2 < A >
275
279
{
276
- Array :: linspace ( 0. , ( m * n) as f32 - 1. , m * n)
277
- . into_shape_with_order ( ( m, n) )
278
- . unwrap ( )
279
- }
280
-
281
- fn range_mat64 ( m : Ix , n : Ix ) -> Array2 < f64 >
282
- {
283
- Array :: linspace ( 0. , ( m * n) as f64 - 1. , m * n)
284
- . into_shape_with_order ( ( m, n) )
285
- . unwrap ( )
280
+ ArrayBuilder :: new ( ( m, n) ) . build ( )
286
281
}
287
282
288
283
#[ cfg( feature = "approx" ) ]
289
284
fn range1_mat64 ( m : Ix ) -> Array1 < f64 >
290
285
{
291
- Array :: linspace ( 0. , m as f64 - 1. , m )
286
+ ArrayBuilder :: new ( m ) . build ( )
292
287
}
293
288
294
289
fn range_i32 ( m : Ix , n : Ix ) -> Array2 < i32 >
295
290
{
296
- Array :: from_iter ( 0 ..( m * n) as i32 )
297
- . into_shape_with_order ( ( m, n) )
298
- . unwrap ( )
291
+ ArrayBuilder :: new ( ( m, n) ) . build ( )
299
292
}
300
293
301
294
// simple, slow, correct (hopefully) mat mul
@@ -332,8 +325,8 @@ where
332
325
fn mat_mul ( )
333
326
{
334
327
let ( m, n, k) = ( 8 , 8 , 8 ) ;
335
- let a = range_mat ( m, n) ;
336
- let b = range_mat ( n, k) ;
328
+ let a = range_mat :: < f32 > ( m, n) ;
329
+ let b = range_mat :: < f32 > ( n, k) ;
337
330
let mut b = b / 4. ;
338
331
{
339
332
let mut c = b. column_mut ( 0 ) ;
@@ -351,8 +344,8 @@ fn mat_mul()
351
344
assert_eq ! ( ab, af. dot( & bf) ) ;
352
345
353
346
let ( m, n, k) = ( 10 , 5 , 11 ) ;
354
- let a = range_mat ( m, n) ;
355
- let b = range_mat ( n, k) ;
347
+ let a = range_mat :: < f32 > ( m, n) ;
348
+ let b = range_mat :: < f32 > ( n, k) ;
356
349
let mut b = b / 4. ;
357
350
{
358
351
let mut c = b. column_mut ( 0 ) ;
@@ -370,8 +363,8 @@ fn mat_mul()
370
363
assert_eq ! ( ab, af. dot( & bf) ) ;
371
364
372
365
let ( m, n, k) = ( 10 , 8 , 1 ) ;
373
- let a = range_mat ( m, n) ;
374
- let b = range_mat ( n, k) ;
366
+ let a = range_mat :: < f32 > ( m, n) ;
367
+ let b = range_mat :: < f32 > ( n, k) ;
375
368
let mut b = b / 4. ;
376
369
{
377
370
let mut c = b. column_mut ( 0 ) ;
@@ -395,8 +388,8 @@ fn mat_mul()
395
388
fn mat_mul_order ( )
396
389
{
397
390
let ( m, n, k) = ( 8 , 8 , 8 ) ;
398
- let a = range_mat ( m, n) ;
399
- let b = range_mat ( n, k) ;
391
+ let a = range_mat :: < f32 > ( m, n) ;
392
+ let b = range_mat :: < f32 > ( n, k) ;
400
393
let mut af = Array :: zeros ( a. dim ( ) . f ( ) ) ;
401
394
let mut bf = Array :: zeros ( b. dim ( ) . f ( ) ) ;
402
395
af. assign ( & a) ;
@@ -415,8 +408,8 @@ fn mat_mul_order()
415
408
fn mat_mul_shape_mismatch ( )
416
409
{
417
410
let ( m, k, k2, n) = ( 8 , 8 , 9 , 8 ) ;
418
- let a = range_mat ( m, k) ;
419
- let b = range_mat ( k2, n) ;
411
+ let a = range_mat :: < f32 > ( m, k) ;
412
+ let b = range_mat :: < f32 > ( k2, n) ;
420
413
a. dot ( & b) ;
421
414
}
422
415
@@ -426,9 +419,9 @@ fn mat_mul_shape_mismatch()
426
419
fn mat_mul_shape_mismatch_2 ( )
427
420
{
428
421
let ( m, k, k2, n) = ( 8 , 8 , 8 , 8 ) ;
429
- let a = range_mat ( m, k) ;
430
- let b = range_mat ( k2, n) ;
431
- let mut c = range_mat ( m, n + 1 ) ;
422
+ let a = range_mat :: < f32 > ( m, k) ;
423
+ let b = range_mat :: < f32 > ( k2, n) ;
424
+ let mut c = range_mat :: < f32 > ( m, n + 1 ) ;
432
425
general_mat_mul ( 1. , & a, & b, 1. , & mut c) ;
433
426
}
434
427
@@ -438,7 +431,7 @@ fn mat_mul_shape_mismatch_2()
438
431
fn mat_mul_broadcast ( )
439
432
{
440
433
let ( m, n, k) = ( 16 , 16 , 16 ) ;
441
- let a = range_mat ( m, n) ;
434
+ let a = range_mat :: < f32 > ( m, n) ;
442
435
let x1 = 1. ;
443
436
let x = Array :: from ( vec ! [ x1] ) ;
444
437
let b0 = x. broadcast ( ( n, k) ) . unwrap ( ) ;
@@ -458,8 +451,8 @@ fn mat_mul_broadcast()
458
451
fn mat_mul_rev ( )
459
452
{
460
453
let ( m, n, k) = ( 16 , 16 , 16 ) ;
461
- let a = range_mat ( m, n) ;
462
- let b = range_mat ( n, k) ;
454
+ let a = range_mat :: < f32 > ( m, n) ;
455
+ let b = range_mat :: < f32 > ( n, k) ;
463
456
let mut rev = Array :: zeros ( b. dim ( ) ) ;
464
457
let mut rev = rev. slice_mut ( s ! [ ..; -1 , ..] ) ;
465
458
rev. assign ( & b) ;
@@ -488,8 +481,8 @@ fn mat_mut_zero_len()
488
481
}
489
482
}
490
483
} ) ;
491
- mat_mul_zero_len ! ( range_mat) ;
492
- mat_mul_zero_len ! ( range_mat64 ) ;
484
+ mat_mul_zero_len ! ( range_mat:: < f32 > ) ;
485
+ mat_mul_zero_len ! ( range_mat :: < f64 > ) ;
493
486
mat_mul_zero_len ! ( range_i32) ;
494
487
}
495
488
@@ -528,9 +521,9 @@ fn scaled_add_2()
528
521
for & s1 in & [ 1 , 2 , -1 , -2 ] {
529
522
for & s2 in & [ 1 , 2 , -1 , -2 ] {
530
523
for & ( m, k, n, q) in & sizes {
531
- let mut a = range_mat64 ( m, k) ;
524
+ let mut a = range_mat :: < f64 > ( m, k) ;
532
525
let mut answer = a. clone ( ) ;
533
- let c = range_mat64 ( n, q) ;
526
+ let c = range_mat :: < f64 > ( n, q) ;
534
527
535
528
{
536
529
let mut av = a. slice_mut ( s ! [ ..; s1, ..; s2] ) ;
@@ -570,7 +563,7 @@ fn scaled_add_3()
570
563
for & s1 in & [ 1 , 2 , -1 , -2 ] {
571
564
for & s2 in & [ 1 , 2 , -1 , -2 ] {
572
565
for & ( m, k, n, q) in & sizes {
573
- let mut a = range_mat64 ( m, k) ;
566
+ let mut a = range_mat :: < f64 > ( m, k) ;
574
567
let mut answer = a. clone ( ) ;
575
568
let cdim = if n == 1 { vec ! [ q] } else { vec ! [ n, q] } ;
576
569
let cslice: Vec < SliceInfoElem > = if n == 1 {
@@ -582,7 +575,7 @@ fn scaled_add_3()
582
575
]
583
576
} ;
584
577
585
- let c = range_mat64 ( n, q) . into_shape_with_order ( cdim) . unwrap ( ) ;
578
+ let c = range_mat :: < f64 > ( n, q) . into_shape_with_order ( cdim) . unwrap ( ) ;
586
579
587
580
{
588
581
let mut av = a. slice_mut ( s ! [ ..; s1, ..; s2] ) ;
@@ -619,9 +612,9 @@ fn gen_mat_mul()
619
612
for & s1 in & [ 1 , 2 , -1 , -2 ] {
620
613
for & s2 in & [ 1 , 2 , -1 , -2 ] {
621
614
for & ( m, k, n) in & sizes {
622
- let a = range_mat64 ( m, k) ;
623
- let b = range_mat64 ( k, n) ;
624
- let mut c = range_mat64 ( m, n) ;
615
+ let a = range_mat :: < f64 > ( m, k) ;
616
+ let b = range_mat :: < f64 > ( k, n) ;
617
+ let mut c = range_mat :: < f64 > ( m, n) ;
625
618
let mut answer = c. clone ( ) ;
626
619
627
620
{
@@ -645,11 +638,11 @@ fn gen_mat_mul()
645
638
#[ test]
646
639
fn gemm_64_1_f ( )
647
640
{
648
- let a = range_mat64 ( 64 , 64 ) . reversed_axes ( ) ;
641
+ let a = range_mat :: < f64 > ( 64 , 64 ) . reversed_axes ( ) ;
649
642
let ( m, n) = a. dim ( ) ;
650
643
// m x n times n x 1 == m x 1
651
- let x = range_mat64 ( n, 1 ) ;
652
- let mut y = range_mat64 ( m, 1 ) ;
644
+ let x = range_mat :: < f64 > ( n, 1 ) ;
645
+ let mut y = range_mat :: < f64 > ( m, 1 ) ;
653
646
let answer = reference_mat_mul ( & a, & x) + & y;
654
647
general_mat_mul ( 1.0 , & a, & x, 1.0 , & mut y) ;
655
648
approx:: assert_relative_eq!( y, answer, epsilon = 1e-12 , max_relative = 1e-7 ) ;
@@ -728,11 +721,8 @@ fn gen_mat_vec_mul()
728
721
for & s1 in & [ 1 , 2 , -1 , -2 ] {
729
722
for & s2 in & [ 1 , 2 , -1 , -2 ] {
730
723
for & ( m, k) in & sizes {
731
- for & rev in & [ false , true ] {
732
- let mut a = range_mat64 ( m, k) ;
733
- if rev {
734
- a = a. reversed_axes ( ) ;
735
- }
724
+ for order in [ Order :: C , Order :: F ] {
725
+ let a = ArrayBuilder :: new ( ( m, k) ) . memory_order ( order) . build ( ) ;
736
726
let ( m, k) = a. dim ( ) ;
737
727
let b = range1_mat64 ( k) ;
738
728
let mut c = range1_mat64 ( m) ;
@@ -794,11 +784,8 @@ fn vec_mat_mul()
794
784
for & s1 in & [ 1 , 2 , -1 , -2 ] {
795
785
for & s2 in & [ 1 , 2 , -1 , -2 ] {
796
786
for & ( m, n) in & sizes {
797
- for & rev in & [ false , true ] {
798
- let mut b = range_mat64 ( m, n) ;
799
- if rev {
800
- b = b. reversed_axes ( ) ;
801
- }
787
+ for order in [ Order :: C , Order :: F ] {
788
+ let b = ArrayBuilder :: new ( ( m, n) ) . memory_order ( order) . build ( ) ;
802
789
let ( m, n) = b. dim ( ) ;
803
790
let a = range1_mat64 ( m) ;
804
791
let mut c = range1_mat64 ( n) ;
0 commit comments