diff --git a/changelog.rst b/changelog.rst index c4aa4e3c..f9bb0346 100644 --- a/changelog.rst +++ b/changelog.rst @@ -2,6 +2,9 @@ Changelog ========= +- Unreleased + - ``MulAcc`` is generalised to allow different output types from input + - 0.10.0 - support more scalar types for scalar/matrix multiplication - refactor the handling of ``CsMatBase``'s ``indptr`` member to be able to diff --git a/src/mul_acc.rs b/src/mul_acc.rs index 0e3a9f39..751fc260 100644 --- a/src/mul_acc.rs +++ b/src/mul_acc.rs @@ -12,16 +12,19 @@ /// to provide the most performant implementation. For instance, we could have /// a default implementation for numeric types that are `Clone`, but it would /// make possibly unnecessary copies. -pub trait MulAcc { +pub trait MulAcc { /// Multiply and accumulate in this variable, formally `*self += a * b`. - fn mul_acc(&mut self, a: &Self, b: &Self); + fn mul_acc(&mut self, a: &A, b: &B); } -impl MulAcc for N +/// Default for types which supports `mul_add` +impl MulAcc for N where - N: Copy + num_traits::MulAdd, + N: Copy, + B: Copy, + A: num_traits::MulAdd + Copy, { - fn mul_acc(&mut self, a: &Self, b: &Self) { + fn mul_acc(&mut self, a: &A, b: &B) { *self = a.mul_add(*b, *self); } } @@ -38,4 +41,22 @@ mod tests { a.mul_acc(&b, &c); assert_eq!(a, 7.); } + + #[derive(Debug, Copy, Clone, Default)] + struct Wrapped(T); + + impl MulAcc, Wrapped> for Wrapped { + fn mul_acc(&mut self, a: &Wrapped, b: &Wrapped) { + self.0 = self.0 + a.0 as i32 * b.0 as i32; + } + } + + #[test] + fn mul_acc_mixed_param_sizes() { + let mut a = Wrapped::(0x40000007i32); + let b = Wrapped::(0x20i8); + let c = Wrapped::(0x3000i16); + a.mul_acc(&b, &c); + assert_eq!(a.0, 0x40060007i32); + } } diff --git a/src/sparse/prod.rs b/src/sparse/prod.rs index 74f04af4..1b52c6aa 100644 --- a/src/sparse/prod.rs +++ b/src/sparse/prod.rs @@ -10,19 +10,43 @@ use num_traits::Num; /// Compute the dot product of two sparse vectors, using binary search to find matching indices. /// /// Runs in O(MlogN) time, where M and N are the number of non-zero entries in each vector. -pub fn csvec_dot_by_binary_search( - vec1: CsVecViewI, - vec2: CsVecViewI, +pub fn csvec_dot_by_binary_search( + vec1: CsVecViewI, + vec2: CsVecViewI, ) -> N where I: SpIndex, - N: crate::MulAcc + num_traits::Zero, + N: crate::MulAcc + num_traits::Zero, { - let (mut idx1, mut val1, mut idx2, mut val2) = if vec1.nnz() < vec2.nnz() { - (vec1.indices(), vec1.data(), vec2.indices(), vec2.data()) + // Check vec1.nnz vec2.nnz() { + csvec_dot_by_binary_search_impl(vec2, vec1, |acc: &mut N, a, b| { + acc.mul_acc(b, a) + }) } else { - (vec2.indices(), vec2.data(), vec1.indices(), vec1.data()) - }; + csvec_dot_by_binary_search_impl(vec1, vec2, |acc: &mut N, a, b| { + acc.mul_acc(a, b) + }) + } +} + +/// Inner routine of `csvec_dot_by_binary_search`, removes need for commutative `MulAcc` +pub(crate) fn csvec_dot_by_binary_search_impl( + vec1: CsVecViewI, + vec2: CsVecViewI, + mul_acc: F, +) -> N +where + F: Fn(&mut N, &A, &B), + I: SpIndex, + N: num_traits::Zero, +{ + assert!(vec1.nnz() <= vec2.nnz()); + // vec1.nnz is smaller + let (mut idx1, mut val1, mut idx2, mut val2) = + (vec1.indices(), vec1.data(), vec2.indices(), vec2.data()); let mut sum = N::zero(); while !idx1.is_empty() && !idx2.is_empty() { @@ -34,7 +58,7 @@ where Err(i) => (false, i), }; if found { - sum.mul_acc(&val1[0], &val2[i]); + mul_acc(&mut sum, &val1[0], &val2[i]); } idx1 = &idx1[1..]; val1 = &val1[1..];