Skip to content

Commit 33e2a58

Browse files
authored
Merge pull request #1421 from rust-ndarray/blas-simplify
Refactor and simplify BLAS gemm call further
2 parents 1df6c32 + 876ad01 commit 33e2a58

File tree

7 files changed

+127
-156
lines changed

7 files changed

+127
-156
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ keywords = ["array", "data-structure", "multidimensional", "matrix", "blas"]
2020
categories = ["data-structures", "science"]
2121

2222
exclude = ["docgen/images/*"]
23+
resolver = "2"
2324

2425
[lib]
2526
name = "ndarray"

crates/blas-mock-tests/Cargo.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ doc = false
1010
doctest = false
1111

1212
[dependencies]
13-
ndarray = { workspace = true, features = ["approx", "blas"] }
14-
ndarray-gen = { workspace = true }
1513
cblas-sys = { workspace = true }
1614

1715
[dev-dependencies]
16+
ndarray = { workspace = true, features = ["approx", "blas"] }
17+
ndarray-gen = { workspace = true }
1818
itertools = { workspace = true }

crates/blas-tests/tests/oper.rs

+10-6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use ndarray::linalg::general_mat_vec_mul;
1212
use ndarray::Order;
1313
use ndarray::{Data, Ix, LinalgScalar};
1414
use ndarray_gen::array_builder::ArrayBuilder;
15+
use ndarray_gen::array_builder::ElementGenerator;
1516

1617
use approx::assert_relative_eq;
1718
use defmac::defmac;
@@ -230,7 +231,6 @@ fn gen_mat_mul()
230231
let sizes = vec![
231232
(4, 4, 4),
232233
(8, 8, 8),
233-
(10, 10, 10),
234234
(8, 8, 1),
235235
(1, 10, 10),
236236
(10, 1, 10),
@@ -241,19 +241,23 @@ fn gen_mat_mul()
241241
(4, 17, 3),
242242
(17, 3, 22),
243243
(19, 18, 2),
244-
(16, 17, 15),
245244
(15, 16, 17),
246-
(67, 63, 62),
245+
(67, 50, 62),
247246
];
248247
let strides = &[1, 2, -1, -2];
249248
let cf_order = [Order::C, Order::F];
249+
let generator = [ElementGenerator::Sequential, ElementGenerator::Checkerboard];
250250

251251
// test different strides and memory orders
252-
for (&s1, &s2) in iproduct!(strides, strides) {
252+
for (&s1, &s2, &gen) in iproduct!(strides, strides, &generator) {
253253
for &(m, k, n) in &sizes {
254254
for (ord1, ord2, ord3) in iproduct!(cf_order, cf_order, cf_order) {
255-
println!("Case s1={}, s2={}, orders={:?}, {:?}, {:?}", s1, s2, ord1, ord2, ord3);
256-
let a = ArrayBuilder::new((m, k)).memory_order(ord1).build() * 0.5;
255+
println!("Case s1={}, s2={}, gen={:?}, orders={:?}, {:?}, {:?}", s1, s2, gen, ord1, ord2, ord3);
256+
let a = ArrayBuilder::new((m, k))
257+
.memory_order(ord1)
258+
.generator(gen)
259+
.build()
260+
* 0.5;
257261
let b = ArrayBuilder::new((k, n)).memory_order(ord2).build();
258262
let mut c = ArrayBuilder::new((m, n)).memory_order(ord3).build();
259263

crates/ndarray-gen/src/array_builder.rs

+8-9
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pub struct ArrayBuilder<D: Dimension>
2626
pub enum ElementGenerator
2727
{
2828
Sequential,
29+
Checkerboard,
2930
Zero,
3031
}
3132

@@ -64,16 +65,14 @@ where D: Dimension
6465
pub fn build<T>(self) -> Array<T, D>
6566
where T: Num + Clone
6667
{
67-
let mut current = T::zero();
68+
let zero = T::zero();
6869
let size = self.dim.size();
69-
let use_zeros = self.generator == ElementGenerator::Zero;
70-
Array::from_iter((0..size).map(|_| {
71-
let ret = current.clone();
72-
if !use_zeros {
73-
current = ret.clone() + T::one();
74-
}
75-
ret
76-
}))
70+
(match self.generator {
71+
ElementGenerator::Sequential =>
72+
Array::from_iter(core::iter::successors(Some(zero), |elt| Some(elt.clone() + T::one())).take(size)),
73+
ElementGenerator::Checkerboard => Array::from_iter([T::one(), zero].iter().cycle().take(size).cloned()),
74+
ElementGenerator::Zero => Array::zeros(size),
75+
})
7776
.into_shape_with_order((self.dim, self.memory_order))
7877
.unwrap()
7978
}

scripts/cross-tests.sh

+1
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ QC_FEAT=--features=ndarray-rand/quickcheck
1111

1212
cross build -v --features="$FEATURES" $QC_FEAT --target=$TARGET
1313
cross test -v --no-fail-fast --features="$FEATURES" $QC_FEAT --target=$TARGET
14+
cross test -v -p blas-mock-tests

scripts/makechangelog.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# Will produce some duplicates for PRs integrated using rebase,
99
# but those will not occur with current merge queue.
1010

11-
git log --first-parent --pretty="format:%H" "$@" | while read commit_sha
11+
git log --first-parent --pretty="tformat:%H" "$@" | while IFS= read -r commit_sha
1212
do
1313
gh api "/repos/:owner/:repo/commits/${commit_sha}/pulls" \
1414
-q ".[] | \"- \(.title) by [@\(.user.login)](\(.user.html_url)) [#\(.number)](\(.html_url))\""

0 commit comments

Comments
 (0)