Skip to content

Commit d992a13

Browse files
committed
blas: Test that matrix multiply calls BLAS
Add a crate with a mock blas implementation, so that we can assert that cblas_sgemm etc are called (depending on memory layout).
1 parent a5d6eaf commit d992a13

File tree

5 files changed

+210
-2
lines changed

5 files changed

+210
-2
lines changed

Cargo.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ approx = { workspace = true, optional = true }
3535
rayon = { version = "1.10.0", optional = true }
3636

3737
# Use via the `blas` crate feature
38-
cblas-sys = { version = "0.1.4", optional = true, default-features = false }
38+
cblas-sys = { workspace = true, optional = true }
3939
libc = { version = "0.2.82", optional = true }
4040

4141
matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm"] }
@@ -90,7 +90,7 @@ default-members = [
9090
"crates/ndarray-gen",
9191
"crates/numeric-tests",
9292
"crates/serialization-tests",
93-
# exclude blas-tests that depends on BLAS install
93+
# exclude blas-tests and blas-mock-tests that activate "blas" feature
9494
]
9595

9696
[workspace.dependencies]
@@ -106,6 +106,7 @@ quickcheck = { version = "1.0", default-features = false }
106106
rand = { version = "0.8.0", features = ["small_rng"] }
107107
rand_distr = { version = "0.4.0" }
108108
itertools = { version = "0.13.0", default-features = false, features = ["use_std"] }
109+
cblas-sys = { version = "0.1.4", default-features = false }
109110

110111
[profile.bench]
111112
debug = true

crates/blas-mock-tests/Cargo.toml

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
[package]
2+
name = "blas-mock-tests"
3+
version = "0.1.0"
4+
edition = "2018"
5+
publish = false
6+
7+
[lib]
8+
test = false
9+
doc = false
10+
doctest = false
11+
12+
[dependencies]
13+
ndarray = { workspace = true, features = ["approx", "blas"] }
14+
ndarray-gen = { workspace = true }
15+
cblas-sys = { workspace = true }
16+
17+
[dev-dependencies]
18+
itertools = { workspace = true }

crates/blas-mock-tests/src/lib.rs

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
//! Mock interfaces to BLAS
2+
3+
use core::cell::RefCell;
4+
use core::ffi::{c_double, c_float, c_int};
5+
use std::thread_local;
6+
7+
use cblas_sys::{c_double_complex, c_float_complex, CBLAS_LAYOUT, CBLAS_TRANSPOSE};
8+
9+
thread_local! {
10+
/// This counter is incremented every time a gemm function is called
11+
pub static CALL_COUNT: RefCell<usize> = RefCell::new(0);
12+
}
13+
14+
#[rustfmt::skip]
15+
#[no_mangle]
16+
#[allow(unused)]
17+
pub unsafe extern "C" fn cblas_sgemm(
18+
layout: CBLAS_LAYOUT,
19+
transa: CBLAS_TRANSPOSE,
20+
transb: CBLAS_TRANSPOSE,
21+
m: c_int,
22+
n: c_int,
23+
k: c_int,
24+
alpha: c_float,
25+
a: *const c_float,
26+
lda: c_int,
27+
b: *const c_float,
28+
ldb: c_int,
29+
beta: c_float,
30+
c: *mut c_float,
31+
ldc: c_int
32+
) {
33+
CALL_COUNT.with(|ctx| *ctx.borrow_mut() += 1);
34+
}
35+
36+
#[rustfmt::skip]
37+
#[no_mangle]
38+
#[allow(unused)]
39+
pub unsafe extern "C" fn cblas_dgemm(
40+
layout: CBLAS_LAYOUT,
41+
transa: CBLAS_TRANSPOSE,
42+
transb: CBLAS_TRANSPOSE,
43+
m: c_int,
44+
n: c_int,
45+
k: c_int,
46+
alpha: c_double,
47+
a: *const c_double,
48+
lda: c_int,
49+
b: *const c_double,
50+
ldb: c_int,
51+
beta: c_double,
52+
c: *mut c_double,
53+
ldc: c_int
54+
) {
55+
CALL_COUNT.with(|ctx| *ctx.borrow_mut() += 1);
56+
}
57+
58+
#[rustfmt::skip]
59+
#[no_mangle]
60+
#[allow(unused)]
61+
pub unsafe extern "C" fn cblas_cgemm(
62+
layout: CBLAS_LAYOUT,
63+
transa: CBLAS_TRANSPOSE,
64+
transb: CBLAS_TRANSPOSE,
65+
m: c_int,
66+
n: c_int,
67+
k: c_int,
68+
alpha: *const c_float_complex,
69+
a: *const c_float_complex,
70+
lda: c_int,
71+
b: *const c_float_complex,
72+
ldb: c_int,
73+
beta: *const c_float_complex,
74+
c: *mut c_float_complex,
75+
ldc: c_int
76+
) {
77+
CALL_COUNT.with(|ctx| *ctx.borrow_mut() += 1);
78+
}
79+
80+
#[rustfmt::skip]
81+
#[no_mangle]
82+
#[allow(unused)]
83+
pub unsafe extern "C" fn cblas_zgemm(
84+
layout: CBLAS_LAYOUT,
85+
transa: CBLAS_TRANSPOSE,
86+
transb: CBLAS_TRANSPOSE,
87+
m: c_int,
88+
n: c_int,
89+
k: c_int,
90+
alpha: *const c_double_complex,
91+
a: *const c_double_complex,
92+
lda: c_int,
93+
b: *const c_double_complex,
94+
ldb: c_int,
95+
beta: *const c_double_complex,
96+
c: *mut c_double_complex,
97+
ldc: c_int
98+
) {
99+
CALL_COUNT.with(|ctx| *ctx.borrow_mut() += 1);
100+
}
+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
extern crate ndarray;
2+
3+
use ndarray::prelude::*;
4+
5+
use blas_mock_tests::CALL_COUNT;
6+
use ndarray::linalg::general_mat_mul;
7+
use ndarray::Order;
8+
use ndarray_gen::array_builder::ArrayBuilder;
9+
10+
use itertools::iproduct;
11+
12+
#[test]
13+
fn test_gen_mat_mul_uses_blas()
14+
{
15+
let alpha = 1.0;
16+
let beta = 0.0;
17+
18+
let sizes = vec![
19+
(8, 8, 8),
20+
(10, 10, 10),
21+
(8, 8, 1),
22+
(1, 10, 10),
23+
(10, 1, 10),
24+
(10, 10, 1),
25+
(1, 10, 1),
26+
(10, 1, 1),
27+
(1, 1, 10),
28+
(4, 17, 3),
29+
(17, 3, 22),
30+
(19, 18, 2),
31+
(16, 17, 15),
32+
(15, 16, 17),
33+
(67, 63, 62),
34+
];
35+
let strides = &[1, 2, -1, -2];
36+
let cf_order = [Order::C, Order::F];
37+
38+
// test different strides and memory orders
39+
for &(m, k, n) in &sizes {
40+
for (&s1, &s2) in iproduct!(strides, strides) {
41+
for (ord1, ord2, ord3) in iproduct!(cf_order, cf_order, cf_order) {
42+
println!("Case s1={}, s2={}, orders={:?}, {:?}, {:?}", s1, s2, ord1, ord2, ord3);
43+
44+
let a = ArrayBuilder::new((m, k)).memory_order(ord1).build();
45+
let b = ArrayBuilder::new((k, n)).memory_order(ord2).build();
46+
let mut c = ArrayBuilder::new((m, n)).memory_order(ord3).build();
47+
48+
{
49+
let av;
50+
let bv;
51+
let mut cv;
52+
53+
if s1 != 1 || s2 != 1 {
54+
av = a.slice(s![..;s1, ..;s2]);
55+
bv = b.slice(s![..;s2, ..;s2]);
56+
cv = c.slice_mut(s![..;s1, ..;s2]);
57+
} else {
58+
// different stride cases for slicing versus not sliced (for axes of
59+
// len=1); so test not sliced here.
60+
av = a.view();
61+
bv = b.view();
62+
cv = c.view_mut();
63+
}
64+
65+
let pre_count = CALL_COUNT.with(|ctx| *ctx.borrow());
66+
general_mat_mul(alpha, &av, &bv, beta, &mut cv);
67+
let after_count = CALL_COUNT.with(|ctx| *ctx.borrow());
68+
let ncalls = after_count - pre_count;
69+
debug_assert!(ncalls <= 1);
70+
71+
let always_uses_blas = s1 == 1 && s2 == 1;
72+
73+
if always_uses_blas {
74+
assert_eq!(ncalls, 1, "Contiguous arrays should use blas, orders={:?}", (ord1, ord2, ord3));
75+
}
76+
77+
let should_use_blas = av.strides().iter().all(|&s| s > 0)
78+
&& bv.strides().iter().all(|&s| s > 0)
79+
&& cv.strides().iter().all(|&s| s > 0)
80+
&& av.strides().iter().any(|&s| s == 1)
81+
&& bv.strides().iter().any(|&s| s == 1)
82+
&& cv.strides().iter().any(|&s| s == 1);
83+
assert_eq!(should_use_blas, ncalls > 0);
84+
}
85+
}
86+
}
87+
}
88+
}

scripts/all-tests.sh

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ cargo test -v -p ndarray -p ndarray-rand --release --features "$FEATURES" $QC_FE
2020

2121
# BLAS tests
2222
cargo test -p ndarray --lib -v --features blas
23+
cargo test -p blas-mock-tests -v
2324
cargo test -p blas-tests -v --features blas-tests/openblas-system
2425
cargo test -p numeric-tests -v --features numeric-tests/test_blas
2526

0 commit comments

Comments
 (0)