Skip to content

Commit

Permalink
feat: allow more SIMD in array operations
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Oct 24, 2023
1 parent 5b87e48 commit 2565230
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 69 deletions.
1 change: 1 addition & 0 deletions proptest-regressions/math.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
cc ea2a2598ee637946e47a8a744d25f76239671b82c5971a56b14ce5ee06838cb4 # shrinks to x = 4.8329699435311735, y = 9.38911339170414
cc cf16a8d08e8ee8f7f3d3cfd60840e136ac51d130dffcd42db1a9a68d7e51f394 # shrinks to (x, y) = ([2.9394791070664547e110, 0.0], [inf, 0.0]), a = -2.4153502104628106e222
cc 28897b64919482133f3885c3de51da0895409d23c9dd503a7b51a3e949bda307 # shrinks to (x1, x2, x3, y1, y2) = ([0.0], [0.0], [-4.0946726283401733e139], [0.0], [1.3157422010991668e73])
cc acf6caef8a89a75ddab31ec3e391850723a625084df032aec2b650c2f95ba1fb # shrinks to (x, y) = ([0.0, 0.0, 0.0, 1.2271235629394547e205, 0.0, 0.0, -0.0, 0.0], [0.0, 0.0, 0.0, 7.121658452243713e81, 0.0, 0.0, 0.0, 0.0]), a = -6.261465657118442e-124
90 changes: 49 additions & 41 deletions src/cpu_math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,12 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
}

fn array_all_finite_and_nonzero(&mut self, array: &Self::Array) -> bool {
array
.col_ref(0)
.iter()
.all(|&x| x.is_finite() & (x != 0f64))
self.arch.dispatch(|| {
array
.col_ref(0)
.iter()
.all(|&x| x.is_finite() & (x != 0f64))
})
}

fn array_mult(&mut self, array1: &Self::Array, array2: &Self::Array, dest: &mut Self::Array) {
Expand Down Expand Up @@ -152,16 +154,18 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
value: &Self::Array,
diff_scale: f64, // 1 / self.count
) {
izip!(
mean.col_mut(0).iter_mut(),
variance.col_mut(0).iter_mut(),
value.col_ref(0)
)
.for_each(|(mean, mut var, x)| {
let diff = x - *mean;
*mean += diff * diff_scale;
*var += diff * diff;
});
self.arch.dispatch(|| {
izip!(
mean.col_mut(0).iter_mut(),
variance.col_mut(0).iter_mut(),
value.col_ref(0)
)
.for_each(|(mean, var, x)| {
let diff = x - *mean;
*mean += diff * diff_scale;
*var += diff * diff;
});
})
}

fn array_update_var_inv_std_draw_grad(
Expand All @@ -173,24 +177,26 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
fill_invalid: Option<f64>,
clamp: (f64, f64),
) {
izip!(
variance_out.col_mut(0).iter_mut(),
inv_std.col_mut(0).iter_mut(),
draw_var.col_ref(0).iter(),
grad_var.col_ref(0).iter(),
)
.for_each(|(var_out, inv_std_out, &draw_var, &grad_var)| {
let val = (draw_var / grad_var).sqrt();
if (!val.is_finite()) | (val == 0f64) {
if let Some(fill_val) = fill_invalid {
*var_out = fill_val;
*inv_std_out = fill_val.recip().sqrt();
self.arch.dispatch(|| {
izip!(
variance_out.col_mut(0).iter_mut(),
inv_std.col_mut(0).iter_mut(),
draw_var.col_ref(0).iter(),
grad_var.col_ref(0).iter(),
)
.for_each(|(var_out, inv_std_out, &draw_var, &grad_var)| {
let val = (draw_var / grad_var).sqrt();
if (!val.is_finite()) | (val == 0f64) {
if let Some(fill_val) = fill_invalid {
*var_out = fill_val;
*inv_std_out = fill_val.recip().sqrt();
}
} else {
let val = val.clamp(clamp.0, clamp.1);
*var_out = val;
*inv_std_out = val.recip().sqrt();
}
} else {
let val = val.clamp(clamp.0, clamp.1);
*var_out = val;
*inv_std_out = val.recip().sqrt();
}
});
});
}

Expand All @@ -202,16 +208,18 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
fill_invalid: f64,
clamp: (f64, f64),
) {
izip!(
variance_out.col_mut(0).iter_mut(),
inv_std.col_mut(0).iter_mut(),
gradient.col_ref(0).iter(),
)
.for_each(|(var_out, inv_std_out, &grad_var)| {
let val = grad_var.abs().clamp(clamp.0, clamp.1).recip();
let val = if val.is_finite() { val } else { fill_invalid };
*var_out = val;
*inv_std_out = val.recip().sqrt();
self.arch.dispatch(|| {
izip!(
variance_out.col_mut(0).iter_mut(),
inv_std.col_mut(0).iter_mut(),
gradient.col_ref(0).iter(),
)
.for_each(|(var_out, inv_std_out, &grad_var)| {
let val = grad_var.abs().clamp(clamp.0, clamp.1).recip();
let val = if val.is_finite() { val } else { fill_invalid };
*var_out = val;
*inv_std_out = val.recip().sqrt();
});
});
}
}
Expand Down
20 changes: 0 additions & 20 deletions src/mass_matrix.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
use itertools::izip;
use multiversion::multiversion;

use crate::{
math_base::Math,
nuts::Collector,
Expand Down Expand Up @@ -71,23 +68,6 @@ impl<M: Math> DiagMassMatrix<M> {
}
}

#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))]
fn update_diag(
variance_out: &mut [f64],
inv_std_out: &mut [f64],
new_variance: impl Iterator<Item = Option<f64>>,
) {
izip!(variance_out, inv_std_out, new_variance).for_each(|(var, inv_std, x)| {
if let Some(x) = x {
assert!(x.is_finite(), "Illegal value on mass matrix: {}", x);
assert!(x > 0f64, "Illegal value on mass matrix: {}", x);
//assert!(*var != x, "No change in mass matrix from {} to {}", *var, x);
*var = x;
*inv_std = (1. / x).sqrt();
};
});
}

impl<M: Math> MassMatrix<M> for DiagMassMatrix<M> {
fn update_velocity(&self, math: &mut M, state: &mut InnerState<M>) {
math.array_mult(&self.variance, &state.p, &mut state.v);
Expand Down
2 changes: 1 addition & 1 deletion src/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ mod tests {
let mut y = y.clone();
axpy(&x[..], &mut y[..], a);
for ((&x, y), out) in x.iter().zip(orig).zip(y) {
assert_approx_eq(out, a * x + y);
assert_approx_eq(out, a.mul_add(x, y));
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/nuts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,7 @@ mod tests {
fn to_arrow() {
let ndim = 10;
let func = NormalLogp::new(ndim, 3.);
let mut math = CpuMath::new(func);
let math = CpuMath::new(func);

let settings = SamplerArgs::default();
let mut rng = thread_rng();
Expand Down
9 changes: 3 additions & 6 deletions src/state.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use std::{
cell::RefCell,
fmt::Debug,
marker::PhantomData,
ops::{Deref, DerefMut},
ops::Deref,
rc::{Rc, Weak},
};

Expand Down Expand Up @@ -54,7 +53,6 @@ pub(crate) struct InnerState<M: Math> {
pub(crate) idx_in_trajectory: i64,
pub(crate) kinetic_energy: f64,
pub(crate) potential_energy: f64,
_phantom_todo: PhantomData<M>,
}

pub(crate) struct InnerStateReusable<M: Math> {
Expand All @@ -74,7 +72,6 @@ impl<'pool, M: Math> InnerStateReusable<M> {
idx_in_trajectory: 0,
kinetic_energy: 0.,
potential_energy: 0.,
_phantom_todo: PhantomData::default(),
},
reuser: Rc::downgrade(&Rc::clone(&owner.storage)),
}
Expand Down Expand Up @@ -225,7 +222,7 @@ mod tests {
fn crate_pool() {
let logp = NormalLogp::new(10, 0.2);
let mut math = CpuMath::new(logp);
let mut pool = StatePool::new(&mut math, 10);
let pool = StatePool::new(&mut math, 10);
let mut state = pool.new_state(&mut math);
assert!(state.p.nrows() == 10);
assert!(state.p.ncols() == 1);
Expand All @@ -241,7 +238,7 @@ mod tests {
let dim = 10;
let logp = NormalLogp::new(dim, 0.2);
let mut math = CpuMath::new(logp);
let mut pool = StatePool::new(&mut math, 10);
let pool = StatePool::new(&mut math, 10);
let a = pool.new_state(&mut math);

assert_eq!(a.idx_in_trajectory, 0);
Expand Down

0 comments on commit 2565230

Please sign in to comment.