Skip to content

Commit

Permalink
tests: Refactor to use ArrayBuilder more places
Browse files Browse the repository at this point in the history
  • Loading branch information
bluss committed Aug 7, 2024
1 parent 0fdeabc commit 5c8b9de
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 99 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ defmac = "0.2"
quickcheck = { workspace = true }
approx = { workspace = true, default-features = true }
itertools = { workspace = true }
ndarray-gen = { workspace = true }

[features]
default = ["std"]
Expand Down Expand Up @@ -93,7 +94,7 @@ default-members = [
]

[workspace.dependencies]
ndarray = { version = "0.16", path = "." }
ndarray = { version = "0.16", path = ".", default-features = false }
ndarray-rand = { path = "ndarray-rand" }
ndarray-gen = { path = "crates/ndarray-gen" }

Expand Down
64 changes: 21 additions & 43 deletions crates/blas-tests/tests/oper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use defmac::defmac;
use itertools::iproduct;
use num_complex::Complex32;
use num_complex::Complex64;
use num_traits::Num;

#[test]
fn mat_vec_product_1d()
Expand Down Expand Up @@ -49,46 +50,29 @@ fn mat_vec_product_1d_inverted_axis()
assert_eq!(a.t().dot(&b), ans);
}

fn range_mat(m: Ix, n: Ix) -> Array2<f32>
fn range_mat<A: Num + Copy>(m: Ix, n: Ix) -> Array2<A>
{
Array::linspace(0., (m * n) as f32 - 1., m * n)
.into_shape_with_order((m, n))
.unwrap()
}

fn range_mat64(m: Ix, n: Ix) -> Array2<f64>
{
Array::linspace(0., (m * n) as f64 - 1., m * n)
.into_shape_with_order((m, n))
.unwrap()
ArrayBuilder::new((m, n)).build()
}

fn range_mat_complex(m: Ix, n: Ix) -> Array2<Complex32>
{
Array::linspace(0., (m * n) as f32 - 1., m * n)
.into_shape_with_order((m, n))
.unwrap()
.map(|&f| Complex32::new(f, 0.))
ArrayBuilder::new((m, n)).build()
}

fn range_mat_complex64(m: Ix, n: Ix) -> Array2<Complex64>
{
Array::linspace(0., (m * n) as f64 - 1., m * n)
.into_shape_with_order((m, n))
.unwrap()
.map(|&f| Complex64::new(f, 0.))
ArrayBuilder::new((m, n)).build()
}

fn range1_mat64(m: Ix) -> Array1<f64>
{
Array::linspace(0., m as f64 - 1., m)
ArrayBuilder::new(m).build()
}

fn range_i32(m: Ix, n: Ix) -> Array2<i32>
{
Array::from_iter(0..(m * n) as i32)
.into_shape_with_order((m, n))
.unwrap()
ArrayBuilder::new((m, n)).build()
}

// simple, slow, correct (hopefully) mat mul
Expand Down Expand Up @@ -163,8 +147,8 @@ where
fn mat_mul_order()
{
let (m, n, k) = (50, 50, 50);
let a = range_mat(m, n);
let b = range_mat(n, k);
let a = range_mat::<f32>(m, n);
let b = range_mat::<f32>(n, k);
let mut af = Array::zeros(a.dim().f());
let mut bf = Array::zeros(b.dim().f());
af.assign(&a);
Expand All @@ -183,7 +167,7 @@ fn mat_mul_order()
fn mat_mul_broadcast()
{
let (m, n, k) = (16, 16, 16);
let a = range_mat(m, n);
let a = range_mat::<f32>(m, n);
let x1 = 1.;
let x = Array::from(vec![x1]);
let b0 = x.broadcast((n, k)).unwrap();
Expand All @@ -203,8 +187,8 @@ fn mat_mul_broadcast()
fn mat_mul_rev()
{
let (m, n, k) = (16, 16, 16);
let a = range_mat(m, n);
let b = range_mat(n, k);
let a = range_mat::<f32>(m, n);
let b = range_mat::<f32>(n, k);
let mut rev = Array::zeros(b.dim());
let mut rev = rev.slice_mut(s![..;-1, ..]);
rev.assign(&b);
Expand Down Expand Up @@ -233,8 +217,8 @@ fn mat_mut_zero_len()
}
}
});
mat_mul_zero_len!(range_mat);
mat_mul_zero_len!(range_mat64);
mat_mul_zero_len!(range_mat::<f32>);
mat_mul_zero_len!(range_mat::<f64>);
mat_mul_zero_len!(range_i32);
}

Expand Down Expand Up @@ -307,11 +291,11 @@ fn gen_mat_mul()
#[test]
fn gemm_64_1_f()
{
let a = range_mat64(64, 64).reversed_axes();
let a = range_mat::<f64>(64, 64).reversed_axes();
let (m, n) = a.dim();
// m x n times n x 1 == m x 1
let x = range_mat64(n, 1);
let mut y = range_mat64(m, 1);
let x = range_mat::<f64>(n, 1);
let mut y = range_mat::<f64>(m, 1);
let answer = reference_mat_mul(&a, &x) + &y;
general_mat_mul(1.0, &a, &x, 1.0, &mut y);
assert_relative_eq!(y, answer, epsilon = 1e-12, max_relative = 1e-7);
Expand Down Expand Up @@ -393,11 +377,8 @@ fn gen_mat_vec_mul()
for &s1 in &[1, 2, -1, -2] {
for &s2 in &[1, 2, -1, -2] {
for &(m, k) in &sizes {
for &rev in &[false, true] {
let mut a = range_mat64(m, k);
if rev {
a = a.reversed_axes();
}
for order in [Order::C, Order::F] {
let a = ArrayBuilder::new((m, k)).memory_order(order).build();
let (m, k) = a.dim();
let b = range1_mat64(k);
let mut c = range1_mat64(m);
Expand Down Expand Up @@ -438,11 +419,8 @@ fn vec_mat_mul()
for &s1 in &[1, 2, -1, -2] {
for &s2 in &[1, 2, -1, -2] {
for &(m, n) in &sizes {
for &rev in &[false, true] {
let mut b = range_mat64(m, n);
if rev {
b = b.reversed_axes();
}
for order in [Order::C, Order::F] {
let b = ArrayBuilder::new((m, n)).memory_order(order).build();
let (m, n) = b.dim();
let a = range1_mat64(m);
let mut c = range1_mat64(n);
Expand Down
2 changes: 1 addition & 1 deletion crates/ndarray-gen/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ edition = "2018"
publish = false

[dependencies]
ndarray = { workspace = true }
ndarray = { workspace = true, default-features = false }
num-traits = { workspace = true }
1 change: 1 addition & 0 deletions crates/ndarray-gen/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![no_std]
// Copyright 2024 bluss and ndarray developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
Expand Down
Loading

0 comments on commit 5c8b9de

Please sign in to comment.