Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix using BLAS for all compatible cases of memory layout #1419

Merged
merged 8 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ approx = { workspace = true, optional = true }
rayon = { version = "1.10.0", optional = true }

# Use via the `blas` crate feature
cblas-sys = { version = "0.1.4", optional = true, default-features = false }
cblas-sys = { workspace = true, optional = true }
libc = { version = "0.2.82", optional = true }

matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm"] }
Expand All @@ -47,7 +47,8 @@ rawpointer = { version = "0.2" }
defmac = "0.2"
quickcheck = { workspace = true }
approx = { workspace = true, default-features = true }
itertools = { version = "0.13.0", default-features = false, features = ["use_std"] }
itertools = { workspace = true }
ndarray-gen = { workspace = true }

[features]
default = ["std"]
Expand All @@ -73,6 +74,7 @@ matrixmultiply-threading = ["matrixmultiply/threading"]

portable-atomic-critical-section = ["portable-atomic/critical-section"]


[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
portable-atomic = { version = "1.6.0" }
portable-atomic-util = { version = "0.2.0", features = [ "alloc" ] }
Expand All @@ -85,14 +87,16 @@ members = [
default-members = [
".",
"ndarray-rand",
"crates/ndarray-gen",
"crates/numeric-tests",
"crates/serialization-tests",
# exclude blas-tests that depends on BLAS install
# exclude blas-tests and blas-mock-tests that activate "blas" feature
]

[workspace.dependencies]
ndarray = { version = "0.16", path = "." }
ndarray = { version = "0.16", path = ".", default-features = false }
ndarray-rand = { path = "ndarray-rand" }
ndarray-gen = { path = "crates/ndarray-gen" }

num-integer = { version = "0.1.39", default-features = false }
num-traits = { version = "0.2", default-features = false }
Expand All @@ -101,6 +105,8 @@ approx = { version = "0.5", default-features = false }
quickcheck = { version = "1.0", default-features = false }
rand = { version = "0.8.0", features = ["small_rng"] }
rand_distr = { version = "0.4.0" }
itertools = { version = "0.13.0", default-features = false, features = ["use_std"] }
cblas-sys = { version = "0.1.4", default-features = false }

[profile.bench]
debug = true
Expand Down
18 changes: 18 additions & 0 deletions crates/blas-mock-tests/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[package]
name = "blas-mock-tests"
version = "0.1.0"
edition = "2018"
publish = false

[lib]
test = false
doc = false
doctest = false

[dependencies]
ndarray = { workspace = true, features = ["approx", "blas"] }
ndarray-gen = { workspace = true }
cblas-sys = { workspace = true }

[dev-dependencies]
itertools = { workspace = true }
100 changes: 100 additions & 0 deletions crates/blas-mock-tests/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
//! Mock interfaces to BLAS

use core::cell::RefCell;
use core::ffi::{c_double, c_float, c_int};
use std::thread_local;

use cblas_sys::{c_double_complex, c_float_complex, CBLAS_LAYOUT, CBLAS_TRANSPOSE};

thread_local! {
/// This counter is incremented every time a gemm function is called
pub static CALL_COUNT: RefCell<usize> = RefCell::new(0);
}

#[rustfmt::skip]
#[no_mangle]
#[allow(unused)]
pub unsafe extern "C" fn cblas_sgemm(
layout: CBLAS_LAYOUT,
transa: CBLAS_TRANSPOSE,
transb: CBLAS_TRANSPOSE,
m: c_int,
n: c_int,
k: c_int,
alpha: c_float,
a: *const c_float,
lda: c_int,
b: *const c_float,
ldb: c_int,
beta: c_float,
c: *mut c_float,
ldc: c_int
) {
CALL_COUNT.with(|ctx| *ctx.borrow_mut() += 1);
}

#[rustfmt::skip]
#[no_mangle]
#[allow(unused)]
pub unsafe extern "C" fn cblas_dgemm(
layout: CBLAS_LAYOUT,
transa: CBLAS_TRANSPOSE,
transb: CBLAS_TRANSPOSE,
m: c_int,
n: c_int,
k: c_int,
alpha: c_double,
a: *const c_double,
lda: c_int,
b: *const c_double,
ldb: c_int,
beta: c_double,
c: *mut c_double,
ldc: c_int
) {
CALL_COUNT.with(|ctx| *ctx.borrow_mut() += 1);
}

#[rustfmt::skip]
#[no_mangle]
#[allow(unused)]
pub unsafe extern "C" fn cblas_cgemm(
layout: CBLAS_LAYOUT,
transa: CBLAS_TRANSPOSE,
transb: CBLAS_TRANSPOSE,
m: c_int,
n: c_int,
k: c_int,
alpha: *const c_float_complex,
a: *const c_float_complex,
lda: c_int,
b: *const c_float_complex,
ldb: c_int,
beta: *const c_float_complex,
c: *mut c_float_complex,
ldc: c_int
) {
CALL_COUNT.with(|ctx| *ctx.borrow_mut() += 1);
}

#[rustfmt::skip]
#[no_mangle]
#[allow(unused)]
pub unsafe extern "C" fn cblas_zgemm(
layout: CBLAS_LAYOUT,
transa: CBLAS_TRANSPOSE,
transb: CBLAS_TRANSPOSE,
m: c_int,
n: c_int,
k: c_int,
alpha: *const c_double_complex,
a: *const c_double_complex,
lda: c_int,
b: *const c_double_complex,
ldb: c_int,
beta: *const c_double_complex,
c: *mut c_double_complex,
ldc: c_int
) {
CALL_COUNT.with(|ctx| *ctx.borrow_mut() += 1);
}
88 changes: 88 additions & 0 deletions crates/blas-mock-tests/tests/use-blas.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
extern crate ndarray;

use ndarray::prelude::*;

use blas_mock_tests::CALL_COUNT;
use ndarray::linalg::general_mat_mul;
use ndarray::Order;
use ndarray_gen::array_builder::ArrayBuilder;

use itertools::iproduct;

#[test]
fn test_gen_mat_mul_uses_blas()
{
let alpha = 1.0;
let beta = 0.0;

let sizes = vec![
(8, 8, 8),
(10, 10, 10),
(8, 8, 1),
(1, 10, 10),
(10, 1, 10),
(10, 10, 1),
(1, 10, 1),
(10, 1, 1),
(1, 1, 10),
(4, 17, 3),
(17, 3, 22),
(19, 18, 2),
(16, 17, 15),
(15, 16, 17),
(67, 63, 62),
];
let strides = &[1, 2, -1, -2];
let cf_order = [Order::C, Order::F];

// test different strides and memory orders
for &(m, k, n) in &sizes {
for (&s1, &s2) in iproduct!(strides, strides) {
for (ord1, ord2, ord3) in iproduct!(cf_order, cf_order, cf_order) {
println!("Case s1={}, s2={}, orders={:?}, {:?}, {:?}", s1, s2, ord1, ord2, ord3);

let a = ArrayBuilder::new((m, k)).memory_order(ord1).build();
let b = ArrayBuilder::new((k, n)).memory_order(ord2).build();
let mut c = ArrayBuilder::new((m, n)).memory_order(ord3).build();

{
let av;
let bv;
let mut cv;

if s1 != 1 || s2 != 1 {
av = a.slice(s![..;s1, ..;s2]);
bv = b.slice(s![..;s2, ..;s2]);
cv = c.slice_mut(s![..;s1, ..;s2]);
} else {
// different stride cases for slicing versus not sliced (for axes of
// len=1); so test not sliced here.
av = a.view();
bv = b.view();
cv = c.view_mut();
}

let pre_count = CALL_COUNT.with(|ctx| *ctx.borrow());
general_mat_mul(alpha, &av, &bv, beta, &mut cv);
let after_count = CALL_COUNT.with(|ctx| *ctx.borrow());
let ncalls = after_count - pre_count;
debug_assert!(ncalls <= 1);

let always_uses_blas = s1 == 1 && s2 == 1;

if always_uses_blas {
assert_eq!(ncalls, 1, "Contiguous arrays should use blas, orders={:?}", (ord1, ord2, ord3));
}

let should_use_blas = av.strides().iter().all(|&s| s > 0)
&& bv.strides().iter().all(|&s| s > 0)
&& cv.strides().iter().all(|&s| s > 0)
&& av.strides().iter().any(|&s| s == 1)
&& bv.strides().iter().any(|&s| s == 1)
&& cv.strides().iter().any(|&s| s == 1);
assert_eq!(should_use_blas, ncalls > 0);
}
}
}
}
}
4 changes: 3 additions & 1 deletion crates/blas-tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ doc = false
doctest = false

[dependencies]
ndarray = { workspace = true, features = ["approx"] }
ndarray = { workspace = true, features = ["approx", "blas"] }
ndarray-gen = { workspace = true }

blas-src = { version = "0.10", optional = true }
openblas-src = { version = "0.10", optional = true }
Expand All @@ -23,6 +24,7 @@ defmac = "0.2"
approx = { workspace = true }
num-traits = { workspace = true }
num-complex = { workspace = true }
itertools = { workspace = true }

[features]
# Just for making an example and to help testing, , multiple different possible
Expand Down
Loading