Skip to content

Commit be22336

Browse files
committed
tests: Refactor to use ArrayBuilder more places
1 parent 0fdeabc commit be22336

File tree

3 files changed

+63
-97
lines changed

3 files changed

+63
-97
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ defmac = "0.2"
4848
quickcheck = { workspace = true }
4949
approx = { workspace = true, default-features = true }
5050
itertools = { workspace = true }
51+
ndarray-gen = { workspace = true }
5152

5253
[features]
5354
default = ["std"]

crates/blas-tests/tests/oper.rs

+21-43
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use defmac::defmac;
1818
use itertools::iproduct;
1919
use num_complex::Complex32;
2020
use num_complex::Complex64;
21+
use num_traits::Num;
2122

2223
#[test]
2324
fn mat_vec_product_1d()
@@ -49,46 +50,29 @@ fn mat_vec_product_1d_inverted_axis()
4950
assert_eq!(a.t().dot(&b), ans);
5051
}
5152

52-
fn range_mat(m: Ix, n: Ix) -> Array2<f32>
53+
fn range_mat<A: Num + Copy>(m: Ix, n: Ix) -> Array2<A>
5354
{
54-
Array::linspace(0., (m * n) as f32 - 1., m * n)
55-
.into_shape_with_order((m, n))
56-
.unwrap()
57-
}
58-
59-
fn range_mat64(m: Ix, n: Ix) -> Array2<f64>
60-
{
61-
Array::linspace(0., (m * n) as f64 - 1., m * n)
62-
.into_shape_with_order((m, n))
63-
.unwrap()
55+
ArrayBuilder::new((m, n)).build()
6456
}
6557

6658
fn range_mat_complex(m: Ix, n: Ix) -> Array2<Complex32>
6759
{
68-
Array::linspace(0., (m * n) as f32 - 1., m * n)
69-
.into_shape_with_order((m, n))
70-
.unwrap()
71-
.map(|&f| Complex32::new(f, 0.))
60+
ArrayBuilder::new((m, n)).build()
7261
}
7362

7463
fn range_mat_complex64(m: Ix, n: Ix) -> Array2<Complex64>
7564
{
76-
Array::linspace(0., (m * n) as f64 - 1., m * n)
77-
.into_shape_with_order((m, n))
78-
.unwrap()
79-
.map(|&f| Complex64::new(f, 0.))
65+
ArrayBuilder::new((m, n)).build()
8066
}
8167

8268
fn range1_mat64(m: Ix) -> Array1<f64>
8369
{
84-
Array::linspace(0., m as f64 - 1., m)
70+
ArrayBuilder::new(m).build()
8571
}
8672

8773
fn range_i32(m: Ix, n: Ix) -> Array2<i32>
8874
{
89-
Array::from_iter(0..(m * n) as i32)
90-
.into_shape_with_order((m, n))
91-
.unwrap()
75+
ArrayBuilder::new((m, n)).build()
9276
}
9377

9478
// simple, slow, correct (hopefully) mat mul
@@ -163,8 +147,8 @@ where
163147
fn mat_mul_order()
164148
{
165149
let (m, n, k) = (50, 50, 50);
166-
let a = range_mat(m, n);
167-
let b = range_mat(n, k);
150+
let a = range_mat::<f32>(m, n);
151+
let b = range_mat::<f32>(n, k);
168152
let mut af = Array::zeros(a.dim().f());
169153
let mut bf = Array::zeros(b.dim().f());
170154
af.assign(&a);
@@ -183,7 +167,7 @@ fn mat_mul_order()
183167
fn mat_mul_broadcast()
184168
{
185169
let (m, n, k) = (16, 16, 16);
186-
let a = range_mat(m, n);
170+
let a = range_mat::<f32>(m, n);
187171
let x1 = 1.;
188172
let x = Array::from(vec![x1]);
189173
let b0 = x.broadcast((n, k)).unwrap();
@@ -203,8 +187,8 @@ fn mat_mul_broadcast()
203187
fn mat_mul_rev()
204188
{
205189
let (m, n, k) = (16, 16, 16);
206-
let a = range_mat(m, n);
207-
let b = range_mat(n, k);
190+
let a = range_mat::<f32>(m, n);
191+
let b = range_mat::<f32>(n, k);
208192
let mut rev = Array::zeros(b.dim());
209193
let mut rev = rev.slice_mut(s![..;-1, ..]);
210194
rev.assign(&b);
@@ -233,8 +217,8 @@ fn mat_mut_zero_len()
233217
}
234218
}
235219
});
236-
mat_mul_zero_len!(range_mat);
237-
mat_mul_zero_len!(range_mat64);
220+
mat_mul_zero_len!(range_mat::<f32>);
221+
mat_mul_zero_len!(range_mat::<f64>);
238222
mat_mul_zero_len!(range_i32);
239223
}
240224

@@ -307,11 +291,11 @@ fn gen_mat_mul()
307291
#[test]
308292
fn gemm_64_1_f()
309293
{
310-
let a = range_mat64(64, 64).reversed_axes();
294+
let a = range_mat::<f64>(64, 64).reversed_axes();
311295
let (m, n) = a.dim();
312296
// m x n times n x 1 == m x 1
313-
let x = range_mat64(n, 1);
314-
let mut y = range_mat64(m, 1);
297+
let x = range_mat::<f64>(n, 1);
298+
let mut y = range_mat::<f64>(m, 1);
315299
let answer = reference_mat_mul(&a, &x) + &y;
316300
general_mat_mul(1.0, &a, &x, 1.0, &mut y);
317301
assert_relative_eq!(y, answer, epsilon = 1e-12, max_relative = 1e-7);
@@ -393,11 +377,8 @@ fn gen_mat_vec_mul()
393377
for &s1 in &[1, 2, -1, -2] {
394378
for &s2 in &[1, 2, -1, -2] {
395379
for &(m, k) in &sizes {
396-
for &rev in &[false, true] {
397-
let mut a = range_mat64(m, k);
398-
if rev {
399-
a = a.reversed_axes();
400-
}
380+
for order in [Order::C, Order::F] {
381+
let a = ArrayBuilder::new((m, k)).memory_order(order).build();
401382
let (m, k) = a.dim();
402383
let b = range1_mat64(k);
403384
let mut c = range1_mat64(m);
@@ -438,11 +419,8 @@ fn vec_mat_mul()
438419
for &s1 in &[1, 2, -1, -2] {
439420
for &s2 in &[1, 2, -1, -2] {
440421
for &(m, n) in &sizes {
441-
for &rev in &[false, true] {
442-
let mut b = range_mat64(m, n);
443-
if rev {
444-
b = b.reversed_axes();
445-
}
422+
for order in [Order::C, Order::F] {
423+
let b = ArrayBuilder::new((m, n)).memory_order(order).build();
446424
let (m, n) = b.dim();
447425
let a = range1_mat64(m);
448426
let mut c = range1_mat64(n);

tests/oper.rs

+41-54
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,17 @@
55
use ndarray::linalg::general_mat_mul;
66
use ndarray::linalg::kron;
77
use ndarray::prelude::*;
8+
#[cfg(feature = "approx")]
9+
use ndarray::Order;
810
use ndarray::{rcarr1, rcarr2};
911
use ndarray::{Data, LinalgScalar};
1012
use ndarray::{Ix, Ixs};
11-
use num_traits::Zero;
13+
use ndarray_gen::array_builder::ArrayBuilder;
1214

1315
use approx::assert_abs_diff_eq;
1416
use defmac::defmac;
17+
use num_traits::Num;
18+
use num_traits::Zero;
1519

1620
fn test_oper(op: &str, a: &[f32], b: &[f32], c: &[f32])
1721
{
@@ -271,31 +275,20 @@ fn product()
271275
}
272276
}
273277

274-
fn range_mat(m: Ix, n: Ix) -> Array2<f32>
278+
fn range_mat<A: Num + Copy>(m: Ix, n: Ix) -> Array2<A>
275279
{
276-
Array::linspace(0., (m * n) as f32 - 1., m * n)
277-
.into_shape_with_order((m, n))
278-
.unwrap()
279-
}
280-
281-
fn range_mat64(m: Ix, n: Ix) -> Array2<f64>
282-
{
283-
Array::linspace(0., (m * n) as f64 - 1., m * n)
284-
.into_shape_with_order((m, n))
285-
.unwrap()
280+
ArrayBuilder::new((m, n)).build()
286281
}
287282

288283
#[cfg(feature = "approx")]
289284
fn range1_mat64(m: Ix) -> Array1<f64>
290285
{
291-
Array::linspace(0., m as f64 - 1., m)
286+
ArrayBuilder::new(m).build()
292287
}
293288

294289
fn range_i32(m: Ix, n: Ix) -> Array2<i32>
295290
{
296-
Array::from_iter(0..(m * n) as i32)
297-
.into_shape_with_order((m, n))
298-
.unwrap()
291+
ArrayBuilder::new((m, n)).build()
299292
}
300293

301294
// simple, slow, correct (hopefully) mat mul
@@ -332,8 +325,8 @@ where
332325
fn mat_mul()
333326
{
334327
let (m, n, k) = (8, 8, 8);
335-
let a = range_mat(m, n);
336-
let b = range_mat(n, k);
328+
let a = range_mat::<f32>(m, n);
329+
let b = range_mat::<f32>(n, k);
337330
let mut b = b / 4.;
338331
{
339332
let mut c = b.column_mut(0);
@@ -351,8 +344,8 @@ fn mat_mul()
351344
assert_eq!(ab, af.dot(&bf));
352345

353346
let (m, n, k) = (10, 5, 11);
354-
let a = range_mat(m, n);
355-
let b = range_mat(n, k);
347+
let a = range_mat::<f32>(m, n);
348+
let b = range_mat::<f32>(n, k);
356349
let mut b = b / 4.;
357350
{
358351
let mut c = b.column_mut(0);
@@ -370,8 +363,8 @@ fn mat_mul()
370363
assert_eq!(ab, af.dot(&bf));
371364

372365
let (m, n, k) = (10, 8, 1);
373-
let a = range_mat(m, n);
374-
let b = range_mat(n, k);
366+
let a = range_mat::<f32>(m, n);
367+
let b = range_mat::<f32>(n, k);
375368
let mut b = b / 4.;
376369
{
377370
let mut c = b.column_mut(0);
@@ -395,8 +388,8 @@ fn mat_mul()
395388
fn mat_mul_order()
396389
{
397390
let (m, n, k) = (8, 8, 8);
398-
let a = range_mat(m, n);
399-
let b = range_mat(n, k);
391+
let a = range_mat::<f32>(m, n);
392+
let b = range_mat::<f32>(n, k);
400393
let mut af = Array::zeros(a.dim().f());
401394
let mut bf = Array::zeros(b.dim().f());
402395
af.assign(&a);
@@ -415,8 +408,8 @@ fn mat_mul_order()
415408
fn mat_mul_shape_mismatch()
416409
{
417410
let (m, k, k2, n) = (8, 8, 9, 8);
418-
let a = range_mat(m, k);
419-
let b = range_mat(k2, n);
411+
let a = range_mat::<f32>(m, k);
412+
let b = range_mat::<f32>(k2, n);
420413
a.dot(&b);
421414
}
422415

@@ -426,9 +419,9 @@ fn mat_mul_shape_mismatch()
426419
fn mat_mul_shape_mismatch_2()
427420
{
428421
let (m, k, k2, n) = (8, 8, 8, 8);
429-
let a = range_mat(m, k);
430-
let b = range_mat(k2, n);
431-
let mut c = range_mat(m, n + 1);
422+
let a = range_mat::<f32>(m, k);
423+
let b = range_mat::<f32>(k2, n);
424+
let mut c = range_mat::<f32>(m, n + 1);
432425
general_mat_mul(1., &a, &b, 1., &mut c);
433426
}
434427

@@ -438,7 +431,7 @@ fn mat_mul_shape_mismatch_2()
438431
fn mat_mul_broadcast()
439432
{
440433
let (m, n, k) = (16, 16, 16);
441-
let a = range_mat(m, n);
434+
let a = range_mat::<f32>(m, n);
442435
let x1 = 1.;
443436
let x = Array::from(vec![x1]);
444437
let b0 = x.broadcast((n, k)).unwrap();
@@ -458,8 +451,8 @@ fn mat_mul_broadcast()
458451
fn mat_mul_rev()
459452
{
460453
let (m, n, k) = (16, 16, 16);
461-
let a = range_mat(m, n);
462-
let b = range_mat(n, k);
454+
let a = range_mat::<f32>(m, n);
455+
let b = range_mat::<f32>(n, k);
463456
let mut rev = Array::zeros(b.dim());
464457
let mut rev = rev.slice_mut(s![..;-1, ..]);
465458
rev.assign(&b);
@@ -488,8 +481,8 @@ fn mat_mut_zero_len()
488481
}
489482
}
490483
});
491-
mat_mul_zero_len!(range_mat);
492-
mat_mul_zero_len!(range_mat64);
484+
mat_mul_zero_len!(range_mat::<f32>);
485+
mat_mul_zero_len!(range_mat::<f64>);
493486
mat_mul_zero_len!(range_i32);
494487
}
495488

@@ -528,9 +521,9 @@ fn scaled_add_2()
528521
for &s1 in &[1, 2, -1, -2] {
529522
for &s2 in &[1, 2, -1, -2] {
530523
for &(m, k, n, q) in &sizes {
531-
let mut a = range_mat64(m, k);
524+
let mut a = range_mat::<f64>(m, k);
532525
let mut answer = a.clone();
533-
let c = range_mat64(n, q);
526+
let c = range_mat::<f64>(n, q);
534527

535528
{
536529
let mut av = a.slice_mut(s![..;s1, ..;s2]);
@@ -570,7 +563,7 @@ fn scaled_add_3()
570563
for &s1 in &[1, 2, -1, -2] {
571564
for &s2 in &[1, 2, -1, -2] {
572565
for &(m, k, n, q) in &sizes {
573-
let mut a = range_mat64(m, k);
566+
let mut a = range_mat::<f64>(m, k);
574567
let mut answer = a.clone();
575568
let cdim = if n == 1 { vec![q] } else { vec![n, q] };
576569
let cslice: Vec<SliceInfoElem> = if n == 1 {
@@ -582,7 +575,7 @@ fn scaled_add_3()
582575
]
583576
};
584577

585-
let c = range_mat64(n, q).into_shape_with_order(cdim).unwrap();
578+
let c = range_mat::<f64>(n, q).into_shape_with_order(cdim).unwrap();
586579

587580
{
588581
let mut av = a.slice_mut(s![..;s1, ..;s2]);
@@ -619,9 +612,9 @@ fn gen_mat_mul()
619612
for &s1 in &[1, 2, -1, -2] {
620613
for &s2 in &[1, 2, -1, -2] {
621614
for &(m, k, n) in &sizes {
622-
let a = range_mat64(m, k);
623-
let b = range_mat64(k, n);
624-
let mut c = range_mat64(m, n);
615+
let a = range_mat::<f64>(m, k);
616+
let b = range_mat::<f64>(k, n);
617+
let mut c = range_mat::<f64>(m, n);
625618
let mut answer = c.clone();
626619

627620
{
@@ -645,11 +638,11 @@ fn gen_mat_mul()
645638
#[test]
646639
fn gemm_64_1_f()
647640
{
648-
let a = range_mat64(64, 64).reversed_axes();
641+
let a = range_mat::<f64>(64, 64).reversed_axes();
649642
let (m, n) = a.dim();
650643
// m x n times n x 1 == m x 1
651-
let x = range_mat64(n, 1);
652-
let mut y = range_mat64(m, 1);
644+
let x = range_mat::<f64>(n, 1);
645+
let mut y = range_mat::<f64>(m, 1);
653646
let answer = reference_mat_mul(&a, &x) + &y;
654647
general_mat_mul(1.0, &a, &x, 1.0, &mut y);
655648
approx::assert_relative_eq!(y, answer, epsilon = 1e-12, max_relative = 1e-7);
@@ -728,11 +721,8 @@ fn gen_mat_vec_mul()
728721
for &s1 in &[1, 2, -1, -2] {
729722
for &s2 in &[1, 2, -1, -2] {
730723
for &(m, k) in &sizes {
731-
for &rev in &[false, true] {
732-
let mut a = range_mat64(m, k);
733-
if rev {
734-
a = a.reversed_axes();
735-
}
724+
for order in [Order::C, Order::F] {
725+
let a = ArrayBuilder::new((m, k)).memory_order(order).build();
736726
let (m, k) = a.dim();
737727
let b = range1_mat64(k);
738728
let mut c = range1_mat64(m);
@@ -794,11 +784,8 @@ fn vec_mat_mul()
794784
for &s1 in &[1, 2, -1, -2] {
795785
for &s2 in &[1, 2, -1, -2] {
796786
for &(m, n) in &sizes {
797-
for &rev in &[false, true] {
798-
let mut b = range_mat64(m, n);
799-
if rev {
800-
b = b.reversed_axes();
801-
}
787+
for order in [Order::C, Order::F] {
788+
let b = ArrayBuilder::new((m, n)).memory_order(order).build();
802789
let (m, n) = b.dim();
803790
let a = range1_mat64(m);
804791
let mut c = range1_mat64(n);

0 commit comments

Comments
 (0)