From eee5c5565ad8ced685fbf295b5356b4726483f26 Mon Sep 17 00:00:00 2001
From: glcoder0 <glcoder@mail.com>
Date: Sat, 24 Apr 2021 12:12:29 +0100
Subject: [PATCH 1/2] Generalise MulAcc trait

---
 src/mul_acc.rs     | 31 ++++++++++++++++++++++++++-----
 src/sparse/prod.rs | 42 +++++++++++++++++++++++++++++++++---------
 2 files changed, 59 insertions(+), 14 deletions(-)

diff --git a/src/mul_acc.rs b/src/mul_acc.rs
index 0e3a9f39..751fc260 100644
--- a/src/mul_acc.rs
+++ b/src/mul_acc.rs
@@ -12,16 +12,19 @@
 /// to provide the most performant implementation. For instance, we could have
 /// a default implementation for numeric types that are `Clone`, but it would
 /// make possibly unnecessary copies.
-pub trait MulAcc {
+pub trait MulAcc<A = Self, B = A> {
     /// Multiply and accumulate in this variable, formally `*self += a * b`.
-    fn mul_acc(&mut self, a: &Self, b: &Self);
+    fn mul_acc(&mut self, a: &A, b: &B);
 }
 
-impl<N> MulAcc for N
+/// Default for types which supports `mul_add`
+impl<N, A, B> MulAcc<A, B> for N
 where
-    N: Copy + num_traits::MulAdd<Output = N>,
+    N: Copy,
+    B: Copy,
+    A: num_traits::MulAdd<B, N, Output = N> + Copy,
 {
-    fn mul_acc(&mut self, a: &Self, b: &Self) {
+    fn mul_acc(&mut self, a: &A, b: &B) {
         *self = a.mul_add(*b, *self);
     }
 }
@@ -38,4 +41,22 @@ mod tests {
         a.mul_acc(&b, &c);
         assert_eq!(a, 7.);
     }
+
+    #[derive(Debug, Copy, Clone, Default)]
+    struct Wrapped<T: Default + Copy + std::fmt::Debug>(T);
+
+    impl MulAcc<Wrapped<i8>, Wrapped<i16>> for Wrapped<i32> {
+        fn mul_acc(&mut self, a: &Wrapped<i8>, b: &Wrapped<i16>) {
+            self.0 = self.0 + a.0 as i32 * b.0 as i32;
+        }
+    }
+
+    #[test]
+    fn mul_acc_mixed_param_sizes() {
+        let mut a = Wrapped::<i32>(0x40000007i32);
+        let b = Wrapped::<i8>(0x20i8);
+        let c = Wrapped::<i16>(0x3000i16);
+        a.mul_acc(&b, &c);
+        assert_eq!(a.0, 0x40060007i32);
+    }
 }
diff --git a/src/sparse/prod.rs b/src/sparse/prod.rs
index 74f04af4..1b52c6aa 100644
--- a/src/sparse/prod.rs
+++ b/src/sparse/prod.rs
@@ -10,19 +10,43 @@ use num_traits::Num;
 /// Compute the dot product of two sparse vectors, using binary search to find matching indices.
 ///
 /// Runs in O(MlogN) time, where M and N are the number of non-zero entries in each vector.
-pub fn csvec_dot_by_binary_search<N, I>(
-    vec1: CsVecViewI<N, I>,
-    vec2: CsVecViewI<N, I>,
+pub fn csvec_dot_by_binary_search<N, I, A, B>(
+    vec1: CsVecViewI<A, I>,
+    vec2: CsVecViewI<B, I>,
 ) -> N
 where
     I: SpIndex,
-    N: crate::MulAcc + num_traits::Zero,
+    N: crate::MulAcc<A, B> + num_traits::Zero,
 {
-    let (mut idx1, mut val1, mut idx2, mut val2) = if vec1.nnz() < vec2.nnz() {
-        (vec1.indices(), vec1.data(), vec2.indices(), vec2.data())
+    // Check vec1.nnz<vec2.nnz
+    // Reverse the dot product vec1 and vec2, but preserve possibly non-commutative MulAcc
+    // through a lamba.
+    if vec1.nnz() > vec2.nnz() {
+        csvec_dot_by_binary_search_impl(vec2, vec1, |acc: &mut N, a, b| {
+            acc.mul_acc(b, a)
+        })
     } else {
-        (vec2.indices(), vec2.data(), vec1.indices(), vec1.data())
-    };
+        csvec_dot_by_binary_search_impl(vec1, vec2, |acc: &mut N, a, b| {
+            acc.mul_acc(a, b)
+        })
+    }
+}
+
+/// Inner routine of `csvec_dot_by_binary_search`, removes need for commutative `MulAcc`
+pub(crate) fn csvec_dot_by_binary_search_impl<N, I, A, B, F>(
+    vec1: CsVecViewI<A, I>,
+    vec2: CsVecViewI<B, I>,
+    mul_acc: F,
+) -> N
+where
+    F: Fn(&mut N, &A, &B),
+    I: SpIndex,
+    N: num_traits::Zero,
+{
+    assert!(vec1.nnz() <= vec2.nnz());
+    // vec1.nnz is smaller
+    let (mut idx1, mut val1, mut idx2, mut val2) =
+        (vec1.indices(), vec1.data(), vec2.indices(), vec2.data());
 
     let mut sum = N::zero();
     while !idx1.is_empty() && !idx2.is_empty() {
@@ -34,7 +58,7 @@ where
             Err(i) => (false, i),
         };
         if found {
-            sum.mul_acc(&val1[0], &val2[i]);
+            mul_acc(&mut sum, &val1[0], &val2[i]);
         }
         idx1 = &idx1[1..];
         val1 = &val1[1..];

From 5a4e6e7a8ff419ac521f64d55e3d92bfae88ee50 Mon Sep 17 00:00:00 2001
From: Magnus Ulimoen <magnus@ulimoen.dev>
Date: Mon, 26 Apr 2021 20:19:35 +0200
Subject: [PATCH 2/2] Add to changelog

---
 changelog.rst | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/changelog.rst b/changelog.rst
index c4aa4e3c..f9bb0346 100644
--- a/changelog.rst
+++ b/changelog.rst
@@ -2,6 +2,9 @@
 Changelog
 =========
 
+- Unreleased
+  - ``MulAcc`` is generalised to allow different output types from input
+
 - 0.10.0
   - support more scalar types for scalar/matrix multiplication
   - refactor the handling of ``CsMatBase``'s ``indptr`` member to be able to