Skip to content

Commit

Permalink
genericize simd uniform int
Browse files Browse the repository at this point in the history
remove some debug stuff

remove bernoulli
  • Loading branch information
TheIronBorn committed Jul 9, 2022
1 parent e614fd7 commit acd5020
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 80 deletions.
19 changes: 0 additions & 19 deletions src/distributions/bernoulli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,25 +154,6 @@ impl Distribution<bool> for Bernoulli {
}
}

/// Requires nightly Rust and the [`simd_support`] feature
///
/// [`simd_support`]: https://github.com/rust-random/rand#crate-features
#[cfg(feature = "simd_support")]
impl<const LANES: usize> Distribution<Mask<i64, LANES>> for Bernoulli
where
LaneCount<LANES>: SupportedLaneCount,
Standard: Distribution<Simd<u64, LANES>>,
{
// TODO: revisit for https://github.com/rust-random/rand/issues/1227
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Mask<i64, LANES> {
if self.p_int == ALWAYS_TRUE {
return Mask::splat(true);
}
rng.gen().lanes_lt(Simd::splat(self.p_int))
}
}

#[cfg(test)]
mod test {
use super::Bernoulli;
Expand Down
2 changes: 1 addition & 1 deletion src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ mod float;
mod integer;
mod other;
mod slice;
pub mod utils;
mod utils;
#[cfg(feature = "alloc")]
mod weighted_index;

Expand Down
95 changes: 36 additions & 59 deletions src/distributions/uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,10 @@ use core::time::Duration;
use core::ops::{Range, RangeInclusive};

use crate::distributions::float::IntoFloat;
use crate::distributions::utils::{BoolAsSIMD, IntAsSIMD, FloatAsSIMD, FloatSIMDUtils, WideningMultiply};
use crate::distributions::utils::{BoolAsSIMD, FloatAsSIMD, FloatSIMDUtils, IntAsSIMD, WideningMultiply};
use crate::distributions::Distribution;
#[cfg(feature = "simd_support")]
use crate::distributions::Standard;
use crate::{Rng, RngCore};

#[cfg(not(feature = "std"))]
Expand Down Expand Up @@ -571,21 +573,30 @@ uniform_int_impl! { u128, u128, u128 }

#[cfg(feature = "simd_support")]
macro_rules! uniform_simd_int_impl {
($ty:ident, $unsigned:ident, $u_scalar:ident) => {
($ty:ident, $unsigned:ident) => {
// The "pick the largest zone that can fit in an `u32`" optimization
// is less useful here. Multiple lanes complicate things, we don't
// know the PRNG's minimal output size, and casting to a larger vector
// is generally a bad idea for SIMD performance. The user can still
// implement it manually.

// TODO: look into `Uniform::<u32x4>::new(0u32, 100)` functionality
// perhaps `impl SampleUniform for $u_scalar`?
impl SampleUniform for $ty {
type Sampler = UniformInt<$ty>;
impl<const LANES: usize> SampleUniform for Simd<$ty, LANES>
where
LaneCount<LANES>: SupportedLaneCount,
Simd<$unsigned, LANES>:
WideningMultiply<Output = (Simd<$unsigned, LANES>, Simd<$unsigned, LANES>)>,
Standard: Distribution<Simd<$unsigned, LANES>>,
{
type Sampler = UniformInt<Simd<$ty, LANES>>;
}

impl UniformSampler for UniformInt<$ty> {
type X = $ty;
impl<const LANES: usize> UniformSampler for UniformInt<Simd<$ty, LANES>>
where
LaneCount<LANES>: SupportedLaneCount,
Simd<$unsigned, LANES>:
WideningMultiply<Output = (Simd<$unsigned, LANES>, Simd<$unsigned, LANES>)>,
Standard: Distribution<Simd<$unsigned, LANES>>,
{
type X = Simd<$ty, LANES>;

#[inline] // if the range is constant, this helps LLVM to do the
// calculations at compile-time.
Expand All @@ -609,13 +620,13 @@ macro_rules! uniform_simd_int_impl {
let high = *high_b.borrow();
assert!(low.lanes_le(high).all(),
"Uniform::new_inclusive called with `low > high`");
let unsigned_max = Simd::splat(::core::$u_scalar::MAX);
let unsigned_max = Simd::splat(::core::$unsigned::MAX);

// NOTE: these may need to be replaced with explicitly
// wrapping operations if `packed_simd` changes
let range: $unsigned = ((high - low) + Simd::splat(1)).cast();
// NOTE: all `Simd` operations are inherently wrapping,
// see https://doc.rust-lang.org/std/simd/struct.Simd.html
let range: Simd<$unsigned, LANES> = ((high - low) + Simd::splat(1)).cast();
// `% 0` will panic at runtime.
let not_full_range = range.lanes_gt($unsigned::splat(0));
let not_full_range = range.lanes_gt(Simd::splat(0));
// replacing 0 with `unsigned_max` allows a faster `select`
// with bitwise OR
let modulo = not_full_range.select(range, unsigned_max);
Expand All @@ -634,8 +645,8 @@ macro_rules! uniform_simd_int_impl {
}

fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
let range: $unsigned = self.range.cast();
let zone: $unsigned = self.z.cast();
let range: Simd<$unsigned, LANES> = self.range.cast();
let zone: Simd<$unsigned, LANES> = self.z.cast();

// This might seem very slow, generating a whole new
// SIMD vector for every sample rejection. For most uses
Expand All @@ -646,19 +657,19 @@ macro_rules! uniform_simd_int_impl {
// rejection. The replacement method does however add a little
// overhead. Benchmarking or calculating probabilities might
// reveal contexts where this replacement method is slower.
let mut v: $unsigned = rng.gen();
let mut v: Simd<$unsigned, LANES> = rng.gen();
loop {
let (hi, lo) = v.wmul(range);
let mask = lo.lanes_le(zone);
if mask.all() {
let hi: $ty = hi.cast();
let hi: Simd<$ty, LANES> = hi.cast();
// wrapping addition
let result = self.low + hi;
// `select` here compiles to a blend operation
// When `range.eq(0).none()` the compare and blend
// operations are avoided.
let v: $ty = v.cast();
return range.lanes_gt($unsigned::splat(0)).select(result, v);
let v: Simd<$ty, LANES> = v.cast();
return range.lanes_gt(Simd::splat(0)).select(result, v);
}
// Replace only the failing lanes
v = mask.select(v, rng.gen());
Expand All @@ -668,50 +679,16 @@ macro_rules! uniform_simd_int_impl {
};

// bulk implementation
($(($unsigned:ident, $signed:ident),)+ $u_scalar:ident) => {
($(($unsigned:ident, $signed:ident)),+) => {
$(
uniform_simd_int_impl!($unsigned, $unsigned, $u_scalar);
uniform_simd_int_impl!($signed, $unsigned, $u_scalar);
uniform_simd_int_impl!($unsigned, $unsigned);
uniform_simd_int_impl!($signed, $unsigned);
)+
};
}

#[cfg(feature = "simd_support")]
uniform_simd_int_impl! {
(u64x2, i64x2),
(u64x4, i64x4),
(u64x8, i64x8),
u64
}

#[cfg(feature = "simd_support")]
uniform_simd_int_impl! {
(u32x2, i32x2),
(u32x4, i32x4),
(u32x8, i32x8),
(u32x16, i32x16),
u32
}

#[cfg(feature = "simd_support")]
uniform_simd_int_impl! {
(u16x2, i16x2),
(u16x4, i16x4),
(u16x8, i16x8),
(u16x16, i16x16),
(u16x32, i16x32),
u16
}

#[cfg(feature = "simd_support")]
uniform_simd_int_impl! {
(u8x4, i8x4),
(u8x8, i8x8),
(u8x16, i8x16),
(u8x32, i8x32),
(u8x64, i8x64),
u8
}
uniform_simd_int_impl! { (u8, i8), (u16, i16), (u32, i32), (u64, i64) }

impl SampleUniform for char {
type Sampler = UniformChar;
Expand Down Expand Up @@ -1183,7 +1160,7 @@ mod tests {
_ => panic!("`UniformDurationMode` was not serialized/deserialized correctly")
}
}

#[test]
#[cfg(feature = "serde1")]
fn test_uniform_serialization() {
Expand Down
1 change: 0 additions & 1 deletion src/distributions/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

//! Math helper functions

#[cfg(feature = "simd_support")] use core::mem;
#[cfg(feature = "simd_support")] use core::simd::*;


Expand Down

0 comments on commit acd5020

Please sign in to comment.