From eee5c5565ad8ced685fbf295b5356b4726483f26 Mon Sep 17 00:00:00 2001 From: glcoder0 <glcoder@mail.com> Date: Sat, 24 Apr 2021 12:12:29 +0100 Subject: [PATCH 1/2] Generalise MulAcc trait --- src/mul_acc.rs | 31 ++++++++++++++++++++++++++----- src/sparse/prod.rs | 42 +++++++++++++++++++++++++++++++++--------- 2 files changed, 59 insertions(+), 14 deletions(-) diff --git a/src/mul_acc.rs b/src/mul_acc.rs index 0e3a9f39..751fc260 100644 --- a/src/mul_acc.rs +++ b/src/mul_acc.rs @@ -12,16 +12,19 @@ /// to provide the most performant implementation. For instance, we could have /// a default implementation for numeric types that are `Clone`, but it would /// make possibly unnecessary copies. -pub trait MulAcc { +pub trait MulAcc<A = Self, B = A> { /// Multiply and accumulate in this variable, formally `*self += a * b`. - fn mul_acc(&mut self, a: &Self, b: &Self); + fn mul_acc(&mut self, a: &A, b: &B); } -impl<N> MulAcc for N +/// Default for types which supports `mul_add` +impl<N, A, B> MulAcc<A, B> for N where - N: Copy + num_traits::MulAdd<Output = N>, + N: Copy, + B: Copy, + A: num_traits::MulAdd<B, N, Output = N> + Copy, { - fn mul_acc(&mut self, a: &Self, b: &Self) { + fn mul_acc(&mut self, a: &A, b: &B) { *self = a.mul_add(*b, *self); } } @@ -38,4 +41,22 @@ mod tests { a.mul_acc(&b, &c); assert_eq!(a, 7.); } + + #[derive(Debug, Copy, Clone, Default)] + struct Wrapped<T: Default + Copy + std::fmt::Debug>(T); + + impl MulAcc<Wrapped<i8>, Wrapped<i16>> for Wrapped<i32> { + fn mul_acc(&mut self, a: &Wrapped<i8>, b: &Wrapped<i16>) { + self.0 = self.0 + a.0 as i32 * b.0 as i32; + } + } + + #[test] + fn mul_acc_mixed_param_sizes() { + let mut a = Wrapped::<i32>(0x40000007i32); + let b = Wrapped::<i8>(0x20i8); + let c = Wrapped::<i16>(0x3000i16); + a.mul_acc(&b, &c); + assert_eq!(a.0, 0x40060007i32); + } } diff --git a/src/sparse/prod.rs b/src/sparse/prod.rs index 74f04af4..1b52c6aa 100644 --- a/src/sparse/prod.rs +++ b/src/sparse/prod.rs @@ -10,19 +10,43 @@ use num_traits::Num; /// Compute the dot product of two sparse vectors, using binary search to find matching indices. /// /// Runs in O(MlogN) time, where M and N are the number of non-zero entries in each vector. -pub fn csvec_dot_by_binary_search<N, I>( - vec1: CsVecViewI<N, I>, - vec2: CsVecViewI<N, I>, +pub fn csvec_dot_by_binary_search<N, I, A, B>( + vec1: CsVecViewI<A, I>, + vec2: CsVecViewI<B, I>, ) -> N where I: SpIndex, - N: crate::MulAcc + num_traits::Zero, + N: crate::MulAcc<A, B> + num_traits::Zero, { - let (mut idx1, mut val1, mut idx2, mut val2) = if vec1.nnz() < vec2.nnz() { - (vec1.indices(), vec1.data(), vec2.indices(), vec2.data()) + // Check vec1.nnz<vec2.nnz + // Reverse the dot product vec1 and vec2, but preserve possibly non-commutative MulAcc + // through a lamba. + if vec1.nnz() > vec2.nnz() { + csvec_dot_by_binary_search_impl(vec2, vec1, |acc: &mut N, a, b| { + acc.mul_acc(b, a) + }) } else { - (vec2.indices(), vec2.data(), vec1.indices(), vec1.data()) - }; + csvec_dot_by_binary_search_impl(vec1, vec2, |acc: &mut N, a, b| { + acc.mul_acc(a, b) + }) + } +} + +/// Inner routine of `csvec_dot_by_binary_search`, removes need for commutative `MulAcc` +pub(crate) fn csvec_dot_by_binary_search_impl<N, I, A, B, F>( + vec1: CsVecViewI<A, I>, + vec2: CsVecViewI<B, I>, + mul_acc: F, +) -> N +where + F: Fn(&mut N, &A, &B), + I: SpIndex, + N: num_traits::Zero, +{ + assert!(vec1.nnz() <= vec2.nnz()); + // vec1.nnz is smaller + let (mut idx1, mut val1, mut idx2, mut val2) = + (vec1.indices(), vec1.data(), vec2.indices(), vec2.data()); let mut sum = N::zero(); while !idx1.is_empty() && !idx2.is_empty() { @@ -34,7 +58,7 @@ where Err(i) => (false, i), }; if found { - sum.mul_acc(&val1[0], &val2[i]); + mul_acc(&mut sum, &val1[0], &val2[i]); } idx1 = &idx1[1..]; val1 = &val1[1..]; From 5a4e6e7a8ff419ac521f64d55e3d92bfae88ee50 Mon Sep 17 00:00:00 2001 From: Magnus Ulimoen <magnus@ulimoen.dev> Date: Mon, 26 Apr 2021 20:19:35 +0200 Subject: [PATCH 2/2] Add to changelog --- changelog.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/changelog.rst b/changelog.rst index c4aa4e3c..f9bb0346 100644 --- a/changelog.rst +++ b/changelog.rst @@ -2,6 +2,9 @@ Changelog ========= +- Unreleased + - ``MulAcc`` is generalised to allow different output types from input + - 0.10.0 - support more scalar types for scalar/matrix multiplication - refactor the handling of ``CsMatBase``'s ``indptr`` member to be able to