Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fully generic MulAcc<A,B> - groundwork for Fully Generic Matrix/Vector element, Product and Accumulator Types - #284

Merged
merged 2 commits into from
Apr 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
Changelog
=========

- Unreleased
- ``MulAcc`` is generalised to allow different output types from input

- 0.10.0
- support more scalar types for scalar/matrix multiplication
- refactor the handling of ``CsMatBase``'s ``indptr`` member to be able to
Expand Down
31 changes: 26 additions & 5 deletions src/mul_acc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@
/// to provide the most performant implementation. For instance, we could have
/// a default implementation for numeric types that are `Clone`, but it would
/// make possibly unnecessary copies.
pub trait MulAcc {
pub trait MulAcc<A = Self, B = A> {
/// Multiply and accumulate in this variable, formally `*self += a * b`.
fn mul_acc(&mut self, a: &Self, b: &Self);
fn mul_acc(&mut self, a: &A, b: &B);
}

impl<N> MulAcc for N
/// Default for types which supports `mul_add`
impl<N, A, B> MulAcc<A, B> for N
where
N: Copy + num_traits::MulAdd<Output = N>,
N: Copy,
B: Copy,
A: num_traits::MulAdd<B, N, Output = N> + Copy,
{
fn mul_acc(&mut self, a: &Self, b: &Self) {
fn mul_acc(&mut self, a: &A, b: &B) {
*self = a.mul_add(*b, *self);
}
}
Expand All @@ -38,4 +41,22 @@ mod tests {
a.mul_acc(&b, &c);
assert_eq!(a, 7.);
}

#[derive(Debug, Copy, Clone, Default)]
struct Wrapped<T: Default + Copy + std::fmt::Debug>(T);

impl MulAcc<Wrapped<i8>, Wrapped<i16>> for Wrapped<i32> {
fn mul_acc(&mut self, a: &Wrapped<i8>, b: &Wrapped<i16>) {
self.0 = self.0 + a.0 as i32 * b.0 as i32;
}
}

#[test]
fn mul_acc_mixed_param_sizes() {
let mut a = Wrapped::<i32>(0x40000007i32);
let b = Wrapped::<i8>(0x20i8);
let c = Wrapped::<i16>(0x3000i16);
a.mul_acc(&b, &c);
assert_eq!(a.0, 0x40060007i32);
}
}
42 changes: 33 additions & 9 deletions src/sparse/prod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,43 @@ use num_traits::Num;
/// Compute the dot product of two sparse vectors, using binary search to find matching indices.
///
/// Runs in O(MlogN) time, where M and N are the number of non-zero entries in each vector.
pub fn csvec_dot_by_binary_search<N, I>(
vec1: CsVecViewI<N, I>,
vec2: CsVecViewI<N, I>,
pub fn csvec_dot_by_binary_search<N, I, A, B>(
vec1: CsVecViewI<A, I>,
vec2: CsVecViewI<B, I>,
) -> N
where
I: SpIndex,
N: crate::MulAcc + num_traits::Zero,
N: crate::MulAcc<A, B> + num_traits::Zero,
{
let (mut idx1, mut val1, mut idx2, mut val2) = if vec1.nnz() < vec2.nnz() {
(vec1.indices(), vec1.data(), vec2.indices(), vec2.data())
// Check vec1.nnz<vec2.nnz
// Reverse the dot product vec1 and vec2, but preserve possibly non-commutative MulAcc
// through a lamba.
if vec1.nnz() > vec2.nnz() {
csvec_dot_by_binary_search_impl(vec2, vec1, |acc: &mut N, a, b| {
acc.mul_acc(b, a)
})
} else {
(vec2.indices(), vec2.data(), vec1.indices(), vec1.data())
};
csvec_dot_by_binary_search_impl(vec1, vec2, |acc: &mut N, a, b| {
acc.mul_acc(a, b)
})
}
}

/// Inner routine of `csvec_dot_by_binary_search`, removes need for commutative `MulAcc`
pub(crate) fn csvec_dot_by_binary_search_impl<N, I, A, B, F>(
vec1: CsVecViewI<A, I>,
vec2: CsVecViewI<B, I>,
mul_acc: F,
) -> N
where
F: Fn(&mut N, &A, &B),
I: SpIndex,
N: num_traits::Zero,
{
assert!(vec1.nnz() <= vec2.nnz());
// vec1.nnz is smaller
let (mut idx1, mut val1, mut idx2, mut val2) =
(vec1.indices(), vec1.data(), vec2.indices(), vec2.data());

let mut sum = N::zero();
while !idx1.is_empty() && !idx2.is_empty() {
Expand All @@ -34,7 +58,7 @@ where
Err(i) => (false, i),
};
if found {
sum.mul_acc(&val1[0], &val2[i]);
mul_acc(&mut sum, &val1[0], &val2[i]);
}
idx1 = &idx1[1..];
val1 = &val1[1..];
Expand Down