diff --git a/Cargo.toml b/Cargo.toml index 4c34a11bc..50faacf19 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ approx = { workspace = true, optional = true } rayon = { version = "1.10.0", optional = true } # Use via the `blas` crate feature -cblas-sys = { version = "0.1.4", optional = true, default-features = false } +cblas-sys = { workspace = true, optional = true } libc = { version = "0.2.82", optional = true } matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm"] } @@ -47,7 +47,8 @@ rawpointer = { version = "0.2" } defmac = "0.2" quickcheck = { workspace = true } approx = { workspace = true, default-features = true } -itertools = { version = "0.13.0", default-features = false, features = ["use_std"] } +itertools = { workspace = true } +ndarray-gen = { workspace = true } [features] default = ["std"] @@ -73,6 +74,7 @@ matrixmultiply-threading = ["matrixmultiply/threading"] portable-atomic-critical-section = ["portable-atomic/critical-section"] + [target.'cfg(not(target_has_atomic = "ptr"))'.dependencies] portable-atomic = { version = "1.6.0" } portable-atomic-util = { version = "0.2.0", features = [ "alloc" ] } @@ -85,14 +87,16 @@ members = [ default-members = [ ".", "ndarray-rand", + "crates/ndarray-gen", "crates/numeric-tests", "crates/serialization-tests", - # exclude blas-tests that depends on BLAS install + # exclude blas-tests and blas-mock-tests that activate "blas" feature ] [workspace.dependencies] -ndarray = { version = "0.16", path = "." } +ndarray = { version = "0.16", path = ".", default-features = false } ndarray-rand = { path = "ndarray-rand" } +ndarray-gen = { path = "crates/ndarray-gen" } num-integer = { version = "0.1.39", default-features = false } num-traits = { version = "0.2", default-features = false } @@ -101,6 +105,8 @@ approx = { version = "0.5", default-features = false } quickcheck = { version = "1.0", default-features = false } rand = { version = "0.8.0", features = ["small_rng"] } rand_distr = { version = "0.4.0" } +itertools = { version = "0.13.0", default-features = false, features = ["use_std"] } +cblas-sys = { version = "0.1.4", default-features = false } [profile.bench] debug = true diff --git a/crates/blas-mock-tests/Cargo.toml b/crates/blas-mock-tests/Cargo.toml new file mode 100644 index 000000000..a12b78580 --- /dev/null +++ b/crates/blas-mock-tests/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "blas-mock-tests" +version = "0.1.0" +edition = "2018" +publish = false + +[lib] +test = false +doc = false +doctest = false + +[dependencies] +ndarray = { workspace = true, features = ["approx", "blas"] } +ndarray-gen = { workspace = true } +cblas-sys = { workspace = true } + +[dev-dependencies] +itertools = { workspace = true } diff --git a/crates/blas-mock-tests/src/lib.rs b/crates/blas-mock-tests/src/lib.rs new file mode 100644 index 000000000..11fc5975e --- /dev/null +++ b/crates/blas-mock-tests/src/lib.rs @@ -0,0 +1,100 @@ +//! Mock interfaces to BLAS + +use core::cell::RefCell; +use core::ffi::{c_double, c_float, c_int}; +use std::thread_local; + +use cblas_sys::{c_double_complex, c_float_complex, CBLAS_LAYOUT, CBLAS_TRANSPOSE}; + +thread_local! { + /// This counter is incremented every time a gemm function is called + pub static CALL_COUNT: RefCell = RefCell::new(0); +} + +#[rustfmt::skip] +#[no_mangle] +#[allow(unused)] +pub unsafe extern "C" fn cblas_sgemm( + layout: CBLAS_LAYOUT, + transa: CBLAS_TRANSPOSE, + transb: CBLAS_TRANSPOSE, + m: c_int, + n: c_int, + k: c_int, + alpha: c_float, + a: *const c_float, + lda: c_int, + b: *const c_float, + ldb: c_int, + beta: c_float, + c: *mut c_float, + ldc: c_int +) { + CALL_COUNT.with(|ctx| *ctx.borrow_mut() += 1); +} + +#[rustfmt::skip] +#[no_mangle] +#[allow(unused)] +pub unsafe extern "C" fn cblas_dgemm( + layout: CBLAS_LAYOUT, + transa: CBLAS_TRANSPOSE, + transb: CBLAS_TRANSPOSE, + m: c_int, + n: c_int, + k: c_int, + alpha: c_double, + a: *const c_double, + lda: c_int, + b: *const c_double, + ldb: c_int, + beta: c_double, + c: *mut c_double, + ldc: c_int +) { + CALL_COUNT.with(|ctx| *ctx.borrow_mut() += 1); +} + +#[rustfmt::skip] +#[no_mangle] +#[allow(unused)] +pub unsafe extern "C" fn cblas_cgemm( + layout: CBLAS_LAYOUT, + transa: CBLAS_TRANSPOSE, + transb: CBLAS_TRANSPOSE, + m: c_int, + n: c_int, + k: c_int, + alpha: *const c_float_complex, + a: *const c_float_complex, + lda: c_int, + b: *const c_float_complex, + ldb: c_int, + beta: *const c_float_complex, + c: *mut c_float_complex, + ldc: c_int +) { + CALL_COUNT.with(|ctx| *ctx.borrow_mut() += 1); +} + +#[rustfmt::skip] +#[no_mangle] +#[allow(unused)] +pub unsafe extern "C" fn cblas_zgemm( + layout: CBLAS_LAYOUT, + transa: CBLAS_TRANSPOSE, + transb: CBLAS_TRANSPOSE, + m: c_int, + n: c_int, + k: c_int, + alpha: *const c_double_complex, + a: *const c_double_complex, + lda: c_int, + b: *const c_double_complex, + ldb: c_int, + beta: *const c_double_complex, + c: *mut c_double_complex, + ldc: c_int +) { + CALL_COUNT.with(|ctx| *ctx.borrow_mut() += 1); +} diff --git a/crates/blas-mock-tests/tests/use-blas.rs b/crates/blas-mock-tests/tests/use-blas.rs new file mode 100644 index 000000000..217508af6 --- /dev/null +++ b/crates/blas-mock-tests/tests/use-blas.rs @@ -0,0 +1,88 @@ +extern crate ndarray; + +use ndarray::prelude::*; + +use blas_mock_tests::CALL_COUNT; +use ndarray::linalg::general_mat_mul; +use ndarray::Order; +use ndarray_gen::array_builder::ArrayBuilder; + +use itertools::iproduct; + +#[test] +fn test_gen_mat_mul_uses_blas() +{ + let alpha = 1.0; + let beta = 0.0; + + let sizes = vec![ + (8, 8, 8), + (10, 10, 10), + (8, 8, 1), + (1, 10, 10), + (10, 1, 10), + (10, 10, 1), + (1, 10, 1), + (10, 1, 1), + (1, 1, 10), + (4, 17, 3), + (17, 3, 22), + (19, 18, 2), + (16, 17, 15), + (15, 16, 17), + (67, 63, 62), + ]; + let strides = &[1, 2, -1, -2]; + let cf_order = [Order::C, Order::F]; + + // test different strides and memory orders + for &(m, k, n) in &sizes { + for (&s1, &s2) in iproduct!(strides, strides) { + for (ord1, ord2, ord3) in iproduct!(cf_order, cf_order, cf_order) { + println!("Case s1={}, s2={}, orders={:?}, {:?}, {:?}", s1, s2, ord1, ord2, ord3); + + let a = ArrayBuilder::new((m, k)).memory_order(ord1).build(); + let b = ArrayBuilder::new((k, n)).memory_order(ord2).build(); + let mut c = ArrayBuilder::new((m, n)).memory_order(ord3).build(); + + { + let av; + let bv; + let mut cv; + + if s1 != 1 || s2 != 1 { + av = a.slice(s![..;s1, ..;s2]); + bv = b.slice(s![..;s2, ..;s2]); + cv = c.slice_mut(s![..;s1, ..;s2]); + } else { + // different stride cases for slicing versus not sliced (for axes of + // len=1); so test not sliced here. + av = a.view(); + bv = b.view(); + cv = c.view_mut(); + } + + let pre_count = CALL_COUNT.with(|ctx| *ctx.borrow()); + general_mat_mul(alpha, &av, &bv, beta, &mut cv); + let after_count = CALL_COUNT.with(|ctx| *ctx.borrow()); + let ncalls = after_count - pre_count; + debug_assert!(ncalls <= 1); + + let always_uses_blas = s1 == 1 && s2 == 1; + + if always_uses_blas { + assert_eq!(ncalls, 1, "Contiguous arrays should use blas, orders={:?}", (ord1, ord2, ord3)); + } + + let should_use_blas = av.strides().iter().all(|&s| s > 0) + && bv.strides().iter().all(|&s| s > 0) + && cv.strides().iter().all(|&s| s > 0) + && av.strides().iter().any(|&s| s == 1) + && bv.strides().iter().any(|&s| s == 1) + && cv.strides().iter().any(|&s| s == 1); + assert_eq!(should_use_blas, ncalls > 0); + } + } + } + } +} diff --git a/crates/blas-tests/Cargo.toml b/crates/blas-tests/Cargo.toml index 0dbd9fd12..05a656000 100644 --- a/crates/blas-tests/Cargo.toml +++ b/crates/blas-tests/Cargo.toml @@ -11,7 +11,8 @@ doc = false doctest = false [dependencies] -ndarray = { workspace = true, features = ["approx"] } +ndarray = { workspace = true, features = ["approx", "blas"] } +ndarray-gen = { workspace = true } blas-src = { version = "0.10", optional = true } openblas-src = { version = "0.10", optional = true } @@ -23,6 +24,7 @@ defmac = "0.2" approx = { workspace = true } num-traits = { workspace = true } num-complex = { workspace = true } +itertools = { workspace = true } [features] # Just for making an example and to help testing, , multiple different possible diff --git a/crates/blas-tests/tests/oper.rs b/crates/blas-tests/tests/oper.rs index 3ed81915e..f1e1bc42b 100644 --- a/crates/blas-tests/tests/oper.rs +++ b/crates/blas-tests/tests/oper.rs @@ -9,12 +9,16 @@ use ndarray::prelude::*; use ndarray::linalg::general_mat_mul; use ndarray::linalg::general_mat_vec_mul; +use ndarray::Order; use ndarray::{Data, Ix, LinalgScalar}; +use ndarray_gen::array_builder::ArrayBuilder; use approx::assert_relative_eq; use defmac::defmac; +use itertools::iproduct; use num_complex::Complex32; use num_complex::Complex64; +use num_traits::Num; #[test] fn mat_vec_product_1d() @@ -46,46 +50,29 @@ fn mat_vec_product_1d_inverted_axis() assert_eq!(a.t().dot(&b), ans); } -fn range_mat(m: Ix, n: Ix) -> Array2 +fn range_mat(m: Ix, n: Ix) -> Array2 { - Array::linspace(0., (m * n) as f32 - 1., m * n) - .into_shape_with_order((m, n)) - .unwrap() -} - -fn range_mat64(m: Ix, n: Ix) -> Array2 -{ - Array::linspace(0., (m * n) as f64 - 1., m * n) - .into_shape_with_order((m, n)) - .unwrap() + ArrayBuilder::new((m, n)).build() } fn range_mat_complex(m: Ix, n: Ix) -> Array2 { - Array::linspace(0., (m * n) as f32 - 1., m * n) - .into_shape_with_order((m, n)) - .unwrap() - .map(|&f| Complex32::new(f, 0.)) + ArrayBuilder::new((m, n)).build() } fn range_mat_complex64(m: Ix, n: Ix) -> Array2 { - Array::linspace(0., (m * n) as f64 - 1., m * n) - .into_shape_with_order((m, n)) - .unwrap() - .map(|&f| Complex64::new(f, 0.)) + ArrayBuilder::new((m, n)).build() } fn range1_mat64(m: Ix) -> Array1 { - Array::linspace(0., m as f64 - 1., m) + ArrayBuilder::new(m).build() } fn range_i32(m: Ix, n: Ix) -> Array2 { - Array::from_iter(0..(m * n) as i32) - .into_shape_with_order((m, n)) - .unwrap() + ArrayBuilder::new((m, n)).build() } // simple, slow, correct (hopefully) mat mul @@ -160,8 +147,8 @@ where fn mat_mul_order() { let (m, n, k) = (50, 50, 50); - let a = range_mat(m, n); - let b = range_mat(n, k); + let a = range_mat::(m, n); + let b = range_mat::(n, k); let mut af = Array::zeros(a.dim().f()); let mut bf = Array::zeros(b.dim().f()); af.assign(&a); @@ -180,7 +167,7 @@ fn mat_mul_order() fn mat_mul_broadcast() { let (m, n, k) = (16, 16, 16); - let a = range_mat(m, n); + let a = range_mat::(m, n); let x1 = 1.; let x = Array::from(vec![x1]); let b0 = x.broadcast((n, k)).unwrap(); @@ -200,8 +187,8 @@ fn mat_mul_broadcast() fn mat_mul_rev() { let (m, n, k) = (16, 16, 16); - let a = range_mat(m, n); - let b = range_mat(n, k); + let a = range_mat::(m, n); + let b = range_mat::(n, k); let mut rev = Array::zeros(b.dim()); let mut rev = rev.slice_mut(s![..;-1, ..]); rev.assign(&b); @@ -230,8 +217,8 @@ fn mat_mut_zero_len() } } }); - mat_mul_zero_len!(range_mat); - mat_mul_zero_len!(range_mat64); + mat_mul_zero_len!(range_mat::); + mat_mul_zero_len!(range_mat::); mat_mul_zero_len!(range_i32); } @@ -243,7 +230,14 @@ fn gen_mat_mul() let sizes = vec![ (4, 4, 4), (8, 8, 8), - (17, 15, 16), + (10, 10, 10), + (8, 8, 1), + (1, 10, 10), + (10, 1, 10), + (10, 10, 1), + (1, 10, 1), + (10, 1, 1), + (1, 1, 10), (4, 17, 3), (17, 3, 22), (19, 18, 2), @@ -251,24 +245,41 @@ fn gen_mat_mul() (15, 16, 17), (67, 63, 62), ]; - // test different strides - for &s1 in &[1, 2, -1, -2] { - for &s2 in &[1, 2, -1, -2] { - for &(m, k, n) in &sizes { - let a = range_mat64(m, k); - let b = range_mat64(k, n); - let mut c = range_mat64(m, n); + let strides = &[1, 2, -1, -2]; + let cf_order = [Order::C, Order::F]; + + // test different strides and memory orders + for (&s1, &s2) in iproduct!(strides, strides) { + for &(m, k, n) in &sizes { + for (ord1, ord2, ord3) in iproduct!(cf_order, cf_order, cf_order) { + println!("Case s1={}, s2={}, orders={:?}, {:?}, {:?}", s1, s2, ord1, ord2, ord3); + let a = ArrayBuilder::new((m, k)).memory_order(ord1).build() * 0.5; + let b = ArrayBuilder::new((k, n)).memory_order(ord2).build(); + let mut c = ArrayBuilder::new((m, n)).memory_order(ord3).build(); + let mut answer = c.clone(); { - let a = a.slice(s![..;s1, ..;s2]); - let b = b.slice(s![..;s2, ..;s2]); - let mut cv = c.slice_mut(s![..;s1, ..;s2]); + let av; + let bv; + let mut cv; + + if s1 != 1 || s2 != 1 { + av = a.slice(s![..;s1, ..;s2]); + bv = b.slice(s![..;s2, ..;s2]); + cv = c.slice_mut(s![..;s1, ..;s2]); + } else { + // different stride cases for slicing versus not sliced (for axes of + // len=1); so test not sliced here. + av = a.view(); + bv = b.view(); + cv = c.view_mut(); + } - let answer_part = alpha * reference_mat_mul(&a, &b) + beta * &cv; + let answer_part = alpha * reference_mat_mul(&av, &bv) + beta * &cv; answer.slice_mut(s![..;s1, ..;s2]).assign(&answer_part); - general_mat_mul(alpha, &a, &b, beta, &mut cv); + general_mat_mul(alpha, &av, &bv, beta, &mut cv); } assert_relative_eq!(c, answer, epsilon = 1e-12, max_relative = 1e-7); } @@ -280,11 +291,11 @@ fn gen_mat_mul() #[test] fn gemm_64_1_f() { - let a = range_mat64(64, 64).reversed_axes(); + let a = range_mat::(64, 64).reversed_axes(); let (m, n) = a.dim(); // m x n times n x 1 == m x 1 - let x = range_mat64(n, 1); - let mut y = range_mat64(m, 1); + let x = range_mat::(n, 1); + let mut y = range_mat::(m, 1); let answer = reference_mat_mul(&a, &x) + &y; general_mat_mul(1.0, &a, &x, 1.0, &mut y); assert_relative_eq!(y, answer, epsilon = 1e-12, max_relative = 1e-7); @@ -366,11 +377,8 @@ fn gen_mat_vec_mul() for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { for &(m, k) in &sizes { - for &rev in &[false, true] { - let mut a = range_mat64(m, k); - if rev { - a = a.reversed_axes(); - } + for order in [Order::C, Order::F] { + let a = ArrayBuilder::new((m, k)).memory_order(order).build(); let (m, k) = a.dim(); let b = range1_mat64(k); let mut c = range1_mat64(m); @@ -411,11 +419,8 @@ fn vec_mat_mul() for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { for &(m, n) in &sizes { - for &rev in &[false, true] { - let mut b = range_mat64(m, n); - if rev { - b = b.reversed_axes(); - } + for order in [Order::C, Order::F] { + let b = ArrayBuilder::new((m, n)).memory_order(order).build(); let (m, n) = b.dim(); let a = range1_mat64(m); let mut c = range1_mat64(n); diff --git a/crates/ndarray-gen/Cargo.toml b/crates/ndarray-gen/Cargo.toml new file mode 100644 index 000000000..6818e4b65 --- /dev/null +++ b/crates/ndarray-gen/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "ndarray-gen" +version = "0.1.0" +edition = "2018" +publish = false + +[dependencies] +ndarray = { workspace = true, default-features = false } +num-traits = { workspace = true } diff --git a/crates/ndarray-gen/README.md b/crates/ndarray-gen/README.md new file mode 100644 index 000000000..7dd02320c --- /dev/null +++ b/crates/ndarray-gen/README.md @@ -0,0 +1,4 @@ + +## ndarray-gen + +Array generation functions, used for testing. diff --git a/crates/ndarray-gen/src/array_builder.rs b/crates/ndarray-gen/src/array_builder.rs new file mode 100644 index 000000000..a021e5252 --- /dev/null +++ b/crates/ndarray-gen/src/array_builder.rs @@ -0,0 +1,97 @@ +// Copyright 2024 bluss and ndarray developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use ndarray::Array; +use ndarray::Dimension; +use ndarray::IntoDimension; +use ndarray::Order; + +use num_traits::Num; + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct ArrayBuilder +{ + dim: D, + memory_order: Order, + generator: ElementGenerator, +} + +/// How to generate elements +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum ElementGenerator +{ + Sequential, + Zero, +} + +impl Default for ArrayBuilder +{ + fn default() -> Self + { + Self::new(D::zeros(D::NDIM.unwrap_or(1))) + } +} + +impl ArrayBuilder +where D: Dimension +{ + pub fn new(dim: impl IntoDimension) -> Self + { + ArrayBuilder { + dim: dim.into_dimension(), + memory_order: Order::C, + generator: ElementGenerator::Sequential, + } + } + + pub fn memory_order(mut self, order: Order) -> Self + { + self.memory_order = order; + self + } + + pub fn generator(mut self, generator: ElementGenerator) -> Self + { + self.generator = generator; + self + } + + pub fn build(self) -> Array + where T: Num + Clone + { + let mut current = T::zero(); + let size = self.dim.size(); + let use_zeros = self.generator == ElementGenerator::Zero; + Array::from_iter((0..size).map(|_| { + let ret = current.clone(); + if !use_zeros { + current = ret.clone() + T::one(); + } + ret + })) + .into_shape_with_order((self.dim, self.memory_order)) + .unwrap() + } +} + +#[test] +fn test_order() +{ + let (m, n) = (12, 13); + let c = ArrayBuilder::new((m, n)) + .memory_order(Order::C) + .build::(); + let f = ArrayBuilder::new((m, n)) + .memory_order(Order::F) + .build::(); + + assert_eq!(c.shape(), &[m, n]); + assert_eq!(f.shape(), &[m, n]); + assert_eq!(c.strides(), &[n as isize, 1]); + assert_eq!(f.strides(), &[1, m as isize]); +} diff --git a/crates/ndarray-gen/src/lib.rs b/crates/ndarray-gen/src/lib.rs new file mode 100644 index 000000000..7f9ca89fc --- /dev/null +++ b/crates/ndarray-gen/src/lib.rs @@ -0,0 +1,12 @@ +#![no_std] +// Copyright 2024 bluss and ndarray developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +/// Build ndarray arrays for test purposes + +pub mod array_builder; diff --git a/scripts/all-tests.sh b/scripts/all-tests.sh index 6f1fdf73a..4ececbcbd 100755 --- a/scripts/all-tests.sh +++ b/scripts/all-tests.sh @@ -19,6 +19,8 @@ cargo test -v --features "$FEATURES" $QC_FEAT cargo test -v -p ndarray -p ndarray-rand --release --features "$FEATURES" $QC_FEAT --lib --tests # BLAS tests +cargo test -p ndarray --lib -v --features blas +cargo test -p blas-mock-tests -v cargo test -p blas-tests -v --features blas-tests/openblas-system cargo test -p numeric-tests -v --features numeric-tests/test_blas diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index f3bedae71..243dc783b 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -24,15 +24,11 @@ use num_complex::{Complex32 as c32, Complex64 as c64}; #[cfg(feature = "blas")] use libc::c_int; -#[cfg(feature = "blas")] -use std::cmp; -#[cfg(feature = "blas")] -use std::mem::swap; #[cfg(feature = "blas")] use cblas_sys as blas_sys; #[cfg(feature = "blas")] -use cblas_sys::{CblasNoTrans, CblasRowMajor, CblasTrans, CBLAS_LAYOUT}; +use cblas_sys::{CblasNoTrans, CblasTrans, CBLAS_LAYOUT}; /// len of vector before we use blas #[cfg(feature = "blas")] @@ -379,8 +375,8 @@ use self::mat_mul_general as mat_mul_impl; #[cfg(feature = "blas")] fn mat_mul_impl( alpha: A, - lhs: &ArrayView2<'_, A>, - rhs: &ArrayView2<'_, A>, + a: &ArrayView2<'_, A>, + b: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>, ) where @@ -388,41 +384,56 @@ fn mat_mul_impl( { // size cutoff for using BLAS let cut = GEMM_BLAS_CUTOFF; - let ((mut m, a), (_, mut n)) = (lhs.dim(), rhs.dim()); - if !(m > cut || n > cut || a > cut) + let ((m, k), (k2, n)) = (a.dim(), b.dim()); + debug_assert_eq!(k, k2); + if !(m > cut || n > cut || k > cut) || !(same_type::() || same_type::() || same_type::() || same_type::()) { - return mat_mul_general(alpha, lhs, rhs, beta, c); + return mat_mul_general(alpha, a, b, beta, c); } - { - // Use `c` for c-order and `f` for an f-order matrix - // We can handle c * c, f * f generally and - // c * f and f * c if the `f` matrix is square. - let mut lhs_ = lhs.view(); - let mut rhs_ = rhs.view(); - let mut c_ = c.view_mut(); - let lhs_s0 = lhs_.strides()[0]; - let rhs_s0 = rhs_.strides()[0]; - let both_f = lhs_s0 == 1 && rhs_s0 == 1; - let mut lhs_trans = CblasNoTrans; - let mut rhs_trans = CblasNoTrans; - if both_f { - // A^t B^t = C^t => B A = C - let lhs_t = lhs_.reversed_axes(); - lhs_ = rhs_.reversed_axes(); - rhs_ = lhs_t; - c_ = c_.reversed_axes(); - swap(&mut m, &mut n); - } else if lhs_s0 == 1 && m == a { - lhs_ = lhs_.reversed_axes(); - lhs_trans = CblasTrans; - } else if rhs_s0 == 1 && a == n { - rhs_ = rhs_.reversed_axes(); - rhs_trans = CblasTrans; - } + + #[allow(clippy::never_loop)] // MSRV Rust 1.64 does not have break from block + 'blas_block: loop { + // Compute A B -> C + // We require for BLAS compatibility that: + // A, B, C are contiguous (stride=1) in their fastest dimension, + // but it can be either first or second axis (either rowmajor/"c" or colmajor/"f"). + // + // The "normal case" is CblasRowMajor for cblas. + // Select CblasRowMajor, CblasColMajor to fit C's memory order. + // + // Apply transpose to A, B as needed if they differ from the normal case. + // If C is CblasColMajor then transpose both A, B (again!) + + let (a_layout, a_axis, b_layout, b_axis, c_layout) = + match (get_blas_compatible_layout(a), + get_blas_compatible_layout(b), + get_blas_compatible_layout(c)) + { + (Some(a_layout), Some(b_layout), Some(c_layout @ MemoryOrder::C)) => { + (a_layout, a_layout.lead_axis(), + b_layout, b_layout.lead_axis(), c_layout) + }, + (Some(a_layout), Some(b_layout), Some(c_layout @ MemoryOrder::F)) => { + // CblasColMajor is the "other case" + // Mark a, b as having layouts opposite of what they were detected as, which + // ends up with the correct transpose setting w.r.t col major + (a_layout.opposite(), a_layout.lead_axis(), + b_layout.opposite(), b_layout.lead_axis(), c_layout) + }, + _ => break 'blas_block, + }; + + let a_trans = a_layout.to_cblas_transpose(); + let lda = blas_stride(&a, a_axis); + + let b_trans = b_layout.to_cblas_transpose(); + let ldb = blas_stride(&b, b_axis); + + let ldc = blas_stride(&c, c_layout.lead_axis()); macro_rules! gemm_scalar_cast { (f32, $var:ident) => { @@ -441,57 +452,40 @@ fn mat_mul_impl( macro_rules! gemm { ($ty:tt, $gemm:ident) => { - if blas_row_major_2d::<$ty, _>(&lhs_) - && blas_row_major_2d::<$ty, _>(&rhs_) - && blas_row_major_2d::<$ty, _>(&c_) - { - let (m, k) = match lhs_trans { - CblasNoTrans => lhs_.dim(), - _ => { - let (rows, cols) = lhs_.dim(); - (cols, rows) - } - }; - let n = match rhs_trans { - CblasNoTrans => rhs_.raw_dim()[1], - _ => rhs_.raw_dim()[0], - }; - // adjust strides, these may [1, 1] for column matrices - let lhs_stride = cmp::max(lhs_.strides()[0] as blas_index, k as blas_index); - let rhs_stride = cmp::max(rhs_.strides()[0] as blas_index, n as blas_index); - let c_stride = cmp::max(c_.strides()[0] as blas_index, n as blas_index); - + if same_type::() { // gemm is C ← αA^Op B^Op + βC // Where Op is notrans/trans/conjtrans unsafe { blas_sys::$gemm( - CblasRowMajor, - lhs_trans, - rhs_trans, + c_layout.to_cblas_layout(), + a_trans, + b_trans, m as blas_index, // m, rows of Op(a) n as blas_index, // n, cols of Op(b) k as blas_index, // k, cols of Op(a) gemm_scalar_cast!($ty, alpha), // alpha - lhs_.ptr.as_ptr() as *const _, // a - lhs_stride, // lda - rhs_.ptr.as_ptr() as *const _, // b - rhs_stride, // ldb + a.ptr.as_ptr() as *const _, // a + lda, // lda + b.ptr.as_ptr() as *const _, // b + ldb, // ldb gemm_scalar_cast!($ty, beta), // beta - c_.ptr.as_ptr() as *mut _, // c - c_stride, // ldc + c.ptr.as_ptr() as *mut _, // c + ldc, // ldc ); } return; } }; } + gemm!(f32, cblas_sgemm); gemm!(f64, cblas_dgemm); - gemm!(c32, cblas_cgemm); gemm!(c64, cblas_zgemm); + + break 'blas_block; } - mat_mul_general(alpha, lhs, rhs, beta, c) + mat_mul_general(alpha, a, b, beta, c) } /// C ← α A B + β C @@ -693,46 +687,51 @@ unsafe fn general_mat_vec_mul_impl( #[cfg(feature = "blas")] macro_rules! gemv { ($ty:ty, $gemv:ident) => { - if let Some(layout) = blas_layout::<$ty, _>(&a) { - if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y) { - // Determine stride between rows or columns. Note that the stride is - // adjusted to at least `k` or `m` to handle the case of a matrix with a - // trivial (length 1) dimension, since the stride for the trivial dimension - // may be arbitrary. - let a_trans = CblasNoTrans; - let a_stride = match layout { - CBLAS_LAYOUT::CblasRowMajor => { - a.strides()[0].max(k as isize) as blas_index - } - CBLAS_LAYOUT::CblasColMajor => { - a.strides()[1].max(m as isize) as blas_index - } - }; - - // Low addr in memory pointers required for x, y - let x_offset = offset_from_low_addr_ptr_to_logical_ptr(&x.dim, &x.strides); - let x_ptr = x.ptr.as_ptr().sub(x_offset); - let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.dim, &y.strides); - let y_ptr = y.ptr.as_ptr().sub(y_offset); - - let x_stride = x.strides()[0] as blas_index; - let y_stride = y.strides()[0] as blas_index; - - blas_sys::$gemv( - layout, - a_trans, - m as blas_index, // m, rows of Op(a) - k as blas_index, // n, cols of Op(a) - cast_as(&alpha), // alpha - a.ptr.as_ptr() as *const _, // a - a_stride, // lda - x_ptr as *const _, // x - x_stride, - cast_as(&beta), // beta - y_ptr as *mut _, // y - y_stride, - ); - return; + if same_type::() { + if let Some(layout) = get_blas_compatible_layout(&a) { + if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y) { + // Determine stride between rows or columns. Note that the stride is + // adjusted to at least `k` or `m` to handle the case of a matrix with a + // trivial (length 1) dimension, since the stride for the trivial dimension + // may be arbitrary. + let a_trans = CblasNoTrans; + + let (a_stride, cblas_layout) = match layout { + MemoryOrder::C => { + (a.strides()[0].max(k as isize) as blas_index, + CBLAS_LAYOUT::CblasRowMajor) + } + MemoryOrder::F => { + (a.strides()[1].max(m as isize) as blas_index, + CBLAS_LAYOUT::CblasColMajor) + } + }; + + // Low addr in memory pointers required for x, y + let x_offset = offset_from_low_addr_ptr_to_logical_ptr(&x.dim, &x.strides); + let x_ptr = x.ptr.as_ptr().sub(x_offset); + let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.dim, &y.strides); + let y_ptr = y.ptr.as_ptr().sub(y_offset); + + let x_stride = x.strides()[0] as blas_index; + let y_stride = y.strides()[0] as blas_index; + + blas_sys::$gemv( + cblas_layout, + a_trans, + m as blas_index, // m, rows of Op(a) + k as blas_index, // n, cols of Op(a) + cast_as(&alpha), // alpha + a.ptr.as_ptr() as *const _, // a + a_stride, // lda + x_ptr as *const _, // x + x_stride, + cast_as(&beta), // beta + y_ptr as *mut _, // y + y_stride, + ); + return; + } } } }; @@ -834,6 +833,8 @@ where } #[cfg(feature = "blas")] +#[derive(Copy, Clone)] +#[cfg_attr(test, derive(PartialEq, Eq, Debug))] enum MemoryOrder { C, @@ -841,29 +842,43 @@ enum MemoryOrder } #[cfg(feature = "blas")] -fn blas_row_major_2d(a: &ArrayBase) -> bool -where - S: Data, - A: 'static, - S::Elem: 'static, +impl MemoryOrder { - if !same_type::() { - return false; + #[inline] + /// Axis of leading stride (opposite of contiguous axis) + fn lead_axis(self) -> usize + { + match self { + MemoryOrder::C => 0, + MemoryOrder::F => 1, + } } - is_blas_2d(&a.dim, &a.strides, MemoryOrder::C) -} -#[cfg(feature = "blas")] -fn blas_column_major_2d(a: &ArrayBase) -> bool -where - S: Data, - A: 'static, - S::Elem: 'static, -{ - if !same_type::() { - return false; + /// Get opposite memory order + #[inline] + fn opposite(self) -> Self + { + match self { + MemoryOrder::C => MemoryOrder::F, + MemoryOrder::F => MemoryOrder::C, + } + } + + fn to_cblas_transpose(self) -> cblas_sys::CBLAS_TRANSPOSE + { + match self { + MemoryOrder::C => CblasNoTrans, + MemoryOrder::F => CblasTrans, + } + } + + fn to_cblas_layout(self) -> CBLAS_LAYOUT + { + match self { + MemoryOrder::C => CBLAS_LAYOUT::CblasRowMajor, + MemoryOrder::F => CBLAS_LAYOUT::CblasColMajor, + } } - is_blas_2d(&a.dim, &a.strides, MemoryOrder::F) } #[cfg(feature = "blas")] @@ -872,41 +887,102 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool let (m, n) = dim.into_pattern(); let s0 = stride[0] as isize; let s1 = stride[1] as isize; - let (inner_stride, outer_dim) = match order { - MemoryOrder::C => (s1, n), - MemoryOrder::F => (s0, m), + let (inner_stride, outer_stride, inner_dim, outer_dim) = match order { + MemoryOrder::C => (s1, s0, m, n), + MemoryOrder::F => (s0, s1, n, m), }; + if !(inner_stride == 1 || outer_dim == 1) { return false; } + if s0 < 1 || s1 < 1 { return false; } + if (s0 > blas_index::MAX as isize || s0 < blas_index::MIN as isize) || (s1 > blas_index::MAX as isize || s1 < blas_index::MIN as isize) { return false; } + + // leading stride must >= the dimension (no broadcasting/aliasing) + if inner_dim > 1 && (outer_stride as usize) < outer_dim { + return false; + } + if m > blas_index::MAX as usize || n > blas_index::MAX as usize { return false; } + true } +/// Get BLAS compatible layout if any (C or F, preferring the former) +#[cfg(feature = "blas")] +fn get_blas_compatible_layout(a: &ArrayBase) -> Option +where S: Data +{ + if is_blas_2d(&a.dim, &a.strides, MemoryOrder::C) { + Some(MemoryOrder::C) + } else if is_blas_2d(&a.dim, &a.strides, MemoryOrder::F) { + Some(MemoryOrder::F) + } else { + None + } +} + +/// `a` should be blas compatible. +/// axis: 0 or 1. +/// +/// Return leading stride (lda, ldb, ldc) of array #[cfg(feature = "blas")] -fn blas_layout(a: &ArrayBase) -> Option +fn blas_stride(a: &ArrayBase, axis: usize) -> blas_index +where S: Data +{ + debug_assert!(axis <= 1); + let other_axis = 1 - axis; + let len_this = a.shape()[axis]; + let len_other = a.shape()[other_axis]; + let stride = a.strides()[axis]; + + // if current axis has length == 1, then stride does not matter for ndarray + // but for BLAS we need a stride that makes sense, i.e. it's >= the other axis + + // cast: a should already be blas compatible + (if len_this <= 1 { + Ord::max(stride, len_other as isize) + } else { + stride + }) as blas_index +} + +#[cfg(test)] +#[cfg(feature = "blas")] +fn blas_row_major_2d(a: &ArrayBase) -> bool where S: Data, A: 'static, S::Elem: 'static, { - if blas_row_major_2d::(a) { - Some(CBLAS_LAYOUT::CblasRowMajor) - } else if blas_column_major_2d::(a) { - Some(CBLAS_LAYOUT::CblasColMajor) - } else { - None + if !same_type::() { + return false; + } + is_blas_2d(&a.dim, &a.strides, MemoryOrder::C) +} + +#[cfg(test)] +#[cfg(feature = "blas")] +fn blas_column_major_2d(a: &ArrayBase) -> bool +where + S: Data, + A: 'static, + S::Elem: 'static, +{ + if !same_type::() { + return false; } + is_blas_2d(&a.dim, &a.strides, MemoryOrder::F) } #[cfg(test)] @@ -964,4 +1040,64 @@ mod blas_tests assert!(!blas_row_major_2d::(&m)); assert!(blas_column_major_2d::(&m)); } + + #[test] + fn blas_row_major_2d_skip_rows_ok() + { + let m: Array2 = Array2::zeros((5, 5)); + let mv = m.slice(s![..;2, ..]); + assert!(blas_row_major_2d::(&mv)); + assert!(!blas_column_major_2d::(&mv)); + } + + #[test] + fn blas_row_major_2d_skip_columns_fail() + { + let m: Array2 = Array2::zeros((5, 5)); + let mv = m.slice(s![.., ..;2]); + assert!(!blas_row_major_2d::(&mv)); + assert!(!blas_column_major_2d::(&mv)); + } + + #[test] + fn blas_col_major_2d_skip_columns_ok() + { + let m: Array2 = Array2::zeros((5, 5).f()); + let mv = m.slice(s![.., ..;2]); + assert!(blas_column_major_2d::(&mv)); + assert!(!blas_row_major_2d::(&mv)); + } + + #[test] + fn blas_col_major_2d_skip_rows_fail() + { + let m: Array2 = Array2::zeros((5, 5).f()); + let mv = m.slice(s![..;2, ..]); + assert!(!blas_column_major_2d::(&mv)); + assert!(!blas_row_major_2d::(&mv)); + } + + #[test] + fn blas_too_short_stride() + { + // leading stride must be longer than the other dimension + // Example, in a 5 x 5 matrix, the leading stride must be >= 5 for BLAS. + + const N: usize = 5; + const MAXSTRIDE: usize = N + 2; + let mut data = [0; MAXSTRIDE * N]; + let mut iter = 0..data.len(); + data.fill_with(|| iter.next().unwrap()); + + for stride in 1..=MAXSTRIDE { + let m = ArrayView::from_shape((N, N).strides((stride, 1)), &data).unwrap(); + eprintln!("{:?}", m); + + if stride < N { + assert_eq!(get_blas_compatible_layout(&m), None); + } else { + assert_eq!(get_blas_compatible_layout(&m), Some(MemoryOrder::C)); + } + } + } } diff --git a/tests/oper.rs b/tests/oper.rs index 294a762c6..5e3e669d0 100644 --- a/tests/oper.rs +++ b/tests/oper.rs @@ -5,13 +5,17 @@ use ndarray::linalg::general_mat_mul; use ndarray::linalg::kron; use ndarray::prelude::*; +#[cfg(feature = "approx")] +use ndarray::Order; use ndarray::{rcarr1, rcarr2}; use ndarray::{Data, LinalgScalar}; use ndarray::{Ix, Ixs}; -use num_traits::Zero; +use ndarray_gen::array_builder::ArrayBuilder; use approx::assert_abs_diff_eq; use defmac::defmac; +use num_traits::Num; +use num_traits::Zero; fn test_oper(op: &str, a: &[f32], b: &[f32], c: &[f32]) { @@ -271,31 +275,20 @@ fn product() } } -fn range_mat(m: Ix, n: Ix) -> Array2 +fn range_mat(m: Ix, n: Ix) -> Array2 { - Array::linspace(0., (m * n) as f32 - 1., m * n) - .into_shape_with_order((m, n)) - .unwrap() -} - -fn range_mat64(m: Ix, n: Ix) -> Array2 -{ - Array::linspace(0., (m * n) as f64 - 1., m * n) - .into_shape_with_order((m, n)) - .unwrap() + ArrayBuilder::new((m, n)).build() } #[cfg(feature = "approx")] fn range1_mat64(m: Ix) -> Array1 { - Array::linspace(0., m as f64 - 1., m) + ArrayBuilder::new(m).build() } fn range_i32(m: Ix, n: Ix) -> Array2 { - Array::from_iter(0..(m * n) as i32) - .into_shape_with_order((m, n)) - .unwrap() + ArrayBuilder::new((m, n)).build() } // simple, slow, correct (hopefully) mat mul @@ -332,8 +325,8 @@ where fn mat_mul() { let (m, n, k) = (8, 8, 8); - let a = range_mat(m, n); - let b = range_mat(n, k); + let a = range_mat::(m, n); + let b = range_mat::(n, k); let mut b = b / 4.; { let mut c = b.column_mut(0); @@ -351,8 +344,8 @@ fn mat_mul() assert_eq!(ab, af.dot(&bf)); let (m, n, k) = (10, 5, 11); - let a = range_mat(m, n); - let b = range_mat(n, k); + let a = range_mat::(m, n); + let b = range_mat::(n, k); let mut b = b / 4.; { let mut c = b.column_mut(0); @@ -370,8 +363,8 @@ fn mat_mul() assert_eq!(ab, af.dot(&bf)); let (m, n, k) = (10, 8, 1); - let a = range_mat(m, n); - let b = range_mat(n, k); + let a = range_mat::(m, n); + let b = range_mat::(n, k); let mut b = b / 4.; { let mut c = b.column_mut(0); @@ -395,8 +388,8 @@ fn mat_mul() fn mat_mul_order() { let (m, n, k) = (8, 8, 8); - let a = range_mat(m, n); - let b = range_mat(n, k); + let a = range_mat::(m, n); + let b = range_mat::(n, k); let mut af = Array::zeros(a.dim().f()); let mut bf = Array::zeros(b.dim().f()); af.assign(&a); @@ -415,8 +408,8 @@ fn mat_mul_order() fn mat_mul_shape_mismatch() { let (m, k, k2, n) = (8, 8, 9, 8); - let a = range_mat(m, k); - let b = range_mat(k2, n); + let a = range_mat::(m, k); + let b = range_mat::(k2, n); a.dot(&b); } @@ -426,9 +419,9 @@ fn mat_mul_shape_mismatch() fn mat_mul_shape_mismatch_2() { let (m, k, k2, n) = (8, 8, 8, 8); - let a = range_mat(m, k); - let b = range_mat(k2, n); - let mut c = range_mat(m, n + 1); + let a = range_mat::(m, k); + let b = range_mat::(k2, n); + let mut c = range_mat::(m, n + 1); general_mat_mul(1., &a, &b, 1., &mut c); } @@ -438,7 +431,7 @@ fn mat_mul_shape_mismatch_2() fn mat_mul_broadcast() { let (m, n, k) = (16, 16, 16); - let a = range_mat(m, n); + let a = range_mat::(m, n); let x1 = 1.; let x = Array::from(vec![x1]); let b0 = x.broadcast((n, k)).unwrap(); @@ -458,8 +451,8 @@ fn mat_mul_broadcast() fn mat_mul_rev() { let (m, n, k) = (16, 16, 16); - let a = range_mat(m, n); - let b = range_mat(n, k); + let a = range_mat::(m, n); + let b = range_mat::(n, k); let mut rev = Array::zeros(b.dim()); let mut rev = rev.slice_mut(s![..;-1, ..]); rev.assign(&b); @@ -488,8 +481,8 @@ fn mat_mut_zero_len() } } }); - mat_mul_zero_len!(range_mat); - mat_mul_zero_len!(range_mat64); + mat_mul_zero_len!(range_mat::); + mat_mul_zero_len!(range_mat::); mat_mul_zero_len!(range_i32); } @@ -528,9 +521,9 @@ fn scaled_add_2() for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { for &(m, k, n, q) in &sizes { - let mut a = range_mat64(m, k); + let mut a = range_mat::(m, k); let mut answer = a.clone(); - let c = range_mat64(n, q); + let c = range_mat::(n, q); { let mut av = a.slice_mut(s![..;s1, ..;s2]); @@ -570,7 +563,7 @@ fn scaled_add_3() for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { for &(m, k, n, q) in &sizes { - let mut a = range_mat64(m, k); + let mut a = range_mat::(m, k); let mut answer = a.clone(); let cdim = if n == 1 { vec![q] } else { vec![n, q] }; let cslice: Vec = if n == 1 { @@ -582,7 +575,7 @@ fn scaled_add_3() ] }; - let c = range_mat64(n, q).into_shape_with_order(cdim).unwrap(); + let c = range_mat::(n, q).into_shape_with_order(cdim).unwrap(); { let mut av = a.slice_mut(s![..;s1, ..;s2]); @@ -619,9 +612,9 @@ fn gen_mat_mul() for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { for &(m, k, n) in &sizes { - let a = range_mat64(m, k); - let b = range_mat64(k, n); - let mut c = range_mat64(m, n); + let a = range_mat::(m, k); + let b = range_mat::(k, n); + let mut c = range_mat::(m, n); let mut answer = c.clone(); { @@ -645,11 +638,11 @@ fn gen_mat_mul() #[test] fn gemm_64_1_f() { - let a = range_mat64(64, 64).reversed_axes(); + let a = range_mat::(64, 64).reversed_axes(); let (m, n) = a.dim(); // m x n times n x 1 == m x 1 - let x = range_mat64(n, 1); - let mut y = range_mat64(m, 1); + let x = range_mat::(n, 1); + let mut y = range_mat::(m, 1); let answer = reference_mat_mul(&a, &x) + &y; general_mat_mul(1.0, &a, &x, 1.0, &mut y); approx::assert_relative_eq!(y, answer, epsilon = 1e-12, max_relative = 1e-7); @@ -728,11 +721,8 @@ fn gen_mat_vec_mul() for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { for &(m, k) in &sizes { - for &rev in &[false, true] { - let mut a = range_mat64(m, k); - if rev { - a = a.reversed_axes(); - } + for order in [Order::C, Order::F] { + let a = ArrayBuilder::new((m, k)).memory_order(order).build(); let (m, k) = a.dim(); let b = range1_mat64(k); let mut c = range1_mat64(m); @@ -794,11 +784,8 @@ fn vec_mat_mul() for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { for &(m, n) in &sizes { - for &rev in &[false, true] { - let mut b = range_mat64(m, n); - if rev { - b = b.reversed_axes(); - } + for order in [Order::C, Order::F] { + let b = ArrayBuilder::new((m, n)).memory_order(order).build(); let (m, n) = b.dim(); let a = range1_mat64(m); let mut c = range1_mat64(n);